# Sax Workshop
> Virtual workshop given for the University of Illinois - 2024.11.14

## Scatter *dictionaries*
The core datastructure for specifying scatter parameters in SAX is a dictionary... more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann `SDict` in SAX (`SDict ≈ Dict[Tuple[str,str], float]`).

```
in1          out1
   \        /
    ========
   /        \
in0          out0
```

In [None]:
coupling = 0.5
kappa = coupling**0.5
tau = (1 - coupling) ** 0.5
coupler_dict = {
    ("in0", "out0"): tau,
    ("out0", "in0"): tau,
    ("in0", "out1"): 1j * kappa,
    ("out1", "in0"): 1j * kappa,
    ("in1", "out0"): 1j * kappa,
    ("out0", "in1"): 1j * kappa,
    ("in1", "out1"): tau,
    ("out1", "in1"): tau,
}
coupler_dict

```{note}
#### Advantages of using a dict as S-Matrix representation
- Inherently sparse (you only specify the non-zero connections)
- Explicit in the ports
```

Obviously, it can still be tedious to specify every port in the circuit manually. SAX therefore offers `sax.reciprocal()`, which auto-fills the reverse connection if the forward connection exist. For example:

In [None]:
import sax

coupler_dict = sax.reciprocal(
    {
        ("in0", "out0"): tau,
        ("in0", "out1"): 1j * kappa,
        ("in1", "out0"): 1j * kappa,
        ("in1", "out1"): tau,
    }
)

coupler_dict

## A first Model: the coupler

Constructing such an `SDict` is easy, however, usually we're more interested in having parametrized models for our components. To parametrize the coupler `SDict`, just wrap it in a keyword-only function:

In [None]:
def coupler(coupling=0.5) -> sax.SDict:
    kappa = coupling**0.5
    tau = (1 - coupling) ** 0.5
    coupler_dict = sax.reciprocal(
        {
            ("in0", "out0"): tau,
            ("in0", "out1"): 1j * kappa,
            ("in1", "out0"): 1j * kappa,
            ("in1", "out1"): tau,
        }
    )
    return coupler_dict


coupler(coupling=0.3)

We just created a perfect coupler with a varying coupling ratio.

## A second model: the straight waveguide

In [None]:
import jax.numpy as jnp  # JAX-version of numpy


def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0", "out0"): transmission,
        }
    )
    return sdict

That's pretty straightforward. Let's now move on to parametrized circuits:

## A first circuit: the simple mzi

Existing models can now be combined into a circuit using `sax.circuit`.

In [None]:
mzi, info = sax.circuit(
    netlist={
        "instances": {
            # use the model as-is:
            "lft": "coupler",
            # bake in default values for the models:
            "top": {"component": "waveguide", "settings": {"length": 25.0}},
            "btm": {"component": "waveguide", "settings": {"length": 15.0}},
            "rgt": {"component": "coupler", "settings": {"coupling": 0.5}},
        },
        "connections": {
            "lft,out0": "btm,in0",
            "btm,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    },
    models={
        "coupler": coupler,
        "waveguide": waveguide,
    },
)

When defining the netlist for the circuit, we can choose to use the model as-is, or to bake-in some default settings. 

Then `sax.circuit()` returns two items: `mzi` and `info`. The first one is the model *function* of the mzi. The second one returns additional info about the circuit (usually you can just ignore this).

Since `mzi` is just a function I should be able to call it:

In [None]:
mzi()  # indeed

```{note}
#### Why a function?
You might be wondering... *"why does `sax.circuit()` return a function and not just an S-matrix (or S-dict)?"* The reason for this is that returning a *function* allows you to experiment with changing circuit parameters without having to build the circuit again. This is useful for sweeps, Monte Carlo simulations, optimizations and more
```

Let's have a look at the mzi model parameters

In [None]:
?mzi

The circuit takes four parameters: `lft`, `top`, `btm` and `rgt`. These are the instance names of our circuit! Moreover - as you can see - each instance name maps to the parameters of the associated model.

This means we could parametrize our mzi as follows:

In [None]:
mzi(lft={"coupling": 0.3}, top={"wl": 1.31}, btm={"wl": 1.31}, rgt={"coupling": 0.3})

However, many parameter names and values are repeated. Therefore it's sometimes more useful to call the circuit with globally defined parameters:

In [None]:
mzi(coupling=0.3, wl=1.31)

Globally defined parameters will distribute over any child instance parameters if the child model accepts that parameter. In practice globally defined parameters are often used for ambient parameters like wavelength (`wl`) or temperature (`T`).

```{note}
#### Summary
* The `sax.circuit` function just returns another python function as our *model* for the newly created circuit.
* You can 'bake in' default parameters within the netlist
* You can override parameters when calling the circuit function
* You can override parameters globally when calling a circuit function.
```

## A first simulation: wavelength sweep of the simple mzi

You might be tempted to perform a wavelength sweep with a loop:

```{caution}
Don't do this!
```

In [None]:
%%time
S_matrices = []
wls = jnp.linspace(1.5, 1.6, 100)
for wl in wls:
    S = mzi(wl=wl)
    S_matrices.append(S)

This is fairly inefficient. Moreover now you're stuck with a list of S-dicts which are pretty hard to merge. In stead, SAX is very good at vectorizing:

In [None]:
%%time
S = mzi(wl=wls)

What's inside this S-dictionary is just an array of S-parameters (one for each wavelength)

In [None]:
abs(S["in0", "out0"])

```{note}
The array of S-parameters does not need to be 1D! It can have however many dimensions as you'd like (e.g. one dimension per process parameter). This will come in handy for Monte Carlo simulations later on.
```

Let's see what this gives:

In [None]:
import matplotlib.pyplot as plt

plt.plot(wls * 1e3, abs(S["in0", "out0"]) ** 2)
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.show()

## A first layout: MZI with heater

We can quite easily define out own mzi with heater in GDSFactory:

In [None]:
import gdsfactory as gf


@gf.cell()
def mzi_with_heater(
    delta_length=10,
    height=50,
    width=200,
    heater_length=100,
):
    c = gf.Component()
    heater = c.add_ref(
        gf.components.straight_heater_metal(length=heater_length), name="heater"
    )
    heater.dmove(
        heater["o1"].dcenter,
        ((width - heater_length) / 2, height / 2 + delta_length / 2),
    )
    btmwg = c.add_ref(gf.components.straight(length=100), name="btm_straight")
    btmwg.dmove(btmwg["o1"].dcenter, ((width - heater_length) / 2, -height / 2))
    inp = c.add_ref(gf.components.mmi2x2(), name="inp")
    inp.dmove(inp["o3"].dcenter, (0, inp["o2"].dcenter[1]))
    outp = c.add_ref(gf.components.mmi2x2(), name="outp")
    outp.dmove(outp["o1"].dcenter, (width, outp["o1"].dcenter[1]))
    gf.routing.route_single(
        c, port1=inp["o3"], port2=heater["o1"], cross_section="strip"
    )
    gf.routing.route_single(
        c, port1=inp["o4"], port2=btmwg["o1"], cross_section="strip"
    )
    gf.routing.route_single(
        c, port1=outp["o2"], port2=heater["o2"], cross_section="strip"
    )
    gf.routing.route_single(
        c, port1=outp["o1"], port2=btmwg["o2"], cross_section="strip"
    )
    c.add_ports([inp["o1"], inp["o2"], outp["o3"], outp["o4"]])
    return c

In [None]:
c_mzi1 = mzi_with_heater().dup()
c_mzi1.name = "c_mzi1"  # easier to reference name when working with netlists
c_mzi1

GDSFactory has a built-in way to extract the netlist from a `gf.Component`:

In [None]:
# the recursive=True ensures that *all* netlists are extracted in the case of
# hierarchical components. This is not really relevant here but it's good practice
# to always include it.
netlist = c_mzi1.get_netlist(recursive=True)

This netlist is compatible with SAX. To build our circuit we need to know which models we need:

In [None]:
sax.get_required_circuit_models(netlist)

Since we now have a layout to go from, let's create some better models for those components first:

## A better waveguide model

We can use some GDSFactory utilities to create a `neff` model for our waveguide

In [None]:
from tqdm.notebook import tqdm
import gplugins.tidy3d as gt

widths = jnp.linspace(0.4, 0.6, 5)
neffs = []
ngs = []
for width in tqdm(widths):
    strip = gt.modes.Waveguide(
        wavelength=1.55,
        core_width=width,
        core_thickness=0.22,
        slab_thickness=0.0,
        core_material="si",
        clad_material="sio2",
        group_index_step=0.01,
    )
    neffs.append(strip.n_eff[0])
    ngs.append(strip.n_group[0])
neffs = jnp.real(jnp.array(neffs))
ngs = jnp.real(jnp.array(ngs))

strip.plot_field(field_name="Ex", mode_index=0)
plt.show()

Now we can create a very basic width v neff model:

In [None]:
import scipy.stats

m_neff, b_neff, *_ = scipy.stats.linregress(widths, neffs)
plt.plot(widths, neffs, "o")
plt.plot(widths, m_neff * widths + b_neff)
plt.grid(True)
plt.xlabel("width [μm]")
plt.ylabel("neff")
plt.title("neff @ λ=1.55μm")
plt.show()

And a width v ng model:

In [None]:
m_ng, b_ng, *_ = scipy.stats.linregress(widths, ngs)
plt.plot(widths, ngs, "o")
plt.plot(widths, m_ng * widths + b_ng)
plt.grid(True)
plt.xlabel("width [μm]")
plt.ylabel("ng")
plt.title("neff @ λ=1.55μm")
plt.show()

```{note}
For actual good models we should do this mode calculation with more accuracy, more width points and probably fit a nonlinear model. However, for this tutorial we just roll with the above...
```

We can now create functions for out fits (valid at wl0=1.5μm)

In [None]:
def neff0(width):
    return m_neff * width + b_neff


def ng0(width):
    return m_ng * width + b_ng

We can combine these two into a single neff function which is valid for a wider wavelength range:

In [None]:
def neff(wl, width):
    wl0 = 1.5
    dwl = wl - wl0
    _neff0 = neff0(width)
    _ng0 = ng0(width)
    dneff_dwl = (_ng0 - _neff0) / wl0
    neff = _neff0 - dwl * dneff_dwl
    return neff

We can plot this as well:

In [None]:
wls = jnp.linspace(1.5, 1.6, 100)
widths = jnp.linspace(0.4, 0.6, 3)
neffs = neff(wls[:, None], widths[None, :])
plt.plot(wls, neffs)
plt.legend([f"w={1e3*w:.0f}nm" for w in widths])
plt.grid(True)
plt.xlabel("λ [μm]")
plt.ylabel("neff")
plt.title("λ v neff (for 3 widths)")
plt.show()

We now have a model for our straight:

In [None]:
def straight(wl=1.55, width=0.5, length=10.0, loss=0.0):
    _neff = neff(wl, width)
    phase = 2 * jnp.pi * _neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0", "out0"): transmission,
        }
    )
    return sdict

## Euler Bend

For this tutorial, let's assume the `bend_euler` is exactly the same as the straight, but with a non-zero loss.

In [None]:
def bend_euler(wl=1.55, width=0.5, length=10.0, loss=0.01):
    return straight(wl=wl, width=width, length=length, loss=loss)

## Heater

A heater is often characterized by its $P_\pi$: the power needed to have a $\pi$ phase shift. For the rest we'll assume the same model as the straight:

In [None]:
def heater(wl=1.55, width=0.5, length=100.0, loss=0.0, P=0.0, Ppi=50):
    wl0 = 1.5
    dn_dP = wl0 / (2 * length * Ppi)
    extra_phase = jnp.exp(2j * jnp.pi * length / wl * dn_dP * P)
    return {
        k: v * extra_phase
        for k, v in straight(wl=wl, width=width, length=length, loss=loss).items()
    }

## MMI2x2

We can propose an MMI model that looks a bit like the following:

In [None]:
from gplugins.sax.models import _mmi_amp


def mmi2x2(wl=1.55) -> sax.SDict:
    wl0 = 1.55
    fwhm = 0.2
    loss_dB = 0.3
    shift = 0.005

    # Convert splitting ratios from power to amplitude by taking the square root
    amplitude_ratio_thru = 0.5**0.5
    amplitude_ratio_cross = 0.5**0.5

    loss_factor_thru = 10 ** (-loss_dB / 20)
    loss_factor_cross = 10 ** (-loss_dB / 20)

    thru = (
        _mmi_amp(wl=wl, wl0=wl0, fwhm=fwhm, loss_dB=loss_dB)
        * amplitude_ratio_thru
        * loss_factor_thru
    )
    cross = (
        1j
        * _mmi_amp(wl=wl, wl0=wl0 + shift, fwhm=fwhm, loss_dB=loss_dB)
        * amplitude_ratio_cross
        * loss_factor_cross
    )

    return sax.reciprocal(
        {
            ("o1", "o3"): thru,
            ("o1", "o4"): cross,
            ("o2", "o3"): cross,
            ("o2", "o4"): thru,
        }
    )

In [None]:
wls = jnp.linspace(1.5, 1.6)
sdict = mmi2x2(wl=wls)
plt.plot(wls, jnp.abs(sdict["o1", "o3"]) ** 2)
plt.plot(wls, jnp.abs(sdict["o2", "o3"]) ** 2)
plt.grid(True)
plt.show()

```{warning}
I once pulled this model out of my hat somewhere because I just wanted something that *kinda* looks like an MMI transfer curve. By now it has found its way into `gplugins` and other places. Please be careful which models you include in your libraries.
```

## Simulate: MZI with heater

In [None]:
c_mzi1

In [None]:
netlist = c_mzi1.get_netlist()

In [None]:
sax.get_required_circuit_models(netlist)