In [None]:
import gdsfactory as gf
import meow as mw
import numpy as np

In [None]:
def example_extrusions(
    t_soi: float = 0.4,
):
    """create some simple extrusion rules

    Args:
        t_slab: the slab thickness
        t_soi: the SOI thickness
        t_ox: the oxide layer thickness
    """
    extrusions = {
        (1, 0): [
            mw.GdsExtrusionRule(
                material=mw.silicon,
                h_min=0.0,
                h_max=0.0 + t_soi,
                mesh_order=1,
            ),
        ],
    }
    return extrusions

In [None]:
l_taper = 20
l_center = 20
w_center = 3

mmi = gf.components.mmi2x2(
    length_taper=l_taper, length_mmi=l_center, width_mmi=w_center
)

c = gf.Component()
ref = c.add_ref(mmi)
ref.xmin = 0
mmi = c

extrusion_rules = example_extrusions()
structs = mw.extrude_gds(mmi, extrusion_rules)

# mw.visualize(structs, scale=(1, 1, 0.2))

In [None]:
mmi.plot()

In [None]:
eps = 1e-10
w_sim = w_center + 2
h_sim = 2.0
mesh = 100
num_cells = 5
Ls = [l_taper / num_cells] * num_cells
Ls[-1] -= eps
Ls[0] -= eps
Ls += [2 * eps, l_center - 2 * eps, 2 * eps] + Ls
Ls = [eps] + Ls + [eps]
print(Ls)

cells = mw.create_cells(
    structures=structs,
    mesh=mw.Mesh2d(
        x=np.linspace(-w_sim / 2, w_sim / 2, mesh + 1),
        y=np.linspace(-h_sim / 2, h_sim / 2, mesh + 1),
    ),
    Ls=Ls,
    ez_interfaces=True,
)

# for cell in cells:
#  mw.visualize(cell)

In [None]:
env = mw.Environment(wl=1.55, T=25.0)
css = [mw.CrossSection(cell=cell, env=env) for cell in cells]

In [None]:
num_modes = 6
modes = mw.compute_modes(css[0], num_modes=num_modes)

mw.visualize(modes)

In [None]:
mw.compute_interface_s_matrix(modes[:1], modes[:2])

In [None]:
from tqdm.notebook import tqdm

modes = [mw.compute_modes(cs, num_modes=num_modes) for cs in tqdm(css)]

In [None]:
print([[mode.neff for mode in modes_] for modes_ in modes])

In [None]:
S, port_map = mw.compute_s_matrix(modes)
print(port_map)
mw.visualize(S)

In [None]:
from meow.port import Port

ports = tuple(
    [
        [
            Port(
                extend_x=ex_x,
                extend_y=(-np.inf, np.inf),
                fg_structures=structs,
                num_modes=2,
            )
            for ex_x in [(-np.inf, 0), (0, np.inf)]
        ]
        for i in [0, 1]
    ]
)

from meow.cross_section import CrossSection
from meow.port import Ports
from meow.mode import Modes, inner_product
from meow.eme.propagate import _connect_two
from typing import Tuple
import sax

def compute_port_modes(cs: CrossSection, ports: Ports):
    """computes the set of modes for a set of ports on a CrossSection"""
    modes = []
    for port in ports:
        modes+=port.compute_modes(cs)
    return modes

def overlap_matrix(modes_l: Modes, modes_r: Modes):
    """compute the overlaps between port and crosssection modes used for deembedding the inner S-matrix"""
    forward = {
        (f"left@{m}",f"right@{n}"):inner_product(modes_l[m], modes_r[n])
        for n in range(len(modes_r))
        for m in range(len(modes_l))
    }
    return sax.reciprocal(forward)

def outer_S_matrix(modes: Modes, ports: Tuple[Ports,Ports], inner_S):
    """Deembed the inner S-matrix with respect to the given ports"""
    port_modes_l = compute_port_modes(modes[0][0].cs, ports[0])
    port_modes_r = compute_port_modes(modes[-1][0].cs, ports[-1])
    O_L = overlap_matrix(port_modes_l, modes[0])
    O_R = overlap_matrix(modes[-1], port_modes_r)
    return sax.sdense(_connect_two(O_L, _connect_two(inner_S, O_R)))

In [None]:
from meow.port import PortCell

cs = modes[0][0].cs
pc = PortCell(port=ports[0][0], **cs.cell.dict())

In [None]:
import meow.eme.propagate as prop

In [None]:
z = np.linspace(0, l_taper * 2 + l_center, 800)
y = 0.2

ex_l = np.zeros(len(modes[0]))
ex_l = ex_l.at[0].set(1)
ex_r = np.zeros(len(modes[-1]))
# ex_r = ex_r.at[1].set(0.3)

Ex, x = prop.propagate_modes(modes, ex_l, ex_r, y, z)

In [None]:
X, Y = np.meshgrid(z, x)
lim = np.max(np.abs(Ex.imag))
plt.pcolormesh(X, Y, Ex.T.imag, shading="nearest", vmin=-lim, vmax=lim, cmap="RdBu")
plt.colorbar()
plt.savefig("test.png", dpi=1200)

In [None]:
X, Y = np.meshgrid(z, x)
plt.pcolormesh(
    X,
    Y,
    np.abs(Ex.T),
    shading="nearest",
    cmap="jet",
    vmax=np.quantile(np.abs(Ex), 0.99),
)
plt.colorbar()