# Multimode simulations


In [2]:
import sax
import jax
import jax.numpy as jnp


## Ports and modes per port

Let’s denote a combination of a port and a mode by a string of the following format: `{port}@{mode}`. We can obtain all possible port-mode combinations with some magic itertools functions:

In [3]:
from itertools import combinations_with_replacement, combinations, product

ports = ["in0", "out0"]
modes = ["te", "tm"]
portmodes = [
    (f"{p1}@{m1}", f"{p2}@{m2}")
    for (p1, m1), (p2, m2) in combinations_with_replacement(product(ports, modes), 2)
]
portmodes



[('in0@te', 'in0@te'),
 ('in0@te', 'in0@tm'),
 ('in0@te', 'out0@te'),
 ('in0@te', 'out0@tm'),
 ('in0@tm', 'in0@tm'),
 ('in0@tm', 'out0@te'),
 ('in0@tm', 'out0@tm'),
 ('out0@te', 'out0@te'),
 ('out0@te', 'out0@tm'),
 ('out0@tm', 'out0@tm')]

In [4]:
portmodes_without_backreflection = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[0] != p2.split("@")[0]
]
portmodes_without_backreflection



[('in0@te', 'out0@te'),
 ('in0@te', 'out0@tm'),
 ('in0@tm', 'out0@te'),
 ('in0@tm', 'out0@tm')]

Sometimes cross-polarization terms can also be ignored:

In [5]:
portmodes_without_crosspolarization = [
    (p1, p2) for p1, p2 in portmodes if p1.split("@")[1] == p2.split("@")[1]
]
portmodes_without_crosspolarization

[('in0@te', 'in0@te'),
 ('in0@te', 'out0@te'),
 ('in0@tm', 'in0@tm'),
 ('in0@tm', 'out0@tm'),
 ('out0@te', 'out0@te'),
 ('out0@tm', 'out0@tm')]

## Multimode waveguide

Let’s create a waveguide with two ports ("in", "out") and two modes ("te", "tm") without backreflection. Let’s assume there is 5% cross-polarization and that the "tm"->``”tm”`` transmission is 10% worse than the "te"->``”te”`` transmission. Naturally in more realisic waveguide models these percentages will be length-dependent, but this is just a dummy model serving as an example.

In [6]:
def waveguide(wl=1.55, wl0=1.55, neff=2.34, ng=3.4, length=10.0, loss=0.0):
    """a simple straight waveguide model

    Args:
        wl: wavelength
        neff: waveguide effective index
        ng: waveguide group index (used for linear neff dispersion)
        wl0: center wavelength at which neff is defined
        length: [m] wavelength length
        loss: [dB/m] waveguide loss
    """
    dwl = wl - wl0
    dneff_dwl = (ng - neff) / wl0
    neff = neff - dwl * dneff_dwl
    phase = 2 * jnp.pi * neff * length / wl
    transmission = 10 ** (-loss * length / 20) * jnp.exp(1j * phase)
    sdict = sax.reciprocal(
        {
            ("in0@te", "out0@te"): 0.95 * transmission,  # 5% lost to cross-polarization
            ("in0@te", "out0@tm"): 0.05 * transmission,  # 5% cross-polarization
            ("in0@tm", "out0@tm"): 0.85 * transmission,  # 10% worse tm->tm than te->te
            ("in0@tm", "out0@te"): 0.05 * transmission,  # 5% cross-polarization
        }
    )
    return sdict


waveguide()



{('in0@te',
  'out0@te'): DeviceArray(0.77972656+0.54270285j, dtype=complex64, weak_type=True),
 ('in0@te',
  'out0@tm'): DeviceArray(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('in0@tm',
  'out0@tm'): DeviceArray(0.69765013+0.48557627j, dtype=complex64, weak_type=True),
 ('in0@tm',
  'out0@te'): DeviceArray(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('out0@te',
  'in0@te'): DeviceArray(0.77972656+0.54270285j, dtype=complex64, weak_type=True),
 ('out0@tm',
  'in0@te'): DeviceArray(0.04103824+0.02856331j, dtype=complex64, weak_type=True),
 ('out0@tm',
  'in0@tm'): DeviceArray(0.69765013+0.48557627j, dtype=complex64, weak_type=True),
 ('out0@te',
  'in0@tm'): DeviceArray(0.04103824+0.02856331j, dtype=complex64, weak_type=True)}

## Multimode  MZI

We can now combine these models into a circuit in much the same way as before. We just need to add the modes= keyword:

In [7]:
mzi = sax.circuit(
    instances={
        "lft": sax.models.coupler,  # single mode models will be automatically converted to multimode models without cross polarization.
        "top": sax.partial(waveguide, length=25.0),
        "btm": sax.partial(waveguide, length=15.0),
        "rgt": sax.models.coupler,  # single mode models will be automatically converted to multimode models without cross polarization.
    },
    connections={
        "lft:out0": "btm:in0",
        "btm:out0": "rgt:in0",
        "lft:out1": "top:in0",
        "top:out0": "rgt:in1",
    },
    ports={
        "in0": "lft:in0",
        "in1": "lft:in1",
        "out0": "rgt:out0",
        "out1": "rgt:out1",
    },
    modes=("te", "tm"),
)

In [8]:
mzi()

{('in0@te', 'in1@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1@te', 'in0@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in0@tm', 'in1@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1@tm', 'in0@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out0@te', 'out1@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out1@te', 'out0@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out0@tm', 'out1@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out1@tm', 'out0@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1@te', 'in1@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1@te', 'in1@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1@te', 'in0@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in0@te', 'in0@te'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in0@te', 'in1@tm'): DeviceArray(0.+0.j, dtype=complex64, weak_typ

we can convert this model back to a singlemode SDict as follows:

In [9]:
mzi_te = sax.singlemode(mzi, mode="te")
mzi_te()



{('in0', 'in1'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1', 'in0'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out0', 'out1'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out1', 'out0'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in1', 'in1'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('in0', 'in0'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out0', 'out0'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out0',
  'in1'): DeviceArray(0.8500066-0.31481263j, dtype=complex64, weak_type=True),
 ('out0',
  'in0'): DeviceArray(-0.26669368+0.09877402j, dtype=complex64, weak_type=True),
 ('out1', 'out1'): DeviceArray(0.+0.j, dtype=complex64, weak_type=True),
 ('out1',
  'in1'): DeviceArray(0.26669368-0.09877402j, dtype=complex64, weak_type=True),
 ('out1',
  'in0'): DeviceArray(0.8500066-0.31481263j, dtype=complex64, weak_type=True),
 ('in1',
  'out0'): DeviceArray(0.8500066-0.31481263j, dtype=comp

Or, if we don’t supply a mode to select, the mean of each S-parameter over all modes is used.

In [None]:
mzi_mean = sax.singlemode(mzi)
mzi_mean()

