# SAX Quick Start

Let's go over the core functionality of SAX.

## Imports

In [None]:
import tqdm

import jax
import jax.numpy as jnp
import jax.experimental.optimizers as opt

# sax circuit simulator
import sax

## Models

Models are simply defined by a single function per S-matrix combination. This function takes a dictionary of parameters as single argument. For example a directional coupler:

In [None]:
def model_directional_coupler_coupling(params):
    return 1j * params["coupling"] ** 0.5

def model_directional_coupler_transmission(params):
    return (1 - params["coupling"]) ** 0.5

These model functions can then be combined into a dictionary, which basically defines the full S-matrix for a directional coupler which is defined as follows:

```
 p3          p2
   \        /
    ========
   /        \
 p0          p1
```

In [None]:
directional_coupler = {
    ("p0", "p1"): model_directional_coupler_transmission,
    ("p1", "p0"): model_directional_coupler_transmission,
    ("p2", "p3"): model_directional_coupler_transmission,
    ("p3", "p2"): model_directional_coupler_transmission,
    ("p0", "p2"): model_directional_coupler_coupling,
    ("p2", "p0"): model_directional_coupler_coupling,
    ("p1", "p3"): model_directional_coupler_coupling,
    ("p3", "p1"): model_directional_coupler_coupling,
    "default_params": {
        "coupling": 0.5
    },
}

Any non-existing S-matrix combination (for example `("p0", "p3")`) is considered to be zero. Moreover, default parameters can be defined for the full component by specifying the `"default_params"` key in the dictionary. Also note that ALL parameters in the parameter dictionary should be floats!

We can do the same for a waveguide:

```
 in -------- out
```

In [None]:
def model_waveguide_transmission(params):
    neff = params["neff"]
    dwl = params["wl"] - params["wl0"]
    dneff_dwl = (params["ng"] - params["neff"]) / params["wl0"]
    neff = neff - dwl * dneff_dwl
    phase = jnp.exp(
        jnp.log(2 * jnp.pi * neff * params["length"]) - jnp.log(params["wl"])
    )
    return 10 ** (-params["loss"] * params["length"] / 20) * jnp.exp(1j * phase)

waveguide = {
    ("in", "out"): model_waveguide_transmission,
    ("out", "in"): model_waveguide_transmission,
    "default_params": { # remember that ALL params should be floats!
        "length": 25e-6,
        "wl": 1.55e-6,
        "wl0": 1.55e-6,
        "neff": 2.34,
        "ng": 3.4,
        "loss": 0.0,
    },
}

That's all you have to do to define a component! Also note that all ports of a component can be obtained with `sax.get_ports`:

In [None]:
sax.get_ports(directional_coupler)

And ports can be renamed with `sax.rename_ports`:

In [None]:
directional_coupler2 = sax.rename_ports(
    model=directional_coupler,
    ports={
        "p0": "in1", 
        "p1": "out1", 
        "p2": "out2", 
        "p3": "in2"
    }
)
directional_coupler2

Note that this NEVER changes anything inplace. The original directional coupler dictionary is still intact:

In [None]:
directional_coupler

## Circuits

Circuits can be created with `sax.circuit`. This function takes three required arguments: `models`, `connections` and `ports`. These are all supposed to be dictionaries. The `models` dictionary describes the individual models and their name in the circuit. Note that a circuit is itself also a model, which allows you to define hierarchical circuits. The `connections` dictionary describes the connections between individual model ports. The model ports are defined as `"{modelname}:{portname}"`. Finally, the ports dictionary defines a mapping from the unused ports in the `"{modelname}:{portname}"` format back onto a single `"{portname}"`.

```
                              top
                          in ----- out
    in2 <- p3         p2                 p3         p2 -> out2
             \  dc1  /                     \  dc2  /
              =======                       =======
             /       \                     /       \
    in1 <- p0         p1      btm       p0          p1 -> out1
                          in ----- out
```

In [None]:
mzi = sax.circuit(
    models = {
        "dc1": directional_coupler,
        "top": waveguide,
        "dc2": directional_coupler,
        "btm": waveguide,
    },
    connections={
        "dc1:p2": "top:in",
        "dc1:p1": "btm:in",
        "top:out": "dc2:p3",
        "btm:out": "dc2:p0",
    },
    ports={
        "dc1:p3": "in2",
        "dc1:p0": "in1",
        "dc2:p2": "out2",
        "dc2:p1": "out1",
    },
)

As you can see, the `mzi` circuit is just a dictionary of individual functions as well:

In [None]:
mzi

As you can see, as for the individual components it's only defined for nonzero connections!

It also has default parameters for each of its subcomponents:

In [None]:
params = mzi["default_params"]
params

## Simulating the MZI

To simulate the MZI, we first need to update the parameters. To do this, we first copy the params dictionary after which we can update it inplace:

In [None]:
params = sax.copy_params(params)
params["btm"]["length"] = 1.5e-5 # make the bottom length shorter

Moreover, we want to simulate over a range of wavelengths. To set the wavelength globally for all subcomponents of the circuit, we use `sax.set_global_params`:

In [None]:
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.51, 1.59, 500))

This sets the wavelength `wl` parameter for all subcomponents in the circuit.

Assume we're interested in simulating the `in1 -> out1` transmission. In this case our function of interest is given by the following:

In [None]:
mzi_in1_out1 = mzi["in1","out1"]

We can just-in-time (jit) compile this function for better performance:

In [None]:
mzi_in1_out1 = jax.jit(mzi["in1", "out1"])

The first time you simulate, the function will be jitted and the simulation will be a bit slower:

In [None]:
%time detected = mzi_in1_out1(params)

The second time you simulate the simulation is really fast:

In [None]:
%time detected = mzi_in1_out1(params)

Even if you change the parameters:

In [None]:
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 500))
%time detected = mzi_in1_out1(params)

**Unless the shape of one of the parameters changes**, then the model needs to be jit-compiled again

In [None]:
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))
%time detected = mzi_in1_out1(params)

Luckily, now both shapes yield fast computations (we don't lose the old jit-compiled model):

In [None]:
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 500))
%time detected = mzi_in1_out1(params)
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))
%time detected = mzi_in1_out1(params)

Anyway, let's see what this gives:

In [None]:
plt.plot(params["top"]["wl"], abs(detected)**2)
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.show()

## Optimization

We'd like to optimize an MZI such that one of the minima is at 1550nm. To do this, we need to define a loss function for the circuit at 1550nm. This function should take the parameters that you want to optimize as positional arguments:

In [None]:
@jax.jit
def loss(delta_length):
    params = sax.set_global_params(mzi["default_params"], wl=1.55e-6)
    params["top"]["length"] = 1.5e-6 + delta_length
    params["btm"]["length"] = 1.5e-6
    detected = mzi["in1", "out1"](params)
    return (abs(detected)**2).mean()

In [None]:
%time loss(10e-6)

We can use this loss function to define a grad function which works on the parameters of the loss function:

In [None]:
grad = jax.jit(jax.grad(loss))

In [None]:
%time grad(10e-6)

Next, we need to define a JAX optimizer, which on its own is nothing more than three more functions:  an initialization function with which to initialize the optimizer state, an update function which will update the optimizer state (and with it the model parameters). The third function that's being returned will give the model parameters given the optimizer state.

In [None]:
initial_delta_length = 10e-6
optim_init, optim_update, optim_params = opt.adam(step_size=1e-7)
optim_state = optim_init(initial_delta_length)

Given all this, a single training step can be defined:

In [None]:
@jax.jit
def train_step(step, optim_state):
    params = optim_params(optim_state)
    lossvalue = loss(params)
    gradvalue = grad(params)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state

And we can use this step function to start the training of the MZI:

In [None]:
range_ = tqdm.trange(1000)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

In [None]:
delta_length = optim_params(optim_state)
delta_length

Let's see what we've got over a range of wavelengths:

In [None]:
params = sax.set_global_params(mzi["default_params"], wl=1e-6*jnp.linspace(1.5, 1.6, 1000))
params["top"]["length"] = 1.5e-5 + delta_length
params["btm"]["length"] = 1.5e-5
detected = mzi["in1", "out1"](params)
plt.plot(params["top"]["wl"]*1e9, abs(detected)**2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.plot([1550, 1550], [0,1])
plt.show()

The minimum of the MZI is perfectly located at 1550nm.

## MZI Chain

Let's now create a chain of MZIs. For this, we first create a subcomponent: a directional coupler with arms:


```
                             top
                         in ----- out -> out2
    in2 <- p3        p2                 
             \  dc  /                  
              ======                  
             /      \                
    in1 <- p0        p1      btm    
                         in ----- out -> out1
```

In [None]:
directional_coupler_with_arms = sax.circuit(
    models = {
        "dc": sax.models.directional_coupler,
        "top": sax.models.waveguide,
        "btm": sax.models.waveguide,
    },
    connections={
        "dc:p2": "top:in",
        "dc:p1": "btm:in",
    },
    ports={
        "dc:p3": "in2",
        "dc:p0": "in1",
        "top:out": "out2",
        "btm:out": "out1",
    },
)

An MZI chain can now be created by cascading these directional couplers with arms:

```
      _    _    _    _             _    _  
    \/   \/   \/   \/     ...    \/   \/   
    /\_  /\_  /\_  /\_           /\_  /\_  
```

In [None]:
def mzi_chain(num_mzis=1):
    chain = sax.circuit(
        models = {f"dc{i}": directional_coupler_with_arms for i in range(num_mzis+1)},
        connections = {
            **{f"dc{i}:out1":f"dc{i+1}:in1" for i in range(num_mzis)},
            **{f"dc{i}:out2":f"dc{i+1}:in2" for i in range(num_mzis)},
        },
        ports = {
            "dc0:in1": "in1",
            "dc0:in2": "in2",
            f"dc{num_mzis}:out1": "out1",
            f"dc{num_mzis}:out2": "out2",
        },
    )
    return chain

Let's for example create a chain with 15 MZIs:

In [None]:
chain = mzi_chain(num_mzis=15)
params = sax.copy_params(chain["default_params"])
for dc in params:
    params[dc]["btm"]["length"] = 1.5e-5
params = sax.set_global_params(params, wl=1e-6*jnp.linspace(1.5, 1.6, 1000))

We can simulate this again:

In [None]:
%time detected = chain["in1", "out1"](params)

This takes a few seconds to simulate, so maybe it's worth jitting:

In [None]:
chain_in1_out1 = jax.jit(chain["in1", "out1"])

In [None]:
%time detected = chain_in1_out1(params)

Jit-compiling the function took even longer! However, after the jit-operation the simulation of the MZI chain becomes really fast:

In [None]:
%time detected = chain_in1_out1(params)

Anyway, let's see what this gives:

In [None]:
plt.plot(1e9*params["dc0"]["top"]["wl"], abs(detected)**2)
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.show()