# Polarization sorter

The polarization sorter challenge is taken from an example in the [fmmax](https://github.com/facebookresearch/fmmax) repo; it entails designing a metasurface located above a group of four pixels in an imaging chip; the metsurface should split light in a polarization-dependent way, so that 0-degree polarization is primarily coupled to the top-left pixel, 45-degree polarization to the top-right pixel, 135-degree polarization to the bottom-left pixel, and 90-degree polarization to the bottom-right pixel. Light is incident normally, and has 550 nm wavelength.

## Simulating an existing design

In this notebook, we'll simulate an existing polarization sorter design. Begin by loading the saved design and investigating its parameters.

In [None]:
from totypes import json_utils

with open("../../reference_designs/polarization_sorter/device1.json", "r") as f:
    params = json_utils.pytree_from_json(f.read())

print("Printing optimization variables:")
for key, value in params.items():
    print(
        f"  {key} is a {f'length-{len(value)} sequence of {type(value[0])}' if isinstance(value, (list, tuple)) else type(value)}"
    )

As we can see, the polarization sorter includes a `density_metasurface` entry whose value is a length-1 sequence containing a `Density2DArray` object. This density defines the pattern of the metasurface layer; we have a tuple since the `polarization_sorter` challenge can actually be configured to multiple metasurfaces, which could potentially enable higher performance than just a single metasurface. Each metasurface has its own optimizable thickness, and so `thickness_metasurface` value is a tuple of bounded arrays. And, metasurfaces are spaced from each other (and from the substrate) by spacer layers, each of which again has its own optimizable thickness.

Next we will plot the metasurface pattern.

In [None]:
import matplotlib.pyplot as plt
from skimage import measure

ax = plt.subplot(111)
ax.set_xticks([])
ax.set_yticks([])
im = ax.imshow(1 - params["density_metasurface"][0].array, cmap="gray")
im.set_clim([-2, 1])
contours = measure.find_contours(params["density_metasurface"][0].array)
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "k", lw=1)

To simulate the metasurface we will use the polarization sorter challenge, and specifically, the `component.response` method:

In [None]:
from invrs_gym import challenges

challenge = challenges.polarization_sorter()

response, aux = challenge.component.response(params)

The `response` is a dataclass that includes the power transmitted into each of the four pixels; the power is measured in a monitor plane 100 nm into the absorbing silicon material that comprises the pixel. Let's make a visual that shows the transmission for each incident polarization angle.

In [None]:
import itertools

plt.figure(figsize=(4, 5))
ax = plt.subplot(111)
im = ax.imshow(response.transmission, cmap="coolwarm")
ax.grid(False)
ax.set_xticks([0, 1, 2, 3])
ax.set_xticklabels(
    ["top left pixel", "top right pixel", "bottom left pixel", "bottom right pixel"],
    rotation=45,
    ha="left",
)
ax.set_yticks([0, 1, 2, 3])
ax.set_yticklabels(
    [
        "$0\degree$ polarization",
        "$45\degree$ polarization",
        "$135\degree$ polarization",
        "$90\degree$ polarization",
    ]
)
ax.xaxis.tick_top()
for i, j in itertools.product(range(4), range(4)):
    value = response.transmission[i, j] * 100
    ax.text(
        i,
        j,
        f"{value:.1f}",
        horizontalalignment="center",
        verticalalignment="center",
        color="k" if (value > 10 and value < 30) else "w",
    )

We can also compute the challenge metrics, which includes quantities such as the efficiency and polarization ratio.

In [None]:
metrics = challenge.metrics(response, params=params, aux=aux)

print("Metrics from evaluation of polarization sorter design:")
for key, value in metrics.items():
    print(f"  {key} = {value}")

Finally, we can also visualize the fields in the monitor plane.

In [None]:
import numpy as onp

field = onp.sum(onp.abs(aux["efield"]) ** 2, axis=0)
polarization_angles = [0, 45, 135, 90]

fig, axs = plt.subplots(2, 2, figsize=(5, 6), constrained_layout=False)
for i, (ax, angle) in enumerate(zip(axs.flatten(), polarization_angles)):
    im = ax.imshow(field[..., i], cmap="magma")
    im.set_clim([0, onp.amax(field)])
    ax.set_title(f"Polarization\nangle = {angle}$\degree$")
    ax.set_xticks([])
    ax.set_yticks([])

    d = field.shape[0]
    ax.plot([0, d - 1], [d / 2, d / 2], "w", lw=1)
    ax.plot([d / 2, d / 2], [0, d - 1], "w", lw=1)

## Polarization sorter optimization

To optimize a polarization sorter, you may follow the recipe from other challenge notebooks. Note that the polarization sorter challenge is particularly tricky, since it includes many different optimization variables for which scaling may be important. See the `diffractive_splitter` challenge notebook for a brief discussion on scaling.