# Surface Models
> Let's build some analytical surface models using MEOW and SAX

In [None]:
import json
from pathlib import Path

import matplotlib.pyplot as plt
import meow as mw
import numpy as np
import pandas as pd
import sax
import xarray as xr
from tqdm.notebook import tqdm

## Silicon Refractive index
Let's create a rudimentary silicon refractive index model:

```{note}
This model is not based on a realistical model.
```

In [None]:
def silicon_index(wl, T):
    """A rudimentary silicon refractive index model with temperature dependence"""
    a, b = 0.2411478522088102, 3.3229394315868976
    dn_dT = 0.00082342342  # probably exaggerated
    return a / wl + b + dn_dT * (T - 25.0)

In [None]:
wls = np.linspace(1.0, 3.0, 21)
for T in [25.0, 35.0, 45.0]:
    plt.plot(1e3 * wls, silicon_index(wls, T))
plt.xlabel("Wavelength [nm]")
plt.ylabel("neff")
plt.title("neff dispersion")
plt.grid(True)
plt.show()

## Waveguide Modes

> NOTE: this example shows a simple 1D linear interpolated neff model vs wavelength. To see an example of a grid interpolation over wavelength and width, see the 'Layout Aware' example.

We can use [meow](https://github.com/flaport/meow) to calculate the modes in our waveguide.

In [None]:
def find_waveguide_modes(
    wl: float = 1.55,
    T: float = 25.0,
    n_box: float = 1.4,
    n_clad: float = 1.4,
    n_core: float | None = None,
    t_slab: float = 0.1,
    t_soi: float = 0.22,
    w_core: float = 0.45,
    du=0.02,
    n_modes: int = 10,
    cache_path: str | Path = "modes",
    *,
    replace_cached: bool = False,
):
    length = 10.0
    delta = 10 * du
    env = mw.Environment(wl=wl, T=T)
    if n_core is None:
        n_core = silicon_index(wl, T)
    cache_path = Path(cache_path).resolve()
    cache_path.mkdir(exist_ok=True)
    fn = f"{wl=:.2f}-{T=:.2f}-{n_box=:.2f}-{n_clad=:.2f}-{n_core=:.2f}-{t_slab=:.3f}-{t_soi=:.3f}-{w_core=:.3f}-{du=:.3f}-{n_modes=}.json"
    path = cache_path / fn
    if not replace_cached and path.exists():
        return [mw.Mode.model_validate(mode) for mode in json.loads(path.read_text())]

    # fmt: off
    m_core = mw.SampledMaterial(name="slab", n=np.asarray([n_core, n_core]), params={"wl": np.asarray([1.0, 2.0])}, meta={"color": (0.9, 0, 0, 0.9)})
    m_clad = mw.SampledMaterial(name="clad", n=np.asarray([n_clad, n_clad]), params={"wl": np.asarray([1.0, 2.0])})
    m_box = mw.SampledMaterial(name="box", n=np.asarray([n_box, n_box]), params={"wl": np.asarray([1.0, 2.0])})
    box = mw.Structure(material=m_box, geometry=mw.Box(x_min=- 2 * w_core - delta, x_max= 2 * w_core + delta, y_min=- 2 * t_soi - delta, y_max=0.0, z_min=0.0, z_max=length))
    slab = mw.Structure(material=m_core, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0.0, y_max=t_slab, z_min=0.0, z_max=length))
    clad = mw.Structure(material=m_clad, geometry=mw.Box(x_min=-2 * w_core - delta, x_max=2 * w_core + delta, y_min=0, y_max=3 * t_soi + delta, z_min=0.0, z_max=length))
    core = mw.Structure(material=m_core, geometry=mw.Box(x_min=-w_core / 2, x_max=w_core / 2, y_min=0.0, y_max=t_soi, z_min=0.0, z_max=length))

    cell = mw.Cell(structures=[box, clad, slab, core], mesh=mw.Mesh2D( x=np.arange(-2*w_core, 2*w_core, du), y=np.arange(-2*t_soi, 3*t_soi, du) ), z_min=0.0, z_max=10.0)
    cross_section = mw.CrossSection.from_cell(cell=cell, env=env)
    modes = mw.compute_modes(cross_section, num_modes=n_modes)
    # fmt: on

    path.write_text(json.dumps([json.loads(mode.model_dump_json()) for mode in modes]))

    return modes

We can now easily calculate the modes of a strip waveguide:

In [None]:
modes = find_waveguide_modes(wl=1.5, T=25.0)

In [None]:
mw.visualize(modes[0])

In [None]:
wls = np.linspace(1.0, 3.0, 21)
Ts = np.linspace(25, 35, 11)
neffs = np.zeros((wls.shape[0], Ts.shape[0]))
for i, wl in enumerate(pb := tqdm(wls)):
    for j, T in enumerate(Ts):
        pb.set_postfix(T=f"{T:.2f}C")
        modes = find_waveguide_modes(wl=wl, T=T, w_core=0.5, replace_cached=False)
        neffs[i, j] = np.real(modes[0].neff)

In [None]:
plt.plot(1e3 * wls, neffs[:, 0], ls="none", marker=".")
plt.plot(1e3 * wls, neffs[:, -1], ls="none", marker=".")
plt.xlabel("Wavelength [nm]")
plt.ylabel("neff")
plt.title("neff dispersion")
plt.grid(True)
plt.show()

We can put the result in an xarray:

In [None]:
xarr = xr.DataArray(data=neffs, coords={"wl": wls, "T": Ts})
xarr

We can convert this xarray to a stacked dataframe:

In [None]:
df = sax.to_df(xarr, target_name="target")
df

In [None]:
from importlib import reload

reload(sax.fit)

In [None]:
result = sax.fit.neural_fit(
    df,
    "target",
    random_seed=12,
    hidden_dims=(20,),
    learning_rate=0.002,
    num_epochs=2000,
    transform_covariates=True,
)

In [None]:
result["model"]

In [None]:
df["pred"] = sax.fit.predict_neural_model(result, df)

In [None]:
plt.plot(df.wl, df.target, ".")
plt.plot(df.wl, df.pred, ".")

In [None]:
df.attrs.update({"hey": "hoi"})
df.attrs