# Data Parsers
> Let's parse some data.

## Imports

In [None]:
from functools import cache

import altair as alt
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import requests
import sax

## Lumerical Parser

The [SiEPIC ebeam PDK](https://github.com/SiEPIC/SiEPIC_EBeam_PDK) has a bunch of data files in Lumerical format. Let's download one of them:

In [None]:
url = "https://raw.githubusercontent.com/SiEPIC/SiEPIC_EBeam_PDK/refs/heads/master/Lumerical_EBeam_CML/EBeam/source_data/ebeam_dc_te1550/dc_gap%3D200nm_Lc%3D0um.sparam"
content = requests.get(url).text
print(content[:1000])

In [None]:
df = sax.parsers.parse_lumerical_dat(content)
df

We see that the parsed dataframe is a dataframe in [tidy format](https://aeturrell.github.io/python4DS/data-tidy.html) with the following columns:

| freq | mag | phase | port_out | port_in | mode_out | mode_in |
|------|-----|-------|----------|---------|----------|---------|

In this case it's a single mode dataframe:

In [None]:
print(f"{df.mode_in.unique()=}")
print(f"{df.mode_out.unique()=}")

So if we want we can drop those columns:

In [None]:
df = sax.parsers.parse_lumerical_dat(content)
df = df.drop(columns=["mode_in", "mode_out"])
df

The plotting library [altair](https://github.com/vega/altair) is a perfect fit for visualizing dataframes in tidy format:

In [None]:
df["wl"] = sax.C_UM_S / df["freq"]
chart = (
    alt.Chart(df.query("port_in=='port_1'"))
    .mark_line()
    .encode(
        x=alt.X("wl", scale=alt.Scale(domain=(df["wl"].min(), df["wl"].max()))),
        y=alt.Y("mag", scale=alt.Scale(domain=(-0.05, 1.05))),
        color="port_out",
    )
    .properties(width="container")
).interactive()
chart

## Transforming into xarray

Very often we would like to represent this as an xarray (think of it as a multi-dimensional dataframe):

In [None]:
df_model = df.copy()
df_model["wl"] = sax.C_UM_S / df_model["freq"]
df_model["amp"] = np.sqrt(df_model["mag"])
df_model = df_model[["wl", "amp", "phi", "port_out", "port_in"]]
df_model
xarr = sax.to_xarray(df_model, target_names=["amp", "phi"])
xarr

## Interpolating an xarray:

To interpolate over the float coordinates of the xarray:

In [None]:
sax.interpolate_xarray(xarr, wl=1.55)

String coordinates can not be interpolated over, but they can be selected:

In [None]:
sax.interpolate_xarray(xarr, wl=[1.555], port_in="port_1", port_out="port_1")

or to have all outputs for a certain input:

In [None]:
sax.interpolate_xarray(xarr, wl=[1.555], port_in="port_1")

## Creating a model

Using all of the above we can create a model. The common boilerplate can be divided in two steps:

In [None]:
# 1. The cached data loader:


@cache
def load_dc_xarray():
    #
    url = url = (
        "https://raw.githubusercontent.com/SiEPIC/SiEPIC_EBeam_PDK/refs/heads/master/Lumerical_EBeam_CML/EBeam/source_data/ebeam_dc_te1550/dc_gap%3D200nm_Lc%3D0um.sparam"
    )
    content = requests.get(url).text
    # or for local data probably more something like this:
    # path = Path(__file__).parent / "relative" / "path" / "to" / "data.dat"
    # content = Path(path).read_text()
    df = sax.parsers.parse_lumerical_dat(content)

    # do the necessary transformations to get the dataframe ready to be transformed into an xarray:
    # only keep columns that should be used
    # (i.e. columns that uniquely predict the target, without duplication, i.e. no freq and wl together)
    df["wl"] = sax.C_UM_S / df["freq"]
    df["amp"] = np.sqrt(df["mag"])
    df = df[["wl", "amp", "phi", "port_out", "port_in"]]

    # now we can transform to xarray
    xarr = sax.to_xarray(df_model, target_names=["amp", "phi"])

    # and return it
    return xarr


# 2. The model function
def dc_model(
    wl=1.5,
) -> (
    sax.SDict
):  # all non-port, non-target columns should be exposed as keyword arguments
    with jax.ensure_compile_time_eval():
        xarr = load_dc_xarray()

    ports = {
        "in0": "port_1",
        "in1": "port_2",
        "out0": "port_4",
        "out1": "port_3",
    }

    S = {}
    for p_in, port_in in ports.items():
        for p_out, port_out in ports.items():
            # don't forget to add more keyword arguments here if your data supports it!
            interpolated = sax.interpolate_xarray(
                xarr, wl=wl, port_in=str(port_in), port_out=str(port_out)
            )
            S[p_in, p_out] = interpolated["amp"] * jnp.exp(1j * interpolated["phi"])
    return S

In [None]:
dc_model()

## SDense for performance

A model returning an `SDict` is usually the easiest to work with, however, we can also return an SDense, which in this case should be more performant, as only one xarray interpolation will be necessary:

In [None]:
def dc_model2(
    wl=1.5,
) -> (
    sax.SDense
):  # all non-port, non-target columns should be exposed as keyword arguments
    with jax.ensure_compile_time_eval():
        xarr = load_dc_xarray()

    ports = {
        "in0": "port_1",
        "in1": "port_2",
        "out0": "port_4",
        "out1": "port_3",
    }

    # by not specifying ports, the array will be interpolated directly:
    # NOTE! for this to work, you should confirm that the last three dimensions
    # the last two dimensions of your xarray (`xarr.dims`) are port_in, port_out, targets
    interpolated = sax.interpolate_xarray(xarr)
    S = interpolated["amp"] * jnp.exp(1j * interpolated["phi"])
    port_map = {k: i for i, k in enumerate(xarr.coords["port_in"].values)}
    # also confirm that if we define port_map with port_out instead, we get the same dict!
    # port_map = {k: i for i, k in enumerate(xarr.coords['port_out'].values)}
    return S, port_map  # this is a an SDense!

In [None]:
dc_model2()