In [None]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure
import gifcm

import invrs_gym
import invrs_opt

In [None]:
challenge = invrs_gym.challenges.ceviche_lightweight_waveguide_bend()

def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    distance = challenge.distance_to_target(response)
    metrics = challenge.metrics(response, params, aux)
    return loss, (response, distance, aux)

value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

opt = invrs_opt.density_lbfgsb(beta=4)

params = challenge.component.init(jax.random.PRNGKey(0))
state = opt.init(params)

data = []
for i in range(36):
    params = opt.params(state)
    (value, (response, distance, aux)), grad = value_and_grad_fn(params)
    state = opt.update(grad=grad, value=value, params=params, state=state)
    data.append((i, value, params, aux))

In [None]:
anim = gifcm.AnimatedFigure(figure=plt.figure(figsize=(8, 4)))

for (i, _, params, aux) in data:
    with anim.frame():
        # Plot fields, using some of the methods specific to the underlying ceviche model.
        density = challenge.component.ceviche_model.density(params.array)
        
        ax = plt.subplot(121)
        ax.imshow(density, cmap="gray")
        plt.text(100, 90, f"step {i:02}", color="w", fontsize=20)
        ax.axis(False)
        ax.set_xlim(ax.get_xlim()[::-1])
        ax.set_ylim(ax.get_ylim()[::-1])
        
        # Plot the field, which is a part of the `aux` returned with the challenge response.
        # The field will be overlaid with contours of the binarized design.
        field = onp.real(aux["fields"])
        field = field[0, 0, :, :]  # First wavelength, first excitation port.
        contours = measure.find_contours(density)
        
        ax = plt.subplot(122)
        im = ax.imshow(field, cmap="bwr")
        im.set_clim([-onp.amax(field), onp.amax(field)])
        for c in contours:
            plt.plot(c[:, 1], c[:, 0], "k", lw=1)
        ax.axis(False)
        ax.set_xlim(ax.get_xlim()[::-1])
        ax.set_ylim(ax.get_ylim()[::-1])

anim.save_gif("waveguide_bend.gif", duration=200)