# Multimode simulations
> SAX can handle multiple modes too!

In [1]:
from itertools import combinations_with_replacement, product

import jax.numpy as jnp
import sax

## Ports and modes per port

Let's denote a combination of a port and a mode by a string of the following format: `"{port}@{mode}"`. We can obtain all possible port-mode combinations with some magic itertools functions:

In [2]:
ports = ["in0", "out0"]
modes = ["te", "tm"]
portmodes = [
    (f"{p1}@{m1}", f"{p2}@{m2}")
    for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)
]
portmodes

[('in0@te', 'in0@te'),
 ('in0@te', 'in0@tm'),
 ('in0@te', 'out0@te'),
 ('in0@te', 'out0@tm'),
 ('in0@tm', 'in0@tm'),
 ('in0@tm', 'out0@te'),
 ('in0@tm', 'out0@tm'),
 ('out0@te', 'out0@te'),
 ('out0@te', 'out0@tm'),
 ('out0@tm', 'out0@tm')]

If we would disregard any backreflection, this can be further simplified:

In [3]:
portmodes_without_backreflection = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[0] != p2.split("@")[0]
]
portmodes_without_backreflection

[('in0@te', 'out0@te'),
 ('in0@te', 'out0@tm'),
 ('in0@tm', 'out0@te'),
 ('in0@tm', 'out0@tm')]

Sometimes cross-polarization terms can also be ignored:

In [4]:
portmodes_without_crosspolarization = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[1] == p2.split("@")[1]
]
portmodes_without_crosspolarization

[('in0@te', 'in0@te'),
 ('in0@te', 'out0@te'),
 ('in0@tm', 'in0@tm'),
 ('in0@tm', 'out0@tm'),
 ('out0@te', 'out0@te'),
 ('out0@tm', 'out0@tm')]

## Multimode waveguide

Let's create a waveguide with two ports (`"in"`, `"out"`) and two modes (`"te"`, `"tm"`) without backreflection. Let's assume there is 5% cross-polarization and that the `"tm"`->`"tm"` transmission is 10% worse than the `"te"`->`"te"` transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example.

In [5]:
def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
    """a simple straight waveguide model

    Args:
        wl: wavelength
        neff: waveguide effective index
        ng: waveguide group index (used for linear neff dispersion)
        wl0: center wavelength at which neff is defined
        length: [m] wavelength length
        loss: [dB/m] waveguide loss
    """
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0@te", "out0@te"): 0.95 * transmission,  # 5% lost to cross-polarization
            ("in0@te", "out0@tm"): 0.05 * transmission,  # 5% cross-polarization
            ("in0@tm", "out0@tm"): 0.85 * transmission,  # 10% worse tm->tm than te->te
            ("in0@tm", "out0@te"): 0.05 * transmission,  # 5% cross-polarization
        }
    )
    return sdict


waveguide()

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


{('in0@te',
  'out0@te'): Array(0.77972656+0.54270285j, dtype=complex64, weak_type=True),
 ('in0@te',
  'out0@tm'): Array(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('in0@tm',
  'out0@tm'): Array(0.69765013+0.48557627j, dtype=complex64, weak_type=True),
 ('in0@tm',
  'out0@te'): Array(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('out0@te',
  'in0@te'): Array(0.77972656+0.54270285j, dtype=complex64, weak_type=True),
 ('out0@tm',
  'in0@te'): Array(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('out0@tm',
  'in0@tm'): Array(0.69765013+0.48557627j, dtype=complex64, weak_type=True),
 ('out0@te',
  'in0@tm'): Array(0.04103824+0.02856331j, dtype=complex64, weak_type=True)}

## Multimode MZI

We can now combine these models into a circuit in much the same way as before. We just need to add the `modes=` keyword:

In [6]:
mzi, _ = sax.circuit(
    netlist={
        "instances": {
            "lft": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
            "top": {"component": "straight", "settings": {"length": 25.0}},
            "btm": {"component": "straight", "settings": {"length": 15.0}},
            "rgt": "coupler",  # single mode models will be automatically converted to multimode models without cross polarization.
        },
        "connections": {
            "lft,out0": "btm,in0",
            "btm,out0": "rgt,in0",
            "lft,out1": "top,in0",
            "top,out0": "rgt,in1",
        },
        "ports": {
            "in0": "lft,in0",
            "in1": "lft,in1",
            "out0": "rgt,out0",
            "out1": "rgt,out1",
        },
    },
    models=sax.get_models(),
    modes=("te", "tm"),
)

In [7]:
mzi()

{('in0@te', 'in0@te'): Array(0.+0.j, dtype=complex64),
 ('in0@te', 'in1@te'): Array(0.+0.j, dtype=complex64),
 ('in0@te', 'in1@tm'): Array(0.+0.j, dtype=complex64),
 ('in0@te', 'in0@tm'): Array(0.+0.j, dtype=complex64),
 ('in1@te', 'in0@te'): Array(0.+0.j, dtype=complex64),
 ('in1@te', 'in1@te'): Array(0.+0.j, dtype=complex64),
 ('in1@te', 'in1@tm'): Array(0.+0.j, dtype=complex64),
 ('in1@te', 'in0@tm'): Array(0.+0.j, dtype=complex64),
 ('in1@tm', 'in0@te'): Array(0.+0.j, dtype=complex64),
 ('in1@tm', 'in1@te'): Array(0.+0.j, dtype=complex64),
 ('in1@tm', 'in1@tm'): Array(0.+0.j, dtype=complex64),
 ('in1@tm', 'in0@tm'): Array(0.+0.j, dtype=complex64),
 ('in0@tm', 'in0@te'): Array(0.+0.j, dtype=complex64),
 ('in0@tm', 'in1@te'): Array(0.+0.j, dtype=complex64),
 ('in0@tm', 'in1@tm'): Array(0.+0.j, dtype=complex64),
 ('in0@tm', 'in0@tm'): Array(0.+0.j, dtype=complex64),
 ('in0@te', 'out1@tm'): Array(0.+0.j, dtype=complex64),
 ('in0@te', 'out0@te'): Array(-0.28073177+0.10396835j, dtype=com

we can convert this model back to a singlemode `SDict` as follows:

In [8]:
mzi_te = sax.singlemode(mzi, mode="te")
mzi_te()

{('in0', 'in0'): Array(0.+0.j, dtype=complex64),
 ('in0', 'in1'): Array(0.+0.j, dtype=complex64),
 ('in1', 'in0'): Array(0.+0.j, dtype=complex64),
 ('in1', 'in1'): Array(0.+0.j, dtype=complex64),
 ('in0', 'out0'): Array(-0.28073177+0.10396835j, dtype=complex64),
 ('in0', 'out1'): Array(0.8947488-0.33136806j, dtype=complex64),
 ('in1', 'out0'): Array(0.8947488-0.33136806j, dtype=complex64),
 ('in1', 'out1'): Array(0.28073177-0.10396835j, dtype=complex64),
 ('out0', 'in0'): Array(-0.28073177+0.10396835j, dtype=complex64),
 ('out0', 'in1'): Array(0.8947488-0.33136806j, dtype=complex64),
 ('out0', 'out0'): Array(0.+0.j, dtype=complex64),
 ('out0', 'out1'): Array(0.+0.j, dtype=complex64),
 ('out1', 'in0'): Array(0.8947488-0.33136806j, dtype=complex64),
 ('out1', 'in1'): Array(0.28073177-0.10396835j, dtype=complex64),
 ('out1', 'out0'): Array(0.+0.j, dtype=complex64),
 ('out1', 'out1'): Array(0.+0.j, dtype=complex64)}