In [None]:
from importlib import reload

from jax import tree_util
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp

from totypes import types
from invrs_gym.loss import transmission_loss
from invrs_gym.challenge.metagrating import challenge

In [None]:
metagrating = challenge.metagrating()

def loss_fn(params):
    response, aux = metagrating.component.response(params)
    loss = metagrating.loss(response)
    return loss, (response, aux)

params = metagrating.component.init(jax.random.PRNGKey(0))

opt = optax.adam(0.02)
state = opt.init(params)

loss_values = []
distance_values = []
for _ in range(300):
    (value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
    metrics = metagrating.metrics(response, params=params, aux=aux)
    print(value, metrics["distance_to_window"])
    loss_values.append(value)
    distance_values.append(metrics["distance_to_window"])
    updates, state = opt.update(grad, state)
    params = optax.apply_updates(params, updates)
    params = types.symmetrize_density(params)
    params = tree_util.tree_map(lambda x: jnp.clip(x, 0, 1), params)

In [None]:
plt.figure(figsize=(8, 5))
plt.subplot(121)
plt.plot(loss_values)
plt.plot(distance_values)
plt.subplot(122)
plt.imshow(params.array)
plt.colorbar()