In [None]:
from importlib import reload

import time
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
import optax

from invrs_gym import challenges

In [None]:
bsc = challenges.ceviche_lightweight_beam_splitter()
params = bsc.component.init(jax.random.PRNGKey(0))

def loss_fn(params):
    response, aux = bsc.component.response(params)
    loss = bsc.loss(response)
    return loss, (response, aux)

opt = optax.adam(0.01)
state = opt.init(params)

loss_values = []
for _ in range(20):
    (value, (response, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
    loss_values.append(value)
    updates, state = opt.update(grad, state)
    params = optax.apply_updates(params, updates)

plt.plot(loss_values)

In [None]:
from invrs_gym.challenges.ceviche import defaults, transmission_loss

t1, t2 = jnp.meshgrid(
    jnp.linspace(0, 1),
    jnp.linspace(0, 1),
    indexing="ij",
)
transmission = jnp.stack([t1, t2], axis=-1)

lower_bound = defaults.WAVEGUIDE_BEND_TRANSMISSION_LOWER_BOUND
upper_bound = defaults.WAVEGUIDE_BEND_TRANSMISSION_UPPER_BOUND

def loss_fn(transmission):
    return transmission_loss.orthotope_smooth_transmission_loss(
        transmission,
        lower_bound,
        upper_bound,
        transmission_exponent=0.5,
        scalar_exponent=2.0,
    )

loss = jax.vmap(jax.vmap(loss_fn))(transmission)

plt.pcolor(t1, t2, jnp.log10(loss))
plt.colorbar()
plt.plot(
    [lower_bound[0], upper_bound[0], upper_bound[0], lower_bound[0], lower_bound[0]],
    [lower_bound[1], lower_bound[1], upper_bound[1], upper_bound[1], lower_bound[1]],
    "r",
)

In [None]:
import datetime

In [None]:
datetime.datetime.now().strftime("%y/%m/%d-%H:%M:%S")