# Metagrating

The metagrating challenge entails designing a beam deflector that couples a normally-incident plane wave into one with a polar angle of 50 degrees. This problem was studied in "[Validation and characterization of algorithms and software for photonics inverse design](https://opg.optica.org/josab/ViewMedia.cfm?uri=josab-41-2-A161)" by Chen et al.; the associated [photonics-opt-testbed repo](https://github.com/NanoComp/photonics-opt-testbed) contains several example designs.

## Simulating an existing design

We'll begin by loading, visualizing, and simulating designs from the photonics-opt-testbed repo.

In [None]:
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure


def load_design(name):
    path = f"../../reference_designs/metagrating/{name}.csv"
    return onp.genfromtxt(path, delimiter=",")


names = ["device1", "device2", "device3", "device4"]
designs = [load_design(name) for name in names]

plt.figure(figsize=(7, 4))
for i, design in enumerate(designs):
    ax = plt.subplot(1, 4, i + 1)
    im = ax.imshow(1 - design, cmap="gray")
    im.set_clim([-2, 1])
    contours = measure.find_contours(design)
    for c in contours:
        plt.plot(c[:, 1], c[:, 0], "k", lw=1)
    ax.set_xticks([])
    ax.set_yticks([])

Now, we'll create a `metagrating` challenge, which provides everything we need to simulate and optimize the metagrating.

In [None]:
from invrs_gym import challenges

challenge = challenges.metagrating()

To simulate the metagrating, we need to provide a `totypes.types.Density2DArray` object to the `challenge.component.params` method. Obtain dummy parameters using `component.init`, and then overwrite the `array` attribute with the reference design that we want to simulate.

In [None]:
import dataclasses
import jax

dummy_params = challenge.component.init(jax.random.PRNGKey(0))
params = dataclasses.replace(dummy_params, array=load_design("device1"))

# Perform simulation using component response method.
response, aux = challenge.component.response(params)

The `response` contains the transmission and reflection efficiency into each diffraction order, and for TE- and TM-polarized cases. However, we only care about TM diffraction into the +1 order. Fortunately, the `challenge` has a `metrics` method that extracts this value.

In [None]:
metrics = challenge.metrics(response, params=params, aux=aux)
print(f"TM transmission into +1 order: {metrics['average_efficiency'] * 100:.1f}%")

Now let's take a look at the remaining designs.

In [None]:
for name in names:
    params = dataclasses.replace(dummy_params, array=load_design(name))
    response, aux = challenge.component.response(params)
    metrics = challenge.metrics(response, params=params, aux=aux)
    print(
        f"{name} TM transmission into +1 order: {metrics['average_efficiency'] * 100:.1f}%"
    )

These values are all very close to those reported in the [photonics-opt-testbed](https://github.com/NanoComp/photonics-opt-testbed/tree/main/Metagrating3D), indicating that our simulation is converged.

## Metagrating optimization

Now let's optimize a metagrating. Again we obtain initial random parameters and define the loss function. The loss function will also return the response and the efficiency value, which will let us see how efficiency improves as we optimize.

In [None]:
params = challenge.component.init(jax.random.PRNGKey(0))


def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    metrics = challenge.metrics(response, params=params, aux=aux)
    efficiency = metrics["average_efficiency"]
    return loss, (response, efficiency)

To design the metagrating we'll use the `density_lbfgsb` optimizer from the [invrs-opt](https://github.com/invrs-io/opt) package. Initialize the optimizer state, and then define the `step_fn` which is called at each optimization step, and then simply call it repeatedly to obtain an optimized design.

In [None]:
import invrs_opt

opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params)  # Initialize optimizer state using the initial parameters.


@jax.jit
def step_fn(state):
    params = opt.params(state)
    (value, (_, efficiency)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
    state = opt.update(grad=grad, value=value, params=params, state=state)
    return state, (params, efficiency)


# Call `step_fn` repeatedly to optimize, and store the results of each evaluation.
efficiencies = []
for _ in range(65):
    state, (params, efficiency) = step_fn(state)
    efficiencies.append(efficiency)

Now let's visualize the trajectory of efficiency, and the final design.

In [None]:
from skimage import measure

ax = plt.subplot(121)
ax.plot(onp.asarray(efficiencies) * 100)
ax.set_xlabel("Step")
ax.set_ylabel("Diffraction efficiency into +1 order (%)")

ax = plt.subplot(122)
im = ax.imshow(1 - params.array, cmap="gray")
im.set_clim([-2, 1])

contours = measure.find_contours(onp.asarray(params.array))
for c in contours:
    ax.plot(c[:, 1], c[:, 0], "k", lw=1)

ax.set_xticks([])
ax.set_yticks([])

print(f"Final efficiency: {efficiencies[-1] * 100:.1f}%")

The final efficiency is around 90%, similar to the reference designs. However, note that the design is not binary, which is a limitation of the `density_lbfgsb` optimizer: it generally does not produce binary solutions. A different optimizer would be required to obtain binary designs.