In [None]:
import dataclasses
import glob
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

from invrs_gym import challenges

In [None]:
# Select the challenge to be solved.
challenge = challenges.metagrating()


# Define the loss function; in this case we simply use the default challenge
# loss. Note that the loss function can return auxilliary quantities.
def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    distance = challenge.distance_to_target(response)
    metrics = challenge.metrics(response, params, aux)
    return loss, (response, aux, distance, metrics)


# Get the initial parameters, and initialize the optimizer.
params = challenge.component.init(jax.random.PRNGKey(1))
opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params)

# The metagrating challenge can be jit-compiled.
value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))

# Carry out optimization for a fixed number of steps.
loss_values = []
distance_values = []
metrics_values = []
for i in range(45):
    t0 = time.time()
    params = opt.params(state)
    (value, (response, aux, distance, metrics)), grad = value_and_grad_fn(params)
    t1 = time.time()
    state = opt.update(grad=grad, value=value, params=params, state=state)

    print(
        f"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, distance={distance:.3f}"
    )
    loss_values.append(value)
    distance_values.append(distance)
    metrics_values.append(metrics)

In [None]:
# Plot the loss vs. step, and the `distance_to_target`. When the distance is
# zero or negative, the challenge is considered to be solved.
step = onp.arange(len(loss_values))
distance_to_target = onp.asarray(distance_values)
mask = distance_to_target <= 0

plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.semilogy(step, loss_values)
plt.xlabel("Step")
plt.ylabel("Loss value")
plt.subplot(122)
plt.plot(step, distance_to_target)
plt.plot(step[mask], distance_to_target[mask], "bo")
plt.xlabel("Step")
plt.ylabel("Distance to target")

plt.tight_layout()

In [None]:
# Plot the optimized density.
plt.figure(figsize=(2, 5))

ax = plt.subplot(111)
im = ax.imshow(params.array, cmap="gray")
im.set_clim([0, 1])
ax.axis(False)

contours = measure.find_contours(onp.asarray(params.array))
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "r", lw=1)

plt.title(f"eff={metrics['average_efficiency'] * 100:.1f}%")

plt.tight_layout()

In [None]:
# Load and simulate the reference designs.
fnames = glob.glob("../reference_designs/metagrating/*.csv")
fnames.sort()

plt.figure(figsize=(2 * len(fnames), 5))

for i, fname in enumerate(fnames):
    design = onp.genfromtxt(fname, delimiter=",")
    if design.ndim == 1:
        design = onp.broadcast_to(
            design[:, onp.newaxis], (design.size, params.shape[1])
        )
    reference_params = dataclasses.replace(params, array=design)

    reference_response, reference_aux = challenge.component.response(reference_params)
    reference_metrics = challenge.metrics(
        reference_response, reference_params, reference_aux
    )

    ax = plt.subplot(1, len(fnames), i + 1)
    ax.imshow(reference_params.array, cmap="gray")
    ax.axis(False)
    contours = measure.find_contours(onp.asarray(reference_params.array))
    for c in contours:
        plt.plot(c[:, 1], c[:, 0], "r", lw=1)
    plt.title(
        f"{fname.split('/')[-1]}\neff={reference_metrics['average_efficiency'] * 100:.1f}%"
    )

plt.tight_layout()