In [24]:
import meow as mw
import jax.numpy as np
import matplotlib.pyplot as plt

## Preparation
let's perform a mode simulation

In [25]:
length = 10.0
box = mw.Box(
    x_min=-0.22,
    x_max=0.22,
    y_min=0,
    y_max=0.22,
    z_min=0.0,
    z_max=length,
)

In [26]:
struct = mw.Structure(material=mw.silicon, geometry=box)

In [27]:
env = mw.Environment(wl=1.55, T=25.0)

In [28]:
cells = mw.create_cells(
    structures=[struct],
    mesh=mw.Mesh2d(
        x = np.linspace(-1, 1, 101),
        y = np.linspace(-1, 1, 101),
    ),
    Ls = [length/3]*3
)
css = [mw.CrossSection(cell=cell, env=env) for cell in cells]

In [29]:
modes = [mw.compute_modes(cs, num_modes=2) for cs in css]

In [30]:
from meow.eme import compute_interface_s_matrices, compute_propagation_s_matrices
from meow.eme import compute_interface_s_matrix
from meow.eme.sax import _get_netlist
import sax
from sax.backends import circuit_backends
evaluate_circuit = circuit_backends["klu"]

In [31]:
propagations = compute_propagation_s_matrices(modes)
interfaces = compute_interface_s_matrices(
    modes, enforce_reciprocity=False,
)

unities = [compute_interface_s_matrix(mode, mode, enforce_reciprocity=False) for mode in modes]


## TODO get rid of unities to increase efficiency

In [32]:
def _connect_two(l, r):
    """l -> left, r -> right"""
    #TODO there must be an easier way to do this...
    s_l, p_l = sax.sdense(l)
    s_r, p_r = sax.sdense(r)
    instances = {"l": l, "r": r}
    p_lr = [p for p in p_l.keys() if "right" in p] #right ports of left
    p_rl = [p for p in p_r.keys() if "left" in p] #left ports of right

    p_ll = [p for p in p_l.keys() if "left" in p] #left ports of left
    p_rr = [p for p in p_r.keys() if "right" in p] #right ports of right

    p_lr.sort()
    p_rl.sort()
    connections = {f"l,{pl}": f"r,{pr}" for pl,pr in zip(p_lr, p_rl)}
    ports = {
        **{p: f"l,{p}" for p in p_ll},
        **{p: f"r,{p}" for p in p_rr}
    }
    net = dict(
        instances = instances,
        connections = connections,
        ports = ports
    )
    return evaluate_circuit(**net)

In [33]:
def pi_pairs(propagations, interfaces, unities):
    """generates the S-matrices of cells: a combination of propagation and interface matrix"""
    S = []
    for i in range(len(propagations)):
        p = propagations[f"p_{i}"]
        if i == len(interfaces):
            S.append(p)
        else:
            c = interfaces[f"i_{i}_{i+1}"]
            S.append(_connect_two(p, c))
    
    return S

In [34]:
pairs = pi_pairs(propagations, interfaces, unities)

In [35]:
def l2r_matrices(pairs):
    Ss=[pairs[0]]

    for p in pairs[1:]:
        Ss.append(_connect_two(Ss[-1], p))
    
    return Ss
    
l2rs = l2r_matrices(pairs)

In [36]:
def r2l_matrices(pairs):
    Ss=[pairs[-1]]

    for p in pairs[-1::-1]:
        Ss.append(_connect_two(p, Ss[-1]))
    
    return Ss[::-1]

r2ls = r2l_matrices(pairs)

let's assume excitation only from the left

In [39]:
excitation_l = np.zeros(len(modes[0]))
excitation_l = excitation_l.at[0].set(1)
excitation_r = np.zeros(len(modes[-1]))
amplitudes = []

In [40]:
l2rs[0]

(Array([[ 7.14851583e-22+4.92011923e-22j,  9.04807018e-11-3.31219244e-11j,
          9.54920193e-01+2.96862637e-01j, -8.60058235e-12-2.67372245e-12j],
        [-1.04431652e-10+3.82288953e-11j,  2.63447864e-22-9.66346452e-22j,
          7.15731899e-12-5.46728267e-12j,  7.94677102e-01-6.07032375e-01j],
        [ 9.54920193e-01+2.96862637e-01j, -7.15733724e-12+5.46729661e-12j,
          8.67806729e-22-1.76601327e-44j,  9.63525780e-11+1.17997879e-26j],
        [ 8.60056042e-12+2.67371564e-12j,  7.94677102e-01-6.07032375e-01j,
         -1.11208895e-10+1.36191614e-26j,  1.00161382e-21-8.11315872e-43j]],      dtype=complex128),
 {'left@0': 0, 'left@1': 1, 'right@0': 2, 'right@1': 3})

In [51]:
def split_square_matrix(matrix, idx):
    if matrix.shape[0] != matrix.shape[1]:
        raise ValueError("Matrix has to be square")
    return [matrix[:idx,:idx], matrix[idx:,:idx]], [matrix[:idx,idx:], matrix[idx:,idx:]]

    # Attention! I am not sure the indexing is correct here!! Testing needed

In [52]:
forwards=[]
backwards=[]
for l2r, r2l in zip(l2rs, r2ls):
    s_l2r = l2r[0]
    s_r2l = r2l[0]

    m = len([k for k in l2r[1].keys() if "right" in k])
    n = s_l2r.shape[0] - m
    l = s_r2l.shape[0] - m
    [u11, u21],[u12, u22] = split_square_matrix(s_l2r, n)
    [v11, v21],[v12, v22] = split_square_matrix(s_r2l, m)

    RHS = u21@excitation_l + u22@v12@excitation_r
    LHS = np.diag(np.ones(m)) - u22@v11
    forward = np.linalg.solve(LHS, RHS)
    backward = v12@excitation_r + v11@forward

    forwards.append(forward)
    backwards.append(backward)



In [53]:
forwards


[Array([9.54920193e-01+2.96862637e-01j, 8.60056042e-12+2.67371564e-12j],      dtype=complex128),
 Array([8.23745149e-01+5.66960254e-01j, 1.58768229e-11+2.01029208e-12j],      dtype=complex128),
 Array([6.1830156e-01+7.85940952e-01j, 1.3837260e-11-8.04021244e-12j],      dtype=complex128)]

In [54]:
backwards

[Array([ 3.84826286e-21+2.14923857e-21j, -2.13484375e-10+4.88540246e-11j],      dtype=complex128),
 Array([ 1.81303514e-21+4.66606579e-22j, -1.07699331e-10-2.77177288e-11j],      dtype=complex128),
 Array([0.+0.j, 0.+0.j], dtype=complex128)]