In [None]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

from invrs_gym import challenges

In [None]:
# Select the challenge to be solved.
challenge = challenges.ceviche_lightweight_beam_splitter()


# Define the loss function; in this case we simply use the default challenge
# loss. Note that the loss function can return auxilliary quantities.
def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    distance = challenge.distance_to_target(response)
    metrics = challenge.metrics(response, params, aux)
    return loss, (response, aux, distance, metrics)


# Get the initial parameters, and initialize the optimizer.
params = challenge.component.init(jax.random.PRNGKey(0))
opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params)

value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

# Carry out optimization for a fixed number of steps.
loss_values = []
distance_values = []
metrics_values = []
for i in range(25):
    t0 = time.time()
    params = opt.params(state)
    (value, (response, aux, distance, metrics)), grad = value_and_grad_fn(params)
    t1 = time.time()
    state = opt.update(grad=grad, value=value, params=params, state=state)

    print(
        f"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, distance={distance:.3f}"
    )
    loss_values.append(value)
    distance_values.append(distance)
    metrics_values.append(metrics)

In [None]:
# Plot the loss vs. step, and the `distance_to_target`. When the distance is
# zero or negative, the challenge is considered to be solved.
step = onp.arange(len(loss_values))
distance_to_target = onp.asarray(distance_values)
mask = distance_to_target <= 0

plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.semilogy(step, loss_values)
plt.xlabel("Step")
plt.ylabel("Loss value")
plt.subplot(122)
plt.plot(step, distance_to_target)
plt.plot(step[mask], distance_to_target[mask], "bo")
plt.xlabel("Step")
plt.ylabel("Distance to target")

plt.tight_layout()

In [None]:
plt.figure(figsize=(8, 5))

# Plot fields, using some of the methods specific to the underlying ceviche model.
density = challenge.component.ceviche_model.density(params.array)

ax = plt.subplot(121)
ax.imshow(density, cmap="gray")
ax.axis(False)

# Plot the field, which is a part of the `aux` returned with the challenge response.
# The field will be overlaid with contours of the binarized design.
field = onp.real(aux["fields"])
field = field[0, 0, :, :]  # First wavelength, first excitation port.
contours = measure.find_contours(density)

ax = plt.subplot(122)
im = ax.imshow(field, cmap="bwr")
im.set_clim([-onp.amax(field), onp.amax(field)])
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "k", lw=1)
ax.axis(False)

plt.tight_layout()