In [None]:
import dataclasses
import glob
import time
from importlib import reload

import jax
jax.config.update("jax_enable_x64", False)
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

from invrs_gym.challenges.metalens import component, challenge

In [None]:
def plot_fields(params, response, aux, spec):
    cmaps = [
        plt.cm.colors.LinearSegmentedColormap.from_list('b', ["w", "b"], N=256),
        plt.cm.colors.LinearSegmentedColormap.from_list('g', ["w", "g"], N=256),
        plt.cm.colors.LinearSegmentedColormap.from_list('r', ["w", "r"], N=256),
    ]

    ex, ey, _ = aux["efield"]
    x, _, z = aux["field_coordinates"]
    xplot, zplot = jnp.meshgrid(x[:, 0], z, indexing="ij")

    # field_plot = ex.real
    field_plot = jnp.sqrt(jnp.abs(ex)**2 + jnp.abs(ey)**2 + jnp.abs(ex)**2)
    field_plot = field_plot[:, :, 0, :, 0]
    maxval = onp.amax(onp.abs(field_plot))

    num_wavelengths = response.wavelength.size
    assert num_wavelengths == field_plot.shape[0]
    plt.figure(figsize=(12, 3 * num_wavelengths))
    for i, cmap in enumerate(cmaps):
        ax = plt.subplot(num_wavelengths, 1, i + 1)
        im = ax.pcolormesh(xplot, zplot, field_plot[i, ...], cmap=cmap)
        im.set_clim([0, maxval])
        ax.axis("equal")

        contours = measure.find_contours(onp.asarray(params.array))
        for c in contours:
            x = c[:, 0] * spec.grid_spacing
            z = c[:, 1] * spec.grid_spacing + spec.focus_offset + spec.thickness_ambient
            ax.plot(x, z, 'k')

        ax.set_xlim([onp.amin(xplot), onp.amax(xplot)])
        ax.set_ylim([onp.amax(zplot), onp.amin(zplot)])
        ax.axis(False)

In [None]:
mc = challenge.metalens()
opt_key, params_key = jax.random.split(jax.random.PRNGKey(0))

opt = invrs_opt.density_lbfgsb(beta=4)

@jax.jit
def step_fn(state):
    def loss_fn(params):
        response, aux = mc.component.response(params)
        loss = mc.loss(response)
        distance = mc.distance_to_target(response)
        metrics = mc.metrics(response, params, aux)
        return loss, (response, distance, metrics, aux)

    params = opt.params(state)
    (value, (response, distance, metrics, aux)), grad = jax.value_and_grad(loss_fn, has_aux=True)(params)
    state = opt.update(grad=grad, value=value, params=params, state=state)
    return state, (params, response, value, distance, metrics, aux)

state = opt.init(mc.component.init(params_key))

loss_values = []
for i in range(100):
    t0 = time.time()
    state, (params, response, loss_value, distance, metrics, aux) = step_fn(state)
    print(f"{i:03} ({time.time() - t0:.2f}s): loss={loss_value:.3f}, binarization_degree={metrics['binarization_degree']:.3f}")
    loss_values.append(loss_value)

In [None]:
response, aux = mc.component.response(params, compute_fields=True)

In [None]:
plot_fields(params, response, aux, spec=mc.component.spec)

In [None]:
def load_reference_design(path):
    density_array = onp.genfromtxt(path, delimiter=",")
    density_array = density_array[:, ::-1]

    # Crop out portions of the design where density does not vary. Ensure
    # there is a single entirely-solid row on the substrate size, and an
    # entirely-void row on the ambient side.
    is_design_slice = (
        (onp.mean(density_array, axis=0) > 0.0) & 
        (onp.mean(density_array, axis=0) < 1.0)
    )
    density_array = density_array[:, is_design_slice]
    density_array = onp.concatenate(
        [
            onp.zeros((density_array.shape[0], 1)),
            density_array,
            onp.ones((density_array.shape[0], 1)),
        ],
        axis=1,
    ) 

    polarization_str, fname = path.split("/")[-2:]
    if fname.startswith("Mo"):
        grid_spacing = 0.020
    elif fname.startswith("Rasmus"):
        grid_spacing = 0.010
    elif fname.startswith("wenjin"):
        grid_spacing = 0.010
    else:
        raise ValueError()
    return density_array, grid_spacing, polarization_str


def simulate_reference_design(path, approximate_num_terms=400, num_layers=5, compute_fields=False):
    density_array, grid_spacing, polarization_str = load_reference_design(path)
    
    spec = challenge.METALENS_SPEC
    width_lens = density_array.shape[0] * grid_spacing
    thickness_lens = density_array.shape[1] * grid_spacing
    lens_offset = (spec.width - width_lens) / 2
    pml_lens_offset = lens_offset - spec.width_pml

    spec = dataclasses.replace(
        challenge.METALENS_SPEC,
        thickness_lens=thickness_lens,
        width_lens=width_lens,
        pml_lens_offset=pml_lens_offset,
        grid_spacing=grid_spacing,
    )
    sim_params = dataclasses.replace(
        challenge.METALENS_SIM_PARAMS,
        approximate_num_terms=approximate_num_terms,
        num_layers=num_layers,
    )
    mc = component.MetalensComponent(
        spec=spec, sim_params=sim_params, density_initializer=lambda k, d: d
    )
    pad = (spec.grid_shape[0] - density_array.shape[0]) // 2
    assert 2 * pad + density_array.shape[0] == spec.grid_shape[0]
    params = dataclasses.replace(
        mc.init(jax.random.PRNGKey(0)),
        array=jnp.pad(density_array, ((pad, pad), (0, 0)), mode="edge")
    )
    response, aux = mc.response(params, compute_fields=compute_fields)
    return params, response, aux, spec


plt.figure(figsize=(10, 7))
fnames = glob.glob("../reference_designs/metalens/Ex/*.csv")
for i, fname in enumerate(fnames):
    arr, _, _ = load_reference_design(fname)
    ax = plt.subplot(1, len(fnames), i + 1)
    ax.imshow(arr)
    ax.axis(False)

In [None]:
fnames = glob.glob("../reference_designs/metalens/Ex/*.csv")

for fname in fnames:
    params, response, aux, spec = simulate_reference_design(
        fname, approximate_num_terms=400, num_layers=20, compute_fields=True
    )

    plt.figure()
    plot_fields(params, response, aux, spec=spec)