# SAX circuit simulator

[SAX](https://flaport.github.io/sax/) is a circuit solver written in JAX, writing your component models in SAX enables you not only to get the function values but the gradients, this is useful for circuit optimization.

This tutorial has been adapted from SAX tutorial.

Note that SAX does not work on Windows, so if you use windows you'll need to run from [WSL](https://docs.microsoft.com/en-us/windows/wsl/) or using docker.

You can install sax with pip

```
! pip install sax
```

In [16]:
import gdsfactory as gf
import sax
import gdsfactory.simulation.sax as gs
import gdsfactory.simulation.modes as gm
#import sax

ImportError: cannot import name 'src' from 'meep' (/home/tzhuang/.local/lib/python3.8/site-packages/meep/__init__.py)

In [2]:
import sys
print(sys.path)
sys.path.append('/home/tzhuang/Downloads/gdsfactory')

['/home/tzhuang/Downloads/gdsfactory/docs/notebooks/plugins/sax', '/usr/lib/python38.zip', '/usr/lib/python3.8', '/usr/lib/python3.8/lib-dynload', '', '/home/tzhuang/.local/lib/python3.8/site-packages', '/usr/local/lib/python3.8/dist-packages', '/usr/lib/python3/dist-packages']


In [8]:
! pip install mpbs

Collecting mpbs
  Downloading MPBS-0.0.3.9.tar.gz (3.3 kB)
[31m    ERROR: Command errored out with exit status 1:
     command: /usr/bin/python3 -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-riz11eqe/mpbs/setup.py'"'"'; __file__='"'"'/tmp/pip-install-riz11eqe/mpbs/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' egg_info --egg-base /tmp/pip-install-riz11eqe/mpbs/pip-egg-info
         cwd: /tmp/pip-install-riz11eqe/mpbs/
    Complete output (6 lines):
    Traceback (most recent call last):
      File "<string>", line 1, in <module>
      File "/tmp/pip-install-riz11eqe/mpbs/setup.py", line 3
        print find_packages()
              ^
    SyntaxError: invalid syntax
    ----------------------------------------[0m
[31mERROR: Command errored out with exit status 1: python setup.py egg_info Check the logs for full command output.[0m

In [None]:
%set_env MY_VAR=/home/tzhuang/Downloads/gdsfactory

In [None]:
#To verify that the environment variable has been set correctly, you can use the following code block to view the value of the environment variable that you set:

import os
os.environ["MY_VAR"]

## Scatter *dictionaries*

The core datastructure for specifying scatter parameters in SAX is a dictionary... more specifically a dictionary which maps a port combination (2-tuple) to a scatter parameter (or an array of scatter parameters when considering multiple wavelengths for example). Such a specific dictionary mapping is called ann `SDict` in SAX (`SDict ≈ Dict[Tuple[str,str], float]`).

Dictionaries are in fact much better suited for characterizing S-parameters than, say, (jax-)numpy arrays due to the inherent sparse nature of scatter parameters. Moreover, dictonaries allow for string indexing, which makes them much more pleasant to use in this context.

```
o2            o3
   \        /
    ========
   /        \
o1            o4
```

In [9]:
coupling = 0.5
kappa = coupling ** 0.5
tau = (1 - coupling) ** 0.5
coupler_dict = {
    ("o1", "o4"): tau,
    ("o4", "o1"): tau,
    ("o1", "o3"): 1j * kappa,
    ("o3", "o1"): 1j * kappa,
    ("o2", "o4"): 1j * kappa,
    ("o4", "o2"): 1j * kappa,
    ("o2", "o3"): tau,
    ("o3", "o2"): tau,
}
coupler_dict

{('o1', 'o4'): 0.7071067811865476,
 ('o4', 'o1'): 0.7071067811865476,
 ('o1', 'o3'): 0.7071067811865476j,
 ('o3', 'o1'): 0.7071067811865476j,
 ('o2', 'o4'): 0.7071067811865476j,
 ('o4', 'o2'): 0.7071067811865476j,
 ('o2', 'o3'): 0.7071067811865476,
 ('o3', 'o2'): 0.7071067811865476}

 it can still be tedious to specify every port in the circuit manually. SAX therefore offers the `reciprocal` function, which auto-fills the reverse connection if the forward connection exist. For example:

In [10]:
import sax
coupler_dict = sax.reciprocal(
    {
        ("o1", "o4"): tau,
        ("o1", "o3"): 1j * kappa,
        ("o2", "o4"): 1j * kappa,
        ("o2", "o3"): tau,
    }
)

coupler_dict

{('o1', 'o4'): 0.7071067811865476,
 ('o1', 'o3'): 0.7071067811865476j,
 ('o2', 'o4'): 0.7071067811865476j,
 ('o2', 'o3'): 0.7071067811865476,
 ('o4', 'o1'): 0.7071067811865476,
 ('o3', 'o1'): 0.7071067811865476j,
 ('o4', 'o2'): 0.7071067811865476j,
 ('o3', 'o2'): 0.7071067811865476}

## Parametrized Models

Constructing such an `SDict` is easy, however, usually we're more interested in having parametrized models for our components. To parametrize the coupler `SDict`, just wrap it in a function to obtain a SAX `Model`, which is a keyword-only function mapping to an `SDict`:

In [11]:
def coupler(coupling=0.5) -> sax.SDict:
    kappa = coupling ** 0.5
    tau = (1 - coupling) ** 0.5
    coupler_dict = sax.reciprocal(
        {
            ("o1", "o4"): tau,
            ("o1", "o3"): 1j * kappa,
            ("o2", "o4"): 1j * kappa,
            ("o2", "o3"): tau,
        }
    )
    return coupler_dict


coupler(coupling=0.3)

{('o1', 'o4'): 0.8366600265340756,
 ('o1', 'o3'): 0.5477225575051661j,
 ('o2', 'o4'): 0.5477225575051661j,
 ('o2', 'o3'): 0.8366600265340756,
 ('o4', 'o1'): 0.8366600265340756,
 ('o3', 'o1'): 0.5477225575051661j,
 ('o4', 'o2'): 0.5477225575051661j,
 ('o3', 'o2'): 0.8366600265340756}

In [12]:
def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0) -> sax.SDict:
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("o1", "o2"): transmission,
        }
    )
    return sdict

## Component Models

### Waveguide model

You can create a dispersive waveguide model in SAX.

Lets compute the effective index `neff` and group index `ng` for a 1550nm 500nm straight waveguide

In [13]:
m = gm.find_mode_dispersion(wavelength=1.55)
print(m.neff, m.ng)

NameError: name 'gm' is not defined

In [14]:
straight_sc = gf.partial(gs.models.straight, neff=m.neff, ng=m.ng)

NameError: name 'm' is not defined

In [None]:
gs.plot_model(straight_sc)

In [None]:
gs.plot_model(straight_sc, phase=True)

### Coupler model

In [None]:
gm.find_coupling_vs_gap?

In [None]:
df = gm.find_coupling_vs_gap()
df

For a 200nm gap the effective index difference `dn` is `0.02`, which means that there is 100% power coupling over 38.2um

In [None]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.02, length=0, coupling0=0)
gs.plot_model(coupler_sc)

If we ignore the coupling from the bend `coupling0 = 0` we know that for a 3dB coupling we need half of the `lc` length, which is the length needed to coupler `100%` of power.

In [None]:
coupler_sc = gf.partial(gs.models.coupler, dn=0.02, length=38.2 / 2, coupling0=0)
gs.plot_model(coupler_sc)

### FDTD Sparameters model

You can also fit a model from Sparameter FDTD simulation data.

In [None]:
from gdsfactory.simulation.get_sparameters_path import get_sparameters_path_lumerical

filepath = get_sparameters_path_lumerical(gf.c.mmi1x2)
mmi1x2 = gf.partial(gs.read.sdict_from_csv, filepath=filepath)
gs.plot_model(mmi1x2)

## Circuit Models

You can combine component models into a circuit using `sax.circuit`, which basically creates a new `Model` function:

Lets define a [MZI interferometer](https://en.wikipedia.org/wiki/Mach%E2%80%93Zehnder_interferometer)

```
           _________
          |  top    |
          |         |
    lft===|         |===rgt
          |         |
          |_________|
             bot

               o1    top   o2
                 ----------
o2            o3           o2            o3
   \        /                 \        /
    ========                   ========
   /        \                 /        \
o1     lft    04           o1    rgt     04
                 ----------
               o1   bot    o2
```

In [None]:
waveguide = straight_sc
coupler = coupler_sc

mzi = sax.circuit(
    instances={
        "lft": coupler,
        "top": waveguide,
        "bot": waveguide,
        "rgt": coupler,
    },
    connections={
        "lft,o4": "bot,o1",
        "bot,o2": "rgt,o1",
        "lft,o3": "top,o1",
        "top,o2": "rgt,o2",
    },
    ports={
        "o1": "lft,o1",
        "o2": "lft,o2",
        "o4": "rgt,o4",
        "o3": "rgt,o3",
    },
)

The `circuit` function just creates a similar function as we created for the waveguide and the coupler, but in stead of taking parameters directly it takes parameter *dictionaries* for each of the instances in the circuit. The keys in these parameter dictionaries should correspond to the keyword arguments of each individual subcomponent.

Let's now do a simulation for the MZI we just constructed:

In [None]:
%time mzi()

In [None]:
import jax
import jax.example_libraries.optimizers as opt
import jax.numpy as jnp
import matplotlib.pyplot as plt  # plotting

mzi2 = jax.jit(mzi)

In [None]:
%time mzi2()

In [None]:
mzi(top={"length": 25.0}, btm={"length": 15.0})

In [None]:
wl = jnp.linspace(1.51, 1.59, 1000)
%time S = mzi(wl=wl, top={"length": 25.0}, btm={"length": 15.0})

In [None]:
plt.plot(wl * 1e3, abs(S["o1", "o3"]) ** 2, label="o3")
plt.plot(wl * 1e3, abs(S["o1", "o4"]) ** 2, label="o4")
plt.ylim(-0.05, 1.05)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.legend()
plt.show()

## Optimization

You can optimize an MZI to get T=0 at 1550nm.
To do this, you need to define a loss function for the circuit at 1550nm.
This function should take the parameters that you want to optimize as positional arguments:

In [None]:
@jax.jit
def loss(delta_length):
    S = mzi(wl=1.55, top={"length": 15.0 + delta_length}, btm={"length": 15.0})
    return (abs(S["o1", "o4"]) ** 2).mean()

In [None]:
%time loss(10.0)

You can use this loss function to define a grad function which works on the parameters of the loss function:

In [None]:
grad = jax.jit(
    jax.grad(
        loss,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

Next, you need to define a JAX optimizer, which on its own is nothing more than three more functions:

1. an initialization function with which to initialize the optimizer state
2. an update function which will update the optimizer state (and with it the model parameters).
3. a function with the model parameters given the optimizer state.

In [None]:
initial_delta_length = 10.0
optim_init, optim_update, optim_params = opt.adam(step_size=0.1)
optim_state = optim_init(initial_delta_length)

In [None]:
def train_step(step, optim_state):
    settings = optim_params(optim_state)
    lossvalue = loss(settings)
    gradvalue = grad(settings)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state

In [None]:
import tqdm

range_ = tqdm.trange(300)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

In [None]:
delta_length = optim_params(optim_state)
delta_length

In [None]:
S = mzi(wl=wl, top={"length": 15.0 + delta_length}, btm={"length": 15.0})
plt.plot(wl * 1e3, abs(S["o1", "o4"]) ** 2)
plt.xlabel("λ [nm]")
plt.ylabel("T")
plt.ylim(-0.05, 1.05)
plt.plot([1550, 1550], [0, 1])
plt.show()

The minimum of the MZI is perfectly located at 1550nm.

## Model fit

You can fit a sax model to Sparameter FDTD simulation data.

In [None]:
import tqdm
import jax
import jax.numpy as jnp
import jax.example_libraries.optimizers as opt
import matplotlib.pyplot as plt

import gdsfactory as gf
import gdsfactory.simulation.modes as gm
import gdsfactory.simulation.sax as gs

In [None]:
gf.config.sparameters_path

In [None]:
sd = gs.read.sdict_from_csv(
    gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)

In [None]:
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv,
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)

In [None]:
gs.plot_model(coupler_fdtd)

In [None]:
gs.plot_model(coupler_fdtd, ports2=("o3", "o4"))

In [None]:
modes = gm.find_modes_coupler(gap=0.224)
modes

In [None]:
dn = modes[1].neff - modes[2].neff
dn

In [None]:
coupler = gf.partial(gf.simulation.sax.models.coupler, dn=dn, length=20, coupling0=0.3)
gs.plot_model(coupler)

In [None]:
coupler_fdtd = gs.read.sdict_from_csv(
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)
S = coupler_fdtd
T_fdtd = abs(S["o1", "o3"]) ** 2
K_fdtd = abs(S["o1", "o4"]) ** 2


@jax.jit
def loss(coupling0, dn, dn1, dn2, dk1, dk2):
    """Returns fit least squares error from a coupler model spectrum
    to the FDTD Sparameter spectrum that we want to fit.
    
    Args:
        coupling0: coupling from the bend raegion
        dn: effective index difference between even and odd mode solver simulations.
        dn1: first derivative of effective index difference vs wavelength.
        dn2: second derivative of effective index difference vs wavelength.
        dk1: first derivative of coupling0 vs wavelength.
        dk2: second derivative of coupling vs wavelength.

    .. code::

          coupling0/2        coupling        coupling0/2
        <-------------><--------------------><---------->
         o2 ________                           _______o3
                    \                         /
                     \        length         /
                      ======================= gap
                     /                       \
            ________/                         \________
         o1                                           o4

                      ------------------------> K (coupled power)
                     /
                    / K
           -----------------------------------> T = 1 - K (transmitted power)

    T: o1 -> o4
    K: o1 -> o3
    """
    S = gf.simulation.sax.models.coupler(
        dn=dn, length=20, coupling0=coupling0, dn1=dn1, dn2=dn2, dk1=dk1, dk2=dk2
    )
    T_model = abs(S["o1", "o4"]) ** 2
    K_model = abs(S["o1", "o3"]) ** 2
    return jnp.abs(T_fdtd - T_model).mean() + jnp.abs(K_fdtd - K_model).mean()


loss(coupling0=0.3, dn=0.016, dk1=1.2435, dk2=5.3022, dn1=0.1169, dn2=0.4821)

In [None]:
grad = jax.jit(
    jax.grad(
        loss,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

In [None]:
def train_step(step, optim_state, dn, dn1, dn2, dk1, dk2):
    settings = optim_params(optim_state)
    lossvalue = loss(settings, dn, dn1, dn2, dk1, dk2)
    gradvalue = grad(settings, dn, dn1, dn2, dk1, dk2)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state


coupling0 = 0.3
optim_init, optim_update, optim_params = opt.adam(step_size=0.1)
optim_state = optim_init(coupling0)

dn = 0.0166
dn1 = 0.11
dn2 = 0.48
dk1 = 1.2
dk2 = 5

range_ = tqdm.trange(300)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state, dn, dn1, dn2, dk1, dk2)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

In [None]:
coupling0_fit = optim_params(optim_state)
coupling0_fit

In [None]:
coupler = gf.partial(
    gf.simulation.sax.models.coupler, dn=dn, length=20, coupling0=coupling0_fit
)
gs.plot_model(coupler)

In [None]:
wl = jnp.linspace(1.50, 1.60, 1000)
S = gf.simulation.sax.models.coupler(
    dn=dn, length=20, coupling0=coupling0_fit, dn1=dn1, dn2=dn2, dk1=dk1, dk2=dk2, wl=wl
)
T_model = abs(S["o1", "o4"]) ** 2
K_model = abs(S["o1", "o3"]) ** 2

In [None]:
coupler_fdtd = S = gs.read.sdict_from_csv(
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
    wl=wl,
)
T_fdtd = abs(S["o1", "o3"]) ** 2
K_fdtd = abs(S["o1", "o4"]) ** 2

In [None]:
plt.plot(wl, T_fdtd, label="fdtd", c="b")
plt.plot(wl, T_model, label="fit", c="b", ls="-.")
plt.plot(wl, K_fdtd, label="fdtd", c="r")
plt.plot(wl, K_model, label="fit", c="r", ls="-.")
plt.legend()

### Multi-variable optimization

As you can see we need to fit more than 1 variable `coupling0` to get a good fit.

In [None]:
grad = jax.jit(
    jax.grad(
        loss,
        # argnums=0,  # JAX gradient function for the first positional argument, jitted
        argnums=[
            0,
            1,
            2,
            3,
            4,
            5,
        ],  # JAX gradient function for all positional arguments, jitted
    )
)

In [None]:
def train_step(step, optim_state):
    coupling0, dn, dn1, dn2, dk1, dk2 = optim_params(optim_state)
    lossvalue = loss(coupling0, dn, dn1, dn2, dk1, dk2)
    gradvalue = grad(coupling0, dn, dn1, dn2, dk1, dk2)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state

In [None]:
coupling0 = 0.3
dn = 0.0166
dn1 = 0.11
dn2 = 0.48
dk1 = 1.2
dk2 = 5.0
optim_init, optim_update, optim_params = opt.adam(step_size=0.01)
optim_state = optim_init((coupling0, dn, dn1, dn2, dk1, dk2))

In [None]:
range_ = tqdm.trange(1000)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

In [None]:
coupling0_fit, dn_fit, dn1_fit, dn2_fit, dk1_fit, dk2_fit = optim_params(optim_state)
coupling0_fit, dn_fit, dn1_fit, dn2_fit, dk1_fit, dk2_fit

In [None]:
wl = jnp.linspace(1.5, 1.60, 1000)
coupler_fdtd = gs.read.sdict_from_csv(
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    wl=wl,
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
)
S = coupler_fdtd
T_fdtd = abs(S["o1", "o3"]) ** 2
S = gf.simulation.sax.models.coupler(
    dn=dn_fit,
    length=20,
    coupling0=coupling0_fit,
    dn1=dn1_fit,
    dn2=dn2_fit,
    dk1=dk1_fit,
    dk2=dk2_fit,
    wl=wl,
)
T_model = abs(S["o1", "o4"]) ** 2
K_model = abs(S["o1", "o3"]) ** 2

plt.plot(wl, T_fdtd, label="fdtd", c="b")
plt.plot(wl, T_model, label="fit", c="b", ls="-.")
plt.plot(wl, K_fdtd, label="fdtd", c="r")
plt.plot(wl, K_model, label="fit", c="r", ls="-.")
plt.legend()

As you can see trying to fit many parameters do not give you a better fit,

you have to make sure you fit the right parameters, in this case `dn1`

In [None]:
wl = jnp.linspace(1.50, 1.60, 1000)
S = gf.simulation.sax.models.coupler(
    dn=dn_fit,
    length=20,
    coupling0=coupling0_fit,
    dn1=dn1_fit - 0.045,
    dn2=dn2_fit,
    dk1=dk1_fit,
    dk2=dk2_fit,
    wl=wl,
)
T_model = abs(S["o1", "o4"]) ** 2
K_model = abs(S["o1", "o3"]) ** 2

plt.plot(wl, T_fdtd, label="fdtd", c="b")
plt.plot(wl, T_model, label="fit", c="b", ls="-.")
plt.plot(wl, K_fdtd, label="fdtd", c="r")
plt.plot(wl, K_model, label="fit", c="r", ls="-.")
plt.legend()

In [None]:
dn = dn_fit
dn2 = dn2_fit
dk1 = dk1_fit
dk2 = dk2_fit


@jax.jit
def loss(dn1):
    """Returns fit least squares error from a coupler model spectrum
    to the FDTD Sparameter spectrum that we want to fit.

    """
    S = gf.simulation.sax.models.coupler(
        dn=dn, length=20, coupling0=coupling0, dn1=dn1, dn2=dn2, dk1=dk1, dk2=dk2
    )
    T_model = jnp.abs(S["o1", "o4"]) ** 2
    K_model = jnp.abs(S["o1", "o3"]) ** 2
    return jnp.abs(T_fdtd - T_model).mean() + jnp.abs(K_fdtd - K_model).mean()


grad = jax.jit(
    jax.grad(
        loss,
        argnums=0,  # JAX gradient function for the first positional argument, jitted
    )
)

dn1 = 0.11
optim_init, optim_update, optim_params = opt.adam(step_size=0.001)
optim_state = optim_init(dn1)


def train_step(step, optim_state):
    settings = optim_params(optim_state)
    lossvalue = loss(settings)
    gradvalue = grad(settings)
    optim_state = optim_update(step, gradvalue, optim_state)
    return lossvalue, optim_state


range_ = tqdm.trange(300)
for step in range_:
    lossvalue, optim_state = train_step(step, optim_state)
    range_.set_postfix(loss=f"{lossvalue:.6f}")

In [None]:
dn1_fit = optim_params(optim_state)
dn1_fit

In [None]:
wl = jnp.linspace(1.50, 1.60, 1000)
S = gf.simulation.sax.models.coupler(
    dn=dn, length=20, coupling0=coupling0, dn1=dn1_fit, dn2=dn2, dk1=dk1, dk2=dk2, wl=wl
)
T_model = abs(S["o1", "o4"]) ** 2
K_model = abs(S["o1", "o3"]) ** 2

coupler_fdtd = gs.read.sdict_from_csv(
    filepath=gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv",
    xkey="wavelength_nm",
    prefix="S",
    xunits=1e-3,
    wl=wl,
)
S = coupler_fdtd
T_fdtd = abs(S["o1", "o3"]) ** 2
K_fdtd = abs(S["o1", "o4"]) ** 2

plt.plot(wl, T_fdtd, label="fdtd", c="b")
plt.plot(wl, T_model, label="fit", c="b", ls="-.")
plt.plot(wl, K_fdtd, label="fdtd", c="r")
plt.plot(wl, K_model, label="fit", c="r", ls="-.")
plt.legend()

## Model fit (linear regression)

For a better fit of the coupler we can build a linear regression model of the coupler with `sklearn`

In [None]:
import sax
import gdsfactory as gf
import gdsfactory.simulation.sax as gs
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from scipy.constants import c
from sklearn.linear_model import LinearRegression

In [None]:
f = jnp.linspace(c / 1.0e-6, c / 2.0e-6, 500) * 1e-12  # THz
wl = c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv, filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)

k = sd["o1", "o3"]
t = sd["o1", "o4"]
s = t + k
a = t - k

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

### Symmetric

In [None]:
plt.plot(wl, jnp.abs(s))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("symmetric (transmission + coupling)")
plt.legend()
plt.show()

In [None]:
plt.plot(wl, jnp.abs(a))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.title("anti-symmetric (transmission - coupling)")
plt.legend()
plt.show()

In [None]:
r = LinearRegression()
fX = lambda x, _order=8: x[:, None] ** (
    jnp.arange(_order)[None, :]
)  # artificially create more 'features' (wl**2, wl**3, wl**4, ...)
X = fX(wl)
r.fit(X, jnp.abs(s))
asm, bsm = r.coef_, r.intercept_
fsm = lambda x: fX(x) @ asm + bsm  # fit symmetric module fiir

plt.plot(wl, jnp.abs(s))
plt.plot(wl, fsm(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()

In [None]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(s)))
asp, bsp = r.coef_, r.intercept_
fsp = lambda x: fX(x) @ asp + bsp  # fit symmetric phase

plt.plot(wl, jnp.unwrap(jnp.angle(s)))
plt.plot(wl, fsp(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [deg]")
plt.legend()
plt.show()

In [None]:
fs = lambda x: fsm(x) * jnp.exp(1j * fsp(x))

Lets fit the symmetric (t+k) and antisymmetric (t-k) transmission

### Anti-Symmetric

In [None]:
r = LinearRegression()
r.fit(X, jnp.abs(a))
aam, bam = r.coef_, r.intercept_
fam = lambda x: fX(x) @ aam + bam

plt.plot(wl, jnp.abs(a))
plt.plot(wl, fam(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.legend()
plt.show()

In [None]:
r = LinearRegression()
r.fit(X, jnp.unwrap(jnp.angle(a)))
aap, bap = r.coef_, r.intercept_
fap = lambda x: fX(x) @ aap + bap

plt.plot(wl, jnp.unwrap(jnp.angle(a)))
plt.plot(wl, fap(wl))
plt.grid(True)
plt.xlabel("Frequency [THz]")
plt.ylabel("Angle [deg]")
plt.legend()
plt.show()

In [None]:
fa = lambda x: fam(x) * jnp.exp(1j * fap(x))

### Total

In [None]:
t_ = 0.5 * (fs(wl) + fa(wl))

plt.plot(wl, jnp.abs(t))
plt.plot(wl, jnp.abs(t_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")

In [None]:
k_ = 0.5 * (fs(wl) - fa(wl))

plt.plot(wl, jnp.abs(k))
plt.plot(wl, jnp.abs(k_))
plt.xlabel("Frequency [THz]")
plt.ylabel("Coupling")

In [None]:
@jax.jit
def coupler(wl=1.5):
    wl = jnp.asarray(wl)
    wl_shape = wl.shape
    wl = wl.ravel()
    t = (0.5 * (fs(wl) + fa(wl))).reshape(*wl_shape)
    k = (0.5 * (fs(wl) - fa(wl))).reshape(*wl_shape)
    sdict = {
        ("o1", "o4"): t,
        ("o1", "o3"): k,
        ("o2", "o3"): k,
        ("o2", "o4"): t,
    }
    return sax.reciprocal(sdict)

In [None]:
f = jnp.linspace(c / 1.0e-6, c / 2.0e-6, 500) * 1e-12  # THz
wl = c / (f * 1e12) * 1e6  # um

filepath = gf.config.sparameters_path / "coupler" / "coupler_G224n_L20_S220.csv"
coupler_fdtd = gf.partial(
    gs.read.sdict_from_csv, filepath, xkey="wavelength_nm", prefix="S", xunits=1e-3
)
sd = coupler_fdtd(wl=wl)
sd_ = coupler(wl=wl)

T = jnp.abs(sd["o1", "o4"]) ** 2
K = jnp.abs(sd["o1", "o3"]) ** 2
T_ = jnp.abs(sd_["o1", "o4"]) ** 2
K_ = jnp.abs(sd_["o1", "o3"]) ** 2
dP = jnp.unwrap(jnp.angle(sd["o1", "o3"]) - jnp.angle(sd["o1", "o4"]))
dP_ = jnp.unwrap(jnp.angle(sd_["o1", "o3"]) - jnp.angle(sd_["o1", "o4"]))

plt.figure(figsize=(12, 3))
plt.plot(wl, T, label="T (fdtd)", c="C0", ls=":", lw="6")
plt.plot(wl, T_, label="T (model)", c="C0")

plt.plot(wl, K, label="K (fdtd)", c="C1", ls=":", lw="6")
plt.plot(wl, K_, label="K (model)", c="C1")

plt.ylim(-0.05, 1.05)
plt.grid(True)

plt.twinx()
plt.plot(wl, dP, label="ΔΦ (fdtd)", color="C2", ls=":", lw="6")
plt.plot(wl, dP_, label="ΔΦ (model)", color="C2")

plt.xlabel("Frequency [THz]")
plt.ylabel("Transmission")
plt.figlegend(bbox_to_anchor=(1.08, 0.9))
plt.savefig("fdtd_vs_model.png", bbox_inches="tight")
plt.show()