In [None]:
import gdsfactory as gf
import jax.numpy as np
import matplotlib.pyplot as plt
import meow as mw
import meow.eme.propagate as prop
from tqdm.notebook import tqdm

In [None]:
def example_extrusions(
    t_slab: float = 0.0,
    t_soi: float = 0.4,
    t_ox: float = 0.0,
):
    """create some simple extrusion rules

    Args:
        t_slab: the slab thickness
        t_soi: the SOI thickness
        t_ox: the oxide layer thickness
    """
    extrusions = {
        (1, 0): [
            mw.GdsExtrusionRule(
                material=mw.silicon,
                h_min=0.0,
                h_max=0.0 + t_soi,
                mesh_order=1,
            ),
        ],
    }
    return extrusions

In [None]:
l_taper = 20
l_center = 20
w_center = 3

mmi = gf.components.mmi2x2(
    length_taper=l_taper, length_mmi=l_center, width_mmi=w_center
)

c = gf.Component()
ref = c.add_ref(mmi)
ref.xmin = 0
mmi = c

extrusion_rules = example_extrusions()
structs = mw.extrude_gds(mmi, extrusion_rules)

# mw.visualize(structs, scale=(1, 1, 0.2))

In [None]:
mmi.plot()

In [None]:
eps = 1e-2
left_cell_edges = np.linspace(0, 20, 11) + eps
right_cell_edges = np.linspace(40, 60, 11) - eps
cell_edges = np.concatenate(
    [left_cell_edges[:1], left_cell_edges, right_cell_edges, right_cell_edges[-1:]]
)

mesh = mw.Mesh2D(
    x=np.linspace(-2, 2, 101),
    y=np.linspace(-1, 1, 101),
)

cells = []
for z_min, z_max in zip(cell_edges[:-1], cell_edges[1:]):
    cell = mw.Cell(
        structures=structs,
        mesh=mesh,
        z_min=z_min,
        z_max=z_max,
    )
    cells.append(cell)

env = mw.Environment(wl=1.55, T=25.0)
css = [mw.CrossSection.from_cell(cell=cell, env=env) for cell in cells]

for cs in css:
    mw.visualize(cs)

In [None]:
num_modes = 16
modes = mw.compute_modes(css[0], num_modes=num_modes)

mw.visualize(modes[0] - modes[1])
mw.visualize(modes[0] + modes[1])
plt.figure()
for mode in modes:
    mw.visualize(mode)

In [None]:
_S = mw.compute_interface_s_matrix(modes[:1], modes[:2])
mw.visualize(_S)

In [None]:
modes = [mw.compute_modes(cs, num_modes=num_modes) for cs in tqdm(css)]

In [None]:
for modes_ in modes:
    print(np.array([np.real(mode.neff) for mode in modes_]))

In [None]:
modes = [[mode for mode in modes_ if mode.neff > 1.45] for modes_ in modes]

In [None]:
modes[0] = [modes[0][0] + modes[0][1], modes[0][0] - modes[0][1]]
modes[-1] = [modes[-1][0] + modes[-1][1], modes[-1][0] - modes[-1][1]]

In [None]:
mw.visualize(modes[0][0])
mw.visualize(modes[-1][0])
mw.visualize(modes[-1][1])

In [None]:
S, port_map = mw.compute_s_matrix(modes, cells)
print(port_map)
mw.visualize(S)

In [None]:
z = np.linspace(0, l_taper * 2 + l_center, 800)
y = 0.2

ex_l = np.zeros(len(modes[0]))
ex_l = ex_l.at[0].set(1)
ex_r = np.zeros(len(modes[-1]))
# ex_r = ex_r.at[1].set(0.3)

Ex, x = prop.propagate_modes(modes, cells, ex_l, ex_r, y, z)

In [None]:
X, Y = np.meshgrid(z, x)
lim = np.max(np.abs(Ex.imag))
plt.pcolormesh(X, Y, Ex.T.imag, shading="nearest", vmin=-lim, vmax=lim, cmap="RdBu")
plt.colorbar()

In [None]:
X, Y = np.meshgrid(z, x)
plt.pcolormesh(
    X,
    Y,
    np.abs(Ex.T),
    shading="nearest",
    cmap="jet",
    vmax=np.quantile(np.abs(Ex), 0.99),
)
plt.colorbar()

In [None]:
for mode in modes[9]:
    mw.visualize(mode)