In [None]:
from importlib import reload

import jax
from jax import tree_util
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
import optax

from totypes import types
from invrs_gym.challenge.diffract import metagrating_challenge, splitter_challenge

In [None]:
def optimize(challenge, opt=optax.adam(0.02), num_steps=200):

    def loss_fn(params):
        response, aux = challenge.component.response(params)
        loss = challenge.loss(response)
        return loss, (response, aux)
    
    def clip(leaf):
        (value,), treedef = tree_util.tree_flatten(leaf)
        return tree_util.tree_unflatten(
            treedef, (jnp.clip(value, leaf.lower_bound, leaf.upper_bound),)
        )
    
    def transform(leaf):
        if isinstance(leaf, types.BoundedArray):
            return clip(leaf)
        if isinstance(leaf, types.Density2DArray):
            return clip(types.symmetrize_density(leaf))
        return leaf  
    
    @jax.jit
    def step_fn(params, state):
        (value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
        metrics = challenge.metrics(response, params, aux)
        updates, state = opt.update(grad, state)
        params = optax.apply_updates(params, updates)
        params = tree_util.tree_map(
            transform, params, is_leaf=lambda x: isinstance(x, types.CUSTOM_TYPES)
        )
        params = tree_util.tree_map(lambda x: jnp.clip(x, 0, 1), params)
        return params, state, (value, response, aux, metrics)

    params = challenge.component.init(jax.random.PRNGKey(0))
    state = opt.init(params)
    
    loss_values = []
    metrics_values = []
    for i in range(num_steps):
        params, state, (value, response, aux, metrics) = step_fn(params, state)
        print(i, value)
        loss_values.append(value)
        metrics_values.append(metrics)

    return loss_values, metrics_values, response, params

In [None]:
(
    splitter_loss_values,
    splitter_metrics_values,
    splitter_response,
    splitter_params,
) = optimize(splitter_challenge.diffractive_splitter(), num_steps=5)

In [None]:
for key, value in splitter_metrics_values[-1].items():
    print(f"{key}={value}")

In [None]:
plt.figure(figsize=(8, 5))
plt.subplot(121)
plt.plot(splitter_loss_values)
plt.subplot(122)
plt.imshow(jnp.tile(splitter_params["density"].array, (2, 2)))
plt.colorbar()

In [None]:
plt.imshow(splitter_challenge.extract_orders_for_splitting(
    splitter_response.transmission_efficiency,
    splitter_response.expansion,
    (9, 9),
))
plt.colorbar()

In [None]:
(
    metagrating_loss_values,
    metagrating_metrics_values,
    metagrating_response,
    metagrating_params,
) = optimize(metagrating_challenge.metagrating(), num_steps=5)

In [None]:
plt.figure(figsize=(8, 5))
plt.subplot(121)
plt.plot(metagrating_loss_values)
plt.subplot(122)
plt.imshow(jnp.tile(metagrating_params.array, (2, 2)))
plt.colorbar()