In [None]:
import gdsfactory as gf
import jax.numpy as np
import matplotlib.pyplot as plt
import meow as mw
import meow.eme.propagate as prop
from tqdm.notebook import tqdm

In [None]:
def example_extrusions(
    t_slab: float = 0.0,
    t_soi: float = 0.22,
    t_ox: float = 0.0,
):
    """create some simple extrusion rules

    Args:
        t_slab: the slab thickness
        t_soi: the SOI thickness
        t_ox: the oxide layer thickness
    """
    extrusions = {
        (1, 0): [
            mw.GdsExtrusionRule(
                material=mw.silicon,
                h_min=0.0,
                h_max=0.0 + t_soi,
                mesh_order=1,
            ),
        ],
    }
    return extrusions

In [None]:
l_taper = 20

mmi = gf.components.taper(l_taper, 0.5, 1.5)

c = gf.Component()
ref = c.add_ref(mmi)
ref.xmin = 0
mmi = c

extrusion_rules = example_extrusions()
structs = mw.extrude_gds(mmi, extrusion_rules)

# mw.visualize(structs, scale=(1, 1, 0.2))

In [None]:
mmi.plot()

In [None]:
eps = 1e-2
cell_edges = np.linspace(0, l_taper, 3) + eps

mesh = mw.Mesh2D(
    x=np.linspace(-2, 2, 401),
    y=np.linspace(-1, 1, 401),
)

cells = []
for z_min, z_max in zip(cell_edges[:-1], cell_edges[1:]):
    cell = mw.Cell(
        structures=structs,
        mesh=mesh,
        z_min=z_min,
        z_max=z_max,
    )
    cells.append(cell)

env = mw.Environment(wl=1.55, T=25.0)
css = [mw.CrossSection.from_cell(cell=cell, env=env) for cell in cells]

for cs in css:
    mw.visualize(cs)

In [None]:
num_modes = 50

In [None]:
modes = [mw.compute_modes(cs, num_modes=num_modes) for cs in tqdm(css)]

In [None]:
modes_tmp = modes

In [None]:
modes = modes_tmp

In [None]:
modes = [[m for m in cs_modes if m.neff.imag >= 0] for cs_modes in modes]

In [None]:
# modes = [[mode for mode in modes_ if mode.neff > 1.45] for modes_ in modes]

In [None]:
import sax

S, port_map = mw.compute_s_matrix(modes, cells)
s_dict = sax.sdict((S, port_map))
np.abs(s_dict[("left@0", "right@0")])
# print(port_map)
# mw.visualize(np.abs(S))

In [None]:
z = np.linspace(9, 11, 800)
y = 0.2

ex_l = np.zeros(len(modes[0]))
# ex_l = ex_l.at[0].set(1)
ex_r = np.zeros(len(modes[-1]))
ex_r = ex_r.at[2].set(0.3)

Ex, x = prop.propagate_modes(modes, cells, ex_l, ex_r, y, z)

In [None]:
plt.figure(figsize=(10, 3))
X, Y = np.meshgrid(z, x)
lim = np.max(np.abs(Ex))
plt.pcolormesh(X, Y, Ex.T.real, shading="nearest", vmin=-lim, vmax=lim, cmap="RdBu")
# plt.gca().set_aspect(True)
plt.colorbar()

In [None]:
plt.figure(figsize=(10, 3))
X, Y = np.meshgrid(z, x)
plt.pcolormesh(
    X,
    Y,
    np.abs(Ex.T),
    shading="nearest",
    cmap="magma",
    vmax=np.quantile(np.abs(Ex), 0.99),
)
plt.colorbar()