In [None]:
import gdsfactory as gf
import meow as mw
import jax.numpy as np

In [None]:
def example_extrusions(
    t_slab: float = 0.0,
    t_soi: float = 0.4,
    t_ox: float = 0.0,
):
    """create some simple extrusion rules

    Args:
        t_slab: the slab thickness
        t_soi: the SOI thickness
        t_ox: the oxide layer thickness
    """
    extrusions = {
        (1, 0): [
            mw.GdsExtrusionRule(
                material=mw.silicon,
                h_min=0.0,
                h_max=0.0 + t_soi,
                mesh_order=1,
            ),
        ],
    }
    return extrusions

In [None]:
l_taper = 20
l_center = 20
w_center = 3

mmi = gf.components.mmi2x2(
    length_taper=l_taper, length_mmi=l_center, width_mmi=w_center
)

c = gf.Component()
ref = c.add_ref(mmi)
ref.xmin = 0
mmi = c

extrusion_rules = example_extrusions()
structs = mw.extrude_gds(mmi, extrusion_rules)

# mw.visualize(structs, scale=(1, 1, 0.2))

In [None]:
mmi.plot()

In [None]:
eps = 1e-10
w_sim = w_center + 2
h_sim = 2.0
mesh = 100
num_cells = 10
Ls = [l_taper / num_cells] * num_cells
Ls[-1] -= eps
Ls[0] -= eps
Ls += [2 * eps, l_center - 2 * eps, 2 * eps] + Ls
Ls = [eps] + Ls + [eps]
print(Ls)

cells = mw.create_cells(
    structures=structs,
    mesh=mw.Mesh2d(
        x=np.linspace(-w_sim / 2, w_sim / 2, mesh + 1),
        y=np.linspace(-h_sim / 2, h_sim / 2, mesh + 1),
    ),
    Ls=Ls,
)

# for cell in cells:
#  mw.visualize(cell)

In [None]:
env = mw.Environment(wl=1.55, T=25.0)
css = [mw.CrossSection(cell=cell, env=env) for cell in cells]

mw.visualize(css[0])
mw.visualize(css[-1])

In [None]:
num_modes = 16
modes = mw.compute_modes(css[0], num_modes=num_modes)
import matplotlib.pyplot as plt

mw.visualize(modes[0] - modes[1])
mw.visualize(modes[0] + modes[1])
plt.figure()
for mode in modes:
    mw.visualize(mode)

In [None]:
mw.compute_interface_s_matrix(modes[:1], modes[:2])

In [None]:
from tqdm.notebook import tqdm

modes = [mw.compute_modes(cs, num_modes=num_modes) for cs in tqdm(css)]

In [None]:
print([[mode.neff for mode in modes_] for modes_ in modes])

In [None]:
modes = [[mode for mode in modes_ if mode.neff > 1.45] for modes_ in modes]

In [None]:
modes[0] = [modes[0][0] + modes[0][1], modes[0][0] - modes[0][1]]
modes[-1] = [modes[-1][0] + modes[-1][1], modes[-1][0] - modes[-1][1]]

In [None]:
mw.visualize(modes[0][0])
mw.visualize(modes[-1][0])
mw.visualize(modes[-1][1])

In [None]:
S, port_map = mw.compute_s_matrix(modes)
print(port_map)
mw.visualize(S)

In [None]:
import meow.eme.propagate as prop

In [None]:
import importlib

importlib.reload(prop)

In [None]:
z = np.linspace(0, l_taper * 2 + l_center, 800)
y = 0.2

ex_l = np.zeros(len(modes[0]))
ex_l = ex_l.at[0].set(1)
ex_r = np.zeros(len(modes[-1]))
# ex_r = ex_r.at[1].set(0.3)

Ex, x = prop.propagate_modes(modes, ex_l, ex_r, y, z)

In [None]:
X, Y = np.meshgrid(z, x)
lim = np.max(np.abs(Ex.imag))
plt.pcolormesh(X, Y, Ex.T.imag, shading="nearest", vmin=-lim, vmax=lim, cmap="RdBu")
plt.colorbar()
plt.savefig("test.png", dpi=1200)

In [None]:
X, Y = np.meshgrid(z, x)
plt.pcolormesh(
    X,
    Y,
    np.abs(Ex.T),
    shading="nearest",
    cmap="jet",
    vmax=np.quantile(np.abs(Ex), 0.99),
)
plt.colorbar()
plt.savefig("test.png", dpi=1200)

In [None]:
for mode in modes[9]:
    mw.vis(mode)