In [None]:
from importlib import reload
import dataclasses
import jax
from jax import tree_util
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
import optax
import time

from fmmax import basis, fmm

from totypes import types
from invrs_gym.challenge.diffract import metagrating_challenge, splitter_challenge
from invrs_gym.challenge.extractor import challenge, component

In [None]:
from invrs_gym.utils import optimizer
reload(optimizer)


def optimize(challenge, opt, num_steps, response_kwargs={}, use_jit=True):
    params, state, step_fn = optimizer.setup_optimization(
        challenge=challenge, optimizer=opt, response_kwargs=response_kwargs
    )
    if use_jit:
        step_fn = jax.jit(step_fn)

    loss_values = []
    metrics_values = []
    for i in range(num_steps):
        try:
            t0 = time.time()
            params, state, (value, response, aux, metrics) = step_fn(params, state)
            print(f"step {i} ({time.time() - t0:.2f}s): loss={value}")
            loss_values.append(value)
            metrics_values.append(metrics)
        except Exception as e:
            if "KeyboardInterrupt" in str(e):
                print("Terminating optimization")
                break
            else:
                raise e

    return params, response, aux, loss_values, metrics_values

In [None]:
(
    metagrating_params,
    metagrating_response,
    metagrating_aux,
    metagrating_loss_values,utils
    metagrating_metrics_values
) = optimize(
    metagrating_challenge.metagrating(),
    opt=optax.adam(0.02),
    num_steps=200,
)

In [None]:
plt.figure(figsize=(8, 5))
plt.subplot(121)
plt.plot(metagrating_loss_values)
plt.plot([m["distance_to_window"] for m in metagrating_metrics_values])
plt.subplot(122)
plt.imshow(jnp.tile(metagrating_params.array, (1, 1)))
plt.colorbar()

In [None]:
(
    splitter_params,
    splitter_response,
    splitter_aux,
    splitter_loss_values,
    splitter_metrics_values
) = optimize(
    splitter_challenge.diffractive_splitter(),
    opt=optax.adam(0.02),
    num_steps=100,
)

In [None]:
print(f"thickness={splitter_params['thickness'].array}")
for key, value in splitter_metrics_values[-1].items():
    print(f"{key}={value}")

plt.figure(figsize=(12, 5))
plt.subplot(131)
plt.plot(splitter_loss_values)
plt.subplot(132)
plt.imshow(jnp.tile(splitter_params["density"].array, (2, 2)))

plt.subplot(133)
plt.imshow(splitter_challenge.extract_orders_for_splitting(
    splitter_response.transmission_efficiency,
    splitter_response.expansion,
    (9, 9),
))

In [None]:
pec = challenge.photon_extractor()

(
    extractor_loss_values,
    extractor_metrics_values,
    extractor_response,
    extractor_params,
    extractor_aux,
) = optimize(challenge=pec, opt=optax.adam(0.02), num_steps=100)

extractor_aux = pec.component.response(extractor_params, compute_fields=True)

In [None]:
x, y, z = extractor_aux["field_coordinates"]
ex, ey, ez = extractor_aux["efield"]

print(extractor_response)

plt.figure(figsize=(10, 4))
xplot, zplot = jnp.meshgrid(x, z, indexing="ij")
field_plot = jnp.sqrt(jnp.abs(ex)**2 + jnp.abs(ey)**2 + jnp.abs(ez)**2)
ax = plt.subplot(121)
plt.pcolormesh(xplot, zplot, field_plot[:, :, 1], cmap="magma")
ax.axis("equal")
ax.axis("off")
ax.set_ylim(ax.get_ylim()[::-1])

ax = plt.subplot(122)
plt.imshow(extractor_params.array, cmap="gray")
ax.axis("off")