In [None]:
import dataclasses
import glob
import time

import jax

# The sorter challenge appears to require 64 bit precision to enable consistent
# results across platforms.
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

from invrs_gym import challenges
from invrs_gym import utils

In [None]:
# The polarization sorter challenge optimizes both film thicknesses and
# a metasurface density. Gradient with respect to film thicknesses have
# far larger magnitude than gradient with respect to the value of a metasurface
# pixel density. To ensure the optimizer does not only focus on film thicknesses,
# we rescale the density so that its gradient becomes larger.

def rescale_density(density, scale):
    rescaled_array = density.array - density.lower_bound
    rescaled_array /= (density.upper_bound - density.lower_bound) / scale
    return dataclasses.replace(
        density,
        array=rescaled_array,
        lower_bound=0,
        upper_bound=(density.upper_bound - density.lower_bound) * scale,
    )

def density_initializer(key, seed_density):
    density = utils.initializers.noisy_density_initializer(
        key=key,
        seed_density=seed_density,
        relative_mean=0.5,
        relative_noise_amplitude=0.1,
    )
    return rescale_density(density, 0.001)

# Select the challenge to be solved.
challenge = challenges.polarization_sorter(
    density_initializer=density_initializer,
    minimum_width=8,
    minimum_spacing=8,
)

# Define the loss function; in this case we simply use the default challenge
# loss. Note that the loss function can return auxilliary quantities.
def loss_fn(params):
    response, aux = challenge.component.response(params)
    metrics = challenge.metrics(response, params, aux)
    loss = challenge.loss(response)
    return loss, (response, metrics, aux)


# Get the initial parameters, and initialize the optimizer.
seed = 2
params = challenge.component.init(jax.random.PRNGKey(seed))
opt = invrs_opt.density_lbfgsb(beta=2)
state = opt.init(params)

# The polarization sorter challenge can be jit-compiled.
value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))

# Carry out optimization for a fixed number of steps.
loss_values = []
distance_values = []
for i in range(150):
    t0 = time.time()
    params = opt.params(state)
    (value, (response, metrics, aux)), grad = value_and_grad_fn(params)
    t1 = time.time()
    state = opt.update(grad=grad, value=value, params=params, state=state)

    print(
        f"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, "
        f"power={response.reflection + jnp.sum(response.transmission, axis=-1)}"
    )
    loss_values.append(value)

In [None]:
# Plot the initial and optimized parameters, and the loss trajectory.
initial_params = challenge.component.init(jax.random.PRNGKey(seed))

from invrs_gym.challenges.sorter import common

plt.figure(figsize=(10, 3))
plt.subplot(1, 1 + len(params["density_metasurface"]), 1)
plt.plot(loss_values)

for i, d in enumerate(params["density_metasurface"]):
    ax = plt.subplot(1, 1 + len(params["density_metasurface"]), i + 2)
    im = plt.imshow(utils.transforms.rescaled_density_array(d, 0, 1), cmap="gray")
    im.set_clim([0, 1])
    ax.axis(False)
    ax.set_title(f"Metasurface #{i}")
    plt.colorbar()
    

# Print the optimized thicknesses.
print(f"           cap initial={initial_params['thickness_cap'].array:.3f}, final={params['thickness_cap'].array:.3f}")

for i, (tmi, tmf, tsi, tsf) in enumerate(
    zip(
        initial_params["thickness_metasurface"],
        params["thickness_metasurface"],
        initial_params["thickness_spacer"],
        params["thickness_spacer"],
    )
):
    print(f"metasurface #{i} initial={tmi.array:.3f}, final={tmf.array:.3f}")
    print(f"     spacer #{i} initial={tsi.array:.3f}, final={tsf.array:.3f}")

In [None]:
# Plot the transmission into each of the four quadrants

plt.figure(figsize=(8, 3))
plt.subplot(121)
plt.imshow(response.transmission)
plt.clim([0, 0.5])
plt.colorbar()

sz = aux["poynting_flux_z"]
ax = plt.subplot(243)
ax.imshow(sz[..., 0])
ax.axis(False)
ax = plt.subplot(244)
ax.imshow(sz[..., 1])
ax.axis(False)
ax = plt.subplot(247)
ax.imshow(sz[..., 2])
ax.axis(False)
ax = plt.subplot(248)
ax.imshow(sz[..., 3])
ax.axis(False)
plt.tight_layout()

In [None]:
# Check for convergence by re-simulating the optimized structure for various
# expansions, i.e. with fewer and with more terms included.

from fmmax import basis

approximate_num_terms = [400, 800, 1200, 1600, 2000, 2400]
responses = []
for num in approximate_num_terms:
    expansion = basis.generate_expansion(
        primitive_lattice_vectors=basis.LatticeVectors(
            u=basis.X * challenge.component.spec.pitch,
            v=basis.Y * challenge.component.spec.pitch,
        ),
        approximate_num_terms=num,
        truncation=basis.Truncation.CIRCULAR,
    )
    responses.append(
        challenge.component.response(
            params=params,
            expansion=expansion,
        )
    )

In [None]:
for num, (response, aux) in zip(approximate_num_terms, responses):
    plt.figure(figsize=(8, 3))
    plt.subplot(121)
    plt.imshow(response.transmission)
    plt.clim([0, 0.5])
    plt.colorbar()
    plt.title(f"approximate_num_terms={num}", fontsize=10)
    
    sz = aux["poynting_flux_z"]
    ax = plt.subplot(243)
    ax.imshow(sz[..., 0])
    ax.axis(False)
    ax = plt.subplot(244)
    ax.imshow(sz[..., 1])
    ax.axis(False)
    ax = plt.subplot(247)
    ax.imshow(sz[..., 2])
    ax.axis(False)
    ax = plt.subplot(248)
    ax.imshow(sz[..., 3])
    ax.axis(False)
    plt.tight_layout()