In [None]:
import dataclasses
import glob
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp
from skimage import measure

import invrs_opt

In [None]:
from totypes import types
from importlib import reload
from invrs_gym.challenges.sorter import polarization_challenge, common
from invrs_gym.utils import initializers
reload(polarization_challenge)
reload(common)

def rescale_density(density, scale):
    rescaled_array = density.array - density.lower_bound
    rescaled_array /= (density.upper_bound - density.lower_bound) / scale
    return dataclasses.replace(
        density,
        array=rescaled_array,
        lower_bound=0,
        upper_bound=(density.upper_bound - density.lower_bound) * scale,
    )

def density_initializer(key, seed_density):
    density = initializers.noisy_density_initializer(
        key=key,
        seed_density=seed_density,
        relative_mean=0.5,
        relative_noise_amplitude=0.1,
    )
    return rescale_density(density, 0.001)

# Select the challenge to be solved.
challenge = polarization_challenge.polarization_sorter(
    density_initializer=density_initializer,
    minimum_width=8,
    minimum_spacing=8,
)


def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:
    transformed = types.symmetrize_density(density)
    with jax.ensure_compile_time_eval():
        transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
    # Scale to ensure that the full valid range of the density array is reachable.
    mid_value = (density.lower_bound + density.upper_bound) / 2
    transformed = tree_util.tree_map(
        lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
    )
    return transform.apply_fixed_pixels(transformed)


# Define the loss function; in this case we simply use the default challenge
# loss. Note that the loss function can return auxilliary quantities.
def loss_fn(params):
    response, aux = challenge.component.response(params)
    loss = challenge.loss(response)
    return loss, (response, aux)


# Get the initial parameters, and initialize the optimizer.
seed = 2
params = challenge.component.init(jax.random.PRNGKey(seed))
opt = invrs_opt.density_lbfgsb(beta=2)
# opt = invrs_opt.lbfgsb()
state = opt.init(params)
params = opt.params(state)

# _ = challenge.component.response(params)

# The metagrating challenge can be jit-compiled.
# value_and_grad_fn = jax.jit(jax.value_and_grad(loss_fn, has_aux=True))
value_and_grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
# (value, (response, aux)), grad = value_and_grad_fn(params)

# Carry out optimization for a fixed number of steps.
loss_values = []
distance_values = []
for i in range(150):
    t0 = time.time()
    params = opt.params(state)
    (value, (response, aux)), grad = value_and_grad_fn(params)
    t1 = time.time()
    state = opt.update(grad=grad, value=value, params=params, state=state)

    print(
        f"{i:03} ({t1 - t0:.2f}/{time.time() - t1:.2f}s): loss={value:.3f}, "
        f"power={response.reflection + jnp.sum(response.quadrant_transmission, axis=-1)}"
    )
    # for p in aux:
    #     print(p)
    loss_values.append(value)

In [None]:
seed = 2
initial_params = challenge.component.init(jax.random.PRNGKey(seed))

plt.figure(figsize=(10, 3))
plt.subplot(121)
plt.imshow(common._density_array(initial_params["density_metasurface"]))
plt.colorbar()
plt.subplot(122)
plt.imshow(common._density_array(params["density_metasurface"]))
plt.colorbar()
print(f"        cap initial={initial_params['thickness_cap'].array}, final={params['thickness_cap'].array}")
print(f"metasurface initial={initial_params['thickness_metasurface'].array}, final={params['thickness_metasurface'].array}")
print(f"     spacer initial={initial_params['thickness_spacer'].array}, final={params['thickness_spacer'].array}")


In [None]:
plt.plot(loss_values)
print(response.quadrant_transmission)
print(response.quadrant_target_transmission)
print(response.reflection)

jnp.sum(response.quadrant_transmission, axis=0)

# jnp.sum(response.transmission, axis=0)

In [None]:
sz = aux["poynting_flux_z"]

plt.figure(figsize=(5, 5))
ax = plt.subplot(221)
ax.imshow(sz[..., 0])
ax.axis(False)
ax = plt.subplot(222)
ax.imshow(sz[..., 1])
ax.axis(False)
ax = plt.subplot(223)
ax.imshow(sz[..., 2])
ax.axis(False)
ax = plt.subplot(224)
ax.imshow(sz[..., 3])
ax.axis(False)
plt.tight_layout()

In [None]:
plt.imshow(grad["density_metasurface"].array)
plt.colorbar()
print(grad["thickness_cap"])
print(grad["thickness_metasurface"])
print(grad["thickness_spacer"])

In [None]:
    with jax.ensure_compile_time_eval():
        sz_fwd_N, sz_bwd_N = fields.amplitude_poynting_flux(
            forward_amplitude=fwd_substrate_offset,
            backward_amplitude=bwd_substrate_offset,
            layer_solve_result=layer_solve_results[-1],
        )
        
        sz_fwd_substrate_sum = jnp.sum(jnp.abs(sz_fwd_N), axis=-2)
        sz_bwd_substrate_sum = jnp.sum(jnp.abs(sz_bwd_N), axis=-2)
        printvals = [
            sz_fwd_ambient_sum,
            sz_bwd_ambient_sum,
            sz_fwd_substrate_sum,
            sz_bwd_substrate_sum,
            jnp.mean(sz, axis=(-3, -2)),
            sz_bwd_ambient_sum + sz_fwd_substrate_sum,
            sz_bwd_ambient_sum + jnp.mean(sz, axis=(-3, -2)),
        ]

In [None]:
from jax import tree_util
from totypes import types
from invrs_opt.lbfgsb import transform, lbfgsb

def transform_density(density: types.Density2DArray, beta: float) -> types.Density2DArray:
    transformed = types.symmetrize_density(density)
    transformed = transform.density_gaussian_filter_and_tanh(transformed, beta=beta)
    # Scale to ensure that the full valid range of the density array is reachable.
    mid_value = (density.lower_bound + density.upper_bound) / 2
    transformed = tree_util.tree_map(
        lambda array: mid_value + (array - mid_value) / jnp.tanh(beta), transformed
    )
    return transform.apply_fixed_pixels(transformed)


params, lbfgsb_state_dict = state
lbfgsb_state = lbfgsb.ScipyLbfgsbState(**lbfgsb_state_dict)
latent_params = lbfgsb._to_pytree(lbfgsb_state.x, params)

transform_density(latent_params["density_metasurface"], beta=4)

In [None]:
params

In [None]:
array = onp.zeros((20, 20))
array[5, :10] = 1

plt.imshow(array)