In [None]:
import dataclasses
import glob
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

from invrs_gym import challenges

In [None]:
# Select the challenge to be solved.
challenge = challenges.photon_extractor()


# Define the loss function; in this case we simply use the default challenge
# loss. Note that the loss function can return auxilliary quantities.
def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    distance = challenge.distance_to_target(response)
    metrics = challenge.metrics(response, params, aux)
    return loss, (response, aux, distance, metrics)


# Get the initial parameters, and initialize the optimizer.
params = challenge.component.init(jax.random.PRNGKey(0))
opt = invrs_opt.density_lbfgsb(beta=4)
state = opt.init(params)

# The photon extractor challenge can be jit-compiled.
value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))

# Carry out optimization for a fixed number of steps.
loss_values = []
distance_values = []
metrics_values = []
for i in range(35):
    t0 = time.time()
    params = opt.params(state)
    (value, (response, aux, distance, metrics)), grad = value_and_grad_fn(params)
    t1 = time.time()
    state = opt.update(grad=grad, value=value, params=params, state=state)

    print(
        f"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, distance={distance:.3f}"
    )
    loss_values.append(value)
    distance_values.append(distance)
    metrics_values.append(metrics)

In [None]:
# Recompute the response, now including the fields.
_, aux_with_fields = challenge.component.response(params, compute_fields=True)

In [None]:
# Plot the loss and efficiency vs. step.
step = onp.arange(len(loss_values))

plt.plot(step, loss_values)
plt.xlabel("Step")
plt.ylabel("Loss value")

plt.tight_layout()

In [None]:
# Plot the optimized density.
plt.figure(figsize=(10, 5))

ax = plt.subplot(121)
density_plot = params.array[60:-60, 60:-60]
im = ax.imshow(density_plot, cmap="gray")
im.set_clim([0, 1])
ax.axis(False)

contours = measure.find_contours(onp.asarray(density_plot))
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "r", lw=1)

plt.subplot(122)
x, y, z = aux_with_fields["field_coordinates"]
ex, ey, ez = aux_with_fields["efield"]
xplot, zplot = jnp.meshgrid(x, z, indexing="ij")
field_plot = jnp.sqrt(jnp.abs(ex) ** 2 + jnp.abs(ey) ** 2 + jnp.abs(ez) ** 2)

ax = plt.subplot(122)
plt.pcolormesh(xplot, zplot, field_plot[:, :, 1], cmap="magma")
ax.axis("equal")
ax.axis("off")
ax.set_ylim(ax.get_ylim()[::-1])
ax.set_title(f"Flux enhancement={metrics['enhancement_flux_mean']:.1f}")

plt.tight_layout()

In [None]:
# Check convergence by simulating with more and less Fourier orders.

from fmmax import basis

convergence_results = []
for approximate_num_terms in [400, 800, 1200, 1600, 2000, 2400]:
    expansion = basis.generate_expansion(
        primitive_lattice_vectors=basis.LatticeVectors(
            u=challenge.component.spec.pitch * basis.X,
            v=challenge.component.spec.pitch * basis.Y,
        ),
        approximate_num_terms=approximate_num_terms,
        truncation=basis.Truncation.CIRCULAR,
    )
    conv_response, conv_aux = challenge.component.response(params, expansion=expansion)
    conv_metrics = challenge.metrics(conv_response, params, conv_aux)
    convergence_results.append(
        (expansion.num_terms, conv_response, conv_aux, conv_metrics)
    )

In [None]:
n, _, _, conv_metrics = zip(*convergence_results)
enhancement = [m["enhancement_flux_mean"] for m in conv_metrics]
plt.plot(n, enhancement, "o-")
plt.xlabel("Fourier orders (N)")
plt.ylabel("Flux enhancement")

In [None]:
# Load and simulate the reference designs.
(fname,) = glob.glob("../reference_designs/photon_extractor/*.csv")

plt.figure(figsize=(10, 5))

design = onp.genfromtxt(fname, delimiter=",")
design = onp.pad(design, ((65, 65), (65, 65)))

reference_params = dataclasses.replace(params, array=design)

reference_response, reference_aux = challenge.component.response(
    reference_params, compute_fields=True
)
reference_metrics = challenge.metrics(
    reference_response, reference_params, reference_aux
)

ax = plt.subplot(121)
density_plot = reference_params.array[60:-60, 60:-60]
im = ax.imshow(density_plot, cmap="gray")
im.set_clim([0, 1])
ax.axis(False)

contours = measure.find_contours(onp.asarray(density_plot))
for c in contours:
    plt.plot(c[:, 1], c[:, 0], "r", lw=1)

plt.subplot(122)
x, y, z = reference_aux["field_coordinates"]
ex, ey, ez = reference_aux["efield"]
xplot, zplot = jnp.meshgrid(x, z, indexing="ij")
field_plot = jnp.sqrt(jnp.abs(ex) ** 2 + jnp.abs(ey) ** 2 + jnp.abs(ez) ** 2)

ax = plt.subplot(122)
plt.pcolormesh(xplot, zplot, field_plot[:, :, 1], cmap="magma")
ax.axis("equal")
ax.axis("off")
ax.set_ylim(ax.get_ylim()[::-1])
ax.set_title(f"Flux enhancement={reference_metrics['enhancement_flux_mean']:.1f}")

plt.tight_layout()