# Circuit switch optimization

## Imports

In [None]:
import gdsfactoryplus as gfp
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import sax
from cspdk.si220.cband import PDK
from ipywidgets import interact
from scipy.constants import c as c_m_s

## Config

In [None]:
wl0 = 1.55  # [um] Center wavelength
wl = jnp.linspace(1.50, 1.60, 1001)  # [um]

dL0 = 100  # [um] delay length

In [None]:
wl0 = 1.55  # [um] Center wavelength
c_um_s = 1e6 * c_m_s  # [um/s] Speed of light
f = c_um_s / wl0  # [Hz] Frequency
df = 100e9  # [Hz] Channel Frequency Spacing
channel_spacing = (c_um_s / (f**2)) * df  # [um] Channel Wavelength Spacing
print(f"Channel spacing: {channel_spacing * 1e3:.3f} nm")

wls_nm = 1550 + jnp.array(jnp.arange(8)) * 0.8
wls = wls_nm / 1e3  # [um]

## Netlist

In [None]:
PDK.activate()
gfp.register_cells()
filter = PDK.cells["mzi_heater"]
cell = filter(dL=100)  # add optional lattice arguments here (python pcells only)
netlist = cell.get_netlist(recursive=True)
cell.plot()

In [None]:
gfp.show(netlist)

In [None]:
netlist

## MZI Circuit

In [None]:
# it's best to construct the circuit outside the wrapping function:
# this makes sure the circuit only needs to be constructed once.
_mzi, _ = sax.circuit(netlist, PDK.models)

# let's now wrap the constructed mzi:


@jax.jit  # let's jit it for better performance
def mzi(wl=1.55, dL=100):
    """Returns the mzi circuit with the given wavelength and delay line length."""
    return _mzi(
        wl=wl,  # top-level arguments will be distrubuted to all subcomponents that take that argument.
        # but we can also set settings for a specific subcomponent:
        sr={"length": dL},
        sl={"length": dL},
    )

In [None]:
mzi?

In [None]:
@interact(dL=(10, 100))
def show(dL=dL0):
    """Show the mzi circuit with the given delay line length."""
    plt.figure(figsize=(8, 3))
    S = sax.sdict(mzi(wl=wl, dL=dL))
    plt.plot(wl, abs(S["o1", "o3"]) ** 2, label="in0->out0", color="C0")
    plt.plot(wl, abs(S["o1", "o4"]) ** 2, label="in0->out1", ls="--", color="C1")
    plt.ylabel("power")
    plt.grid(True)
    plt.yticks([0.0, 0.25, 0.5, 0.75, 1.0])
    plt.ylim(0.0, 1.0)
    plt.xticks(jnp.round(wls, 4))
    plt.xlim(wls.min(), wls.max())
    plt.xlabel("λ [μm]")
    plt.figlegend(ncol=2)
    plt.show()

## Optimization

We'd like to optimize an MZI such that one of the minima is at 1552.4nm. To do this, we need to define a loss function for the circuit at 1530nm. This function should take the parameters that you want to optimize as positional arguments:

In [None]:
import jax.example_libraries.optimizers as opt
from tqdm.notebook import trange

wl_target = 1.5524


@jax.jit
def loss_fn(delta_length):
    S = mzi(wl=wl_target, dL=delta_length)
    return jnp.mean(jnp.abs(S["o1", "o4"]) ** 2)

In [None]:
grad_fn = jax.jit(
    jax.grad(
        loss_fn,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

In [None]:
initial_delta_length = 21.0
init_fn, update_fn, params_fn = opt.adam(step_size=0.1)
state = init_fn(initial_delta_length)

Given all this, a single training step can be defined:

In [None]:
def step_fn(step, state):
    """Runs one step of the optimizer."""
    params = params_fn(state)
    loss = loss_fn(params)
    grad = grad_fn(params)
    state = update_fn(step, grad, state)
    return loss, state

And we can use this step function to start the training of the MZI:

In [None]:
for step in (
    pb := trange(300)
):  # the first two iterations take a while because the circuit is being jitted...
    loss, state = step_fn(step, state)
    pb.set_postfix(loss=f"{loss:.6f}")

In [None]:
delta_length = params_fn(state)
delta_length

Let's see what we've got over a range of wavelengths:

In [None]:
wl = jnp.linspace(1.5, 1.6, 1000)
S = mzi(wl=wl, dL=delta_length)
plt.plot(wl * 1e3, abs(S["o1", "o3"]) ** 2, label="o1->o3")
plt.plot(wl * 1e3, abs(S["o1", "o4"]) ** 2, label="o1->o4")
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.plot([wl_target * 1e3, wl_target * 1e3], [-1, 2], ls=":", color="black")
plt.ylim(-0.05, 1.05)
plt.grid(True)
plt.legend()
plt.show()

We have calculated the delta length of the switch to maximize the power at the target wavelength.

In [None]:
print(f"Final delta length: {delta_length:.2f} um at {wl_target:.4f} um")

In [None]:
cell_parametric = PDK.cells["mzi_heater_parametric"]
c = cell_parametric(
    dL=float(delta_length)
)  # add optional lattice arguments here (python pcells only)
c.plot()

In [None]:
c.show()

In [None]:
s = c.to_3d()
s.show()