In [None]:
# default_exp typing_

# Typing

> SAX types

In [None]:
# hide
import matplotlib.pyplot as plt
from fastcore.test import test_eq
from pytest import approx, raises

import os, sys; sys.stderr = open(os.devnull, "w")

In [None]:
# export
from __future__ import annotations

import functools
import inspect
from collections.abc import Callable as CallableABC
from typing import Any, Callable, Dict, Tuple, TypedDict, Union, cast, overload

import jax.numpy as jnp
import numpy as np
from natsort import natsorted

## Core Types

#### Array

an `Array` is either a jax array or a numpy array:

In [None]:
# exports
Array = Union[jnp.ndarray, np.ndarray]

#### Int

An `Int` is either a built-in `int` or an `Array` [of dtype `int`]

In [None]:
# exports
Int = Union[int, Array]

#### Float

A `Float` is eiter a built-in `float` or an `Array` [of dtype `float`]

In [None]:
# exports
Float = Union[float, Array]

#### ComplexFloat

A `ComplexFloat` is either a build-in `complex` or an Array [of dtype `complex`]:

In [None]:
# exports
ComplexFloat = Union[complex, Float]

#### Settings

A `Settings` dictionary is a nested mapping between setting names [`str`] to either `ComplexFloat` values OR to another lower level `Settings` dictionary.

In [None]:
# exports
Settings = Union[Dict[str, ComplexFloat], Dict[str, "Settings"]]

Settings dictionaries are used to parametrize a SAX `Model` or a `circuit`. The settings dictionary should have the same hierarchy levels as the circuit:
 
 > Example:

In [None]:
mzi_settings = {
    "wl": 1.5,  # global settings
    "lft": {"coupling": 0.5},  # settings for the left coupler
    "top": {"neff": 3.4},  # settings for the top waveguide
    "rgt": {"coupling": 0.3},  # settings for the right coupler
}

#### SDict

An `SDict` is a sparse dictionary based representation of an S-matrix, mapping port name tuples such as `('in0', 'out0')` to `ComplexFloat`.

In [None]:
# exports
SDict = Dict[Tuple[str, str], ComplexFloat]

> Example:

In [None]:
_sdict: SDict = {
    ("in0", "out0"): 3.0,
}

#### SCoo

An `SCoo` is a sparse matrix based representation of an S-matrix consisting of three arrays and a port map. The three arrays represent the input port indices [`int`], output port indices [`int`] and the S-matrix values [`ComplexFloat`] of the sparse matrix. The port map maps a port name [`str`] to a port index [`int`]. Only these four arrays **together** and in this specific **order** are considered a valid `SCoo` representation!

In [None]:
# exports
SCoo = Tuple[Array, Array, ComplexFloat, Dict[str, int]]

> Example:

In [None]:
Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
_scoo: SCoo = Si, Sj, Sx, port_map

#### SDense

an `SDense` is a dense matrix representation of an S-matrix. It's represented by an NxN `ComplexFloat` array and a port map (mapping port names onto port indices).

In [None]:
# exports
SDense = Tuple[Array, Dict[str, int]]

> Example:

In [None]:
Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
_sdense = Sd, port_map

#### SType

an `SType` is either an `SDict` **OR** an `SCoo` **OR** an `SDense`:

In [None]:
# exports
SType = Union[SDict, SCoo, SDense]

> Example:

In [None]:
obj: SType = _sdict
obj: SType = _scoo
obj: SType = _sdense

#### Model

A `Model` is any keyword-only function that returns an `SType`:

In [None]:
# exports
Model = Callable[..., SType]

#### ModelFactory

A `ModelFactory` is any keyword-only function that returns a `Model`:

In [None]:
# exports
ModelFactory = Callable[..., Model]

> Note: SAX sometimes needs to figure out the difference between a `ModelFactory` and a normal `Model` *before* running the function. To do this, SAX will check the return annotation of the function. Any function with a `-> Model` or `-> Callable` annotation will be considered a `ModelFactory`. Any function without this annotation will be considered a normal Model: **don't forget the return annotation of your Model Factory!** To ensure a correct annotation and to ensure forward compatibility, it's recommended to decorate your `ModelFactory` with the `modelfactory` decorator.

#### GeneralModel

a `GeneralModel` is either a `Model` or a `LogicalNetlist` (will be defined below):

In [None]:
# exports
GeneralModel = Union[Model, "LogicalNetlist"]

#### Models

`Models` is a mapping between model names [`str`] and `GeneralModel`:

In [None]:
# exports
Models = Dict[str, GeneralModel]

> Note: sometimes 'component' is used to refer to a a `Model` or `GeneralModel`. This is because other tools (such as for example GDSFactory) prefer that terminology.

## Netlist Types

#### Instance

A netlist `Instance` is a mapping with two keys: `"component"`, which should map to a key in a `Models` dictionary and `"settings"`, which are all the necessary settings to instanciate a component:

In [None]:
# exports
Instance = TypedDict(
    "Instance",
    {
        "component": str,
        "settings": Settings,
    },
)

> Note: in SAX, a better name for `"component"` in the instance definition would probably be `"model"` or `"model_name"`. However we chose `"component"` here to have a 1-to-1 map between SAX netlists and GDSFactory netlists.

#### GeneralInstance

A general instance can be any of the following (`LogicalNetlist` and `Netlist` will be defined below):

In [None]:
# exports
GeneralInstance = Union[str, Instance, "LogicalNetlist", "Netlist"]

> For example, this is allowed:

In [None]:
inst: GeneralInstance = "my_component_model"
inst: GeneralInstance = {
    "component": "my_component_model",
    "settings": {},
}

> ... and this is not (will be flagged by a static type checker like pyright or mypy):

In [None]:
inst: GeneralInstance = {
    "component": "my_component_model",
    "settings": {},
    "extra_arg": "invalid",
}

#### Instances

`Instances` is a mapping from instance names [`str`] to a `GeneralInstance`:

In [None]:
# exports
Instances = Union[Dict[str, str], Dict[str, GeneralInstance]]

#### Netlist

a `Netlist` is a collection of `"instances"`, `"connections"` and `"ports"`:

In [None]:
# exports

Netlist = TypedDict(
    "Netlist",
    {
        "instances": Instances,
        "connections": Dict[str, str],
        "ports": Dict[str, str],
    },
)

> Example:

In [None]:
mzi_netlist: Netlist = {
    "instances": {
        "lft": "mmi1x2",  # shorthand if no settings need to be given
        "top": {  # full instance definition
            "component": "waveguide",
            "settings": {
                "length": 100.0,
            },
        },
        "rgt": "mmi2x2",  # shorthand if no settings need to be given
    },
    "connections": {
        "lft,out0": "top,in0",
        "top,out0": "rgt,in0",
        "top,out1": "rgt,in1",
    },
    "ports": {
        "in0": "lft,in0",
        "out0": "rgt,out0",
        "out1": "rgt,out1",
    },
}

#### LogicalNetlist

a `LogicalNetlist` is a subset of the more general `Netlist`.  It only contains the logical connections and instance names. Not the actual instances. This data structure is mostly used for internal use only.

In [None]:
# exports

LogicalNetlist = TypedDict(
    "LogicalNetlist",
    {
        "instances": Dict[str, str],
        "connections": Dict[str, str],
        "ports": Dict[str, str],
    },
)

> Example:

In [None]:
mzi_logical_netlist: Netlist = {
    "instances": {
        "lft": "mmi1x2",
        "top": "waveguide",
        "rgt": "mmi2x2",
    },
    "connections": {
        "lft,out0": "top,in0",
        "top,out0": "rgt,in0",
        "top,out1": "rgt,in1",
    },
    "ports": {
        "in0": "lft,in0",
        "out0": "rgt,out0",
        "out1": "rgt,out1",
    },
}

## Validation and runtime type-checking:

> Note: the type-checking functions below are **NOT** very tight and hence should be used within the right context!

In [None]:
# export
def is_float(x: Any) -> bool:
    """Check if an object is a `Float`"""
    if isinstance(x, float):
        return True
    if isinstance(x, np.ndarray):
        return x.dtype in (np.float16, np.float32, np.float64, np.float128)
    if isinstance(x, jnp.ndarray):
        return x.dtype in (jnp.float16, jnp.float32, jnp.float64)
    return False

In [None]:
assert is_float(3.0)
assert not is_float(3)
assert not is_float(3.0 + 2j)
assert not is_float(jnp.array(3.0, dtype=complex))
assert not is_float(jnp.array(3, dtype=int))

In [None]:
# export
def is_complex(x: Any) -> bool:
    """check if an object is a `ComplexFloat`"""
    if isinstance(x, complex):
        return True
    if isinstance(x, np.ndarray):
        return x.dtype in (np.complex64, np.complex128)
    if isinstance(x, jnp.ndarray):
        return x.dtype in (jnp.complex64, jnp.complex128)
    return False

In [None]:
assert not is_complex(3.0)
assert not is_complex(3)
assert is_complex(3.0 + 2j)
assert is_complex(jnp.array(3.0, dtype=complex))
assert not is_complex(jnp.array(3, dtype=int))

In [None]:
# export
def is_complex_float(x: Any) -> bool:
    """check if an object is either a `ComplexFloat` or a `Float`"""
    return is_float(x) or is_complex(x)

In [None]:
assert is_complex_float(3.0)
assert not is_complex_float(3)
assert is_complex_float(3.0 + 2j)
assert is_complex_float(jnp.array(3.0, dtype=complex))
assert not is_complex_float(jnp.array(3, dtype=int))

In [None]:
# export
def is_sdict(x: Any) -> bool:
    """check if an object is an `SDict` (a SAX S-dictionary)"""
    return isinstance(x, dict)

In [None]:
assert not is_sdict(object())
assert is_sdict(_sdict)
assert not is_sdict(_scoo)
assert not is_sdict(_sdense)

In [None]:
# export
def is_scoo(x: Any) -> bool:
    """check if an object is an `SCoo` (a SAX sparse S-matrix representation in COO-format)"""
    return isinstance(x, (tuple, list)) and len(x) == 4

In [None]:
assert not is_scoo(object)
assert not is_scoo(_sdict)
assert is_scoo(_scoo)
assert not is_scoo(_sdense)

In [None]:
# export
def is_sdense(x: Any) -> bool:
    """check if an object is an `SDense` (a SAX dense S-matrix representation)"""
    return isinstance(x, (tuple, list)) and len(x) == 2

In [None]:
assert not is_sdense(object)
assert not is_sdense(_sdict)
assert not is_sdense(_scoo)
assert is_sdense(_sdense)

In [None]:
# export
def is_model(model: Any) -> bool:
    """check if a callable is a `Model` (a callable returning an `SType`)"""
    if not callable(model):
        return False
    try:
        sig = inspect.signature(model)
    except ValueError:
        return False
    for param in sig.parameters.values():
        if param.default == inspect.Parameter.empty:
            return False  # a proper SAX model does not have any positional arguments.
    if _is_callable_annotation(sig.return_annotation):  # model factory
        return False
    return True

def _is_callable_annotation(annotation: Any) -> bool:
    """check if an annotation is `Callable`-like"""
    if isinstance(annotation, str):
        # happens when
        # from __future__ import annotations
        # was imported at the top of the file...
        return annotation.startswith("Callable") or annotation.endswith("Model")
        # TODO: this is not a very robust check...
    try:
        return annotation.__origin__ == CallableABC
    except AttributeError:
        return False

In [None]:
# hide
assert _is_callable_annotation(Callable)
assert not _is_callable_annotation(SDict)

In [None]:
def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:
    return {("in0", "out0"): jnp.array(3.0)}
assert is_model(good_model)

def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:
    return {("in0", "out0"): jnp.array(3.0)}
assert not is_model(bad_model)

In [None]:
# export
def is_model_factory(model: Any) -> bool:
    """check if a callable is a model function."""
    if not callable(model):
        return False
    sig = inspect.signature(model)
    if _is_callable_annotation(sig.return_annotation):  # model factory
        return True
    return False

> Note: For a `Callable` to be considered a `ModelFactory` in SAX, it **MUST** have a `Callable` or `Model` return annotation. Otherwise SAX will view it as a `Model` and things might break!

In [None]:
def func() -> Model:
    ...
    
assert is_model_factory(func) # yes, we only check the annotation for now...

def func():
    ...
    
assert not is_model_factory(func) # yes, we only check the annotation for now...

In [None]:
# export
def validate_model(model: Callable):
    """Validate the parameters of a model"""
    positional_arguments = []
    for param in inspect.signature(model).parameters.values():
        if param.default is inspect.Parameter.empty:
            positional_arguments.append(param.name)
    if positional_arguments:
        raise ValueError(
            f"model '{model}' takes positional arguments {', '.join(positional_arguments)} "
            "and hence is not a valid SAX Model! A SAX model should ONLY take keyword arguments (or no arguments at all)."
        )

In [None]:
def good_model(x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:
    return {("in0", "out0"): jnp.array(3.0)}


assert validate_model(good_model) is None

In [None]:
def bad_model(positional_argument, x=jnp.array(3.0), y=jnp.array(4.0)) -> SDict:
    return {("in0", "out0"): jnp.array(3.0)}


with raises(ValueError):
    validate_model(bad_model)

In [None]:
# export
def is_instance(instance: Any) -> bool:
    """check if a dictionary is an instance"""
    if not isinstance(instance, dict):
        return False
    return "component" in instance

In [None]:
# export
def is_netlist(netlist: Any) -> bool:
    """check if a dictionary is a netlist"""
    if not isinstance(netlist, dict):
        return False
    if not "instances" in netlist:
        return False
    if not "connections" in netlist:
        return False
    if not "ports" in netlist:
        return False
    return True

In [None]:
# export
def is_stype(stype: Any) -> bool:
    """check if an object is an SDict, SCoo or SDense"""
    return is_sdict(stype) or is_scoo(stype) or is_sdense(stype)

In [None]:
# export
def is_singlemode(S: Any) -> bool:
    """check if an stype is single mode"""
    if not is_stype(S):
        return False
    ports = _get_ports(S)
    return not any(("@" in p) for p in ports)

def _get_ports(S: SType):
    if is_sdict(S):
        S = cast(SDict, S)
        ports_set = {p1 for p1, _ in S} | {p2 for _, p2 in S}
        return tuple(natsorted(ports_set))
    else:
        *_, ports_map = S
        assert isinstance(ports_map, dict)
        return tuple(natsorted(ports_map.keys()))

In [None]:
# export
def is_multimode(S: Any) -> bool:
    """check if an stype is single mode"""
    if not is_stype(S):
        return False
    
    ports = _get_ports(S)
    return all(("@" in p) for p in ports)

In [None]:
# export
def is_mixedmode(S: Any) -> bool:
    """check if an stype is neither single mode nor multimode (hence invalid)"""
    return not is_singlemode(S) and not is_multimode(S)

## SAX return type helpers

> a.k.a SDict, SDense, SCoo helpers

Convert an `SDict`, `SCoo` or `SDense` into an `SDict` (or convert a model generating any of these types into a model generating an `SDict`):

In [None]:
# exporti

@overload
def sdict(S: Model) -> Model:
    ...


@overload
def sdict(S: SType) -> SDict:
    ...

In [None]:
# export
def sdict(S: Union[Model, SType]) -> Union[Model, SType]:
    """Convert an `SCoo` or `SDense` to `SDict`"""

    if is_model(S):
        model = cast(Model, S)

        @functools.wraps(model)
        def wrapper(**kwargs):
            return sdict(model(**kwargs))

        return wrapper

    elif is_scoo(S):
        x_dict = _scoo_to_sdict(*cast(SCoo, S))
    elif is_sdense(S):
        x_dict = _sdense_to_sdict(*cast(SDense, S))
    elif is_sdict(S):
        x_dict = cast(SDict, S)
    else:
        raise ValueError("Could not convert arguments to sdict.")

    return x_dict


def _scoo_to_sdict(Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]) -> SDict:
    sdict = {}
    inverse_ports_map = {int(i): p for p, i in ports_map.items()}
    for i, (si, sj) in enumerate(zip(Si, Sj)):
        sdict[
            inverse_ports_map.get(int(si), ""), inverse_ports_map.get(int(sj), "")
        ] = Sx[..., i]
    sdict = {(p1, p2): v for (p1, p2), v in sdict.items() if p1 and p2}
    return sdict


def _sdense_to_sdict(S: Array, ports_map: Dict[str, int]) -> SDict:
    sdict = {}
    for p1, i in ports_map.items():
        for p2, j in ports_map.items():
            sdict[p1, p2] = S[..., i, j]
    return sdict

In [None]:
assert sdict(_sdict) is _sdict
assert sdict(_scoo) == {
    ("in0", "in0"): 3.0,
    ("in1", "in0"): 1.0,
    ("out0", "out0"): 4.0,
}
assert sdict(_sdense) == {
    ("in0", "in0"): 0.0,
    ("in0", "out0"): 1.0,
    ("in0", "in1"): 2.0,
    ("out0", "in0"): 3.0,
    ("out0", "out0"): 4.0,
    ("out0", "in1"): 5.0,
    ("in1", "in0"): 6.0,
    ("in1", "out0"): 7.0,
    ("in1", "in1"): 8.0,
}

Convert an `SDict`, `SCoo` or `SDense` into an `SCoo` (or convert a model generating any of these types into a model generating an `SCoo`):

In [None]:
# exporti

@overload
def scoo(S: Callable) -> Callable:
    ...


@overload
def scoo(S: SType) -> SCoo:
    ...

In [None]:
# export

def scoo(S: Union[Callable, SType]) -> Union[Callable, SCoo]:
    """Convert an `SDict` or `SDense` to `SCoo`"""

    if is_model(S):
        model = cast(Model, S)

        @functools.wraps(model)
        def wrapper(**kwargs):
            return scoo(model(**kwargs))

        return wrapper

    elif is_scoo(S):
        S = cast(SCoo, S)
    elif is_sdense(S):
        S = _sdense_to_scoo(*cast(SDense, S))
    elif is_sdict(S):
        S = _sdict_to_scoo(cast(SDict, S))
    else:
        raise ValueError("Could not convert arguments to scoo.")

    return S


def _sdense_to_scoo(S: Array, ports_map: Dict[str, int]) -> SCoo:
    Sj, Si = jnp.meshgrid(jnp.arange(S.shape[-1]), jnp.arange(S.shape[-2]))
    return Si.ravel(), Sj.ravel(), S.reshape(*S.shape[:-2], -1), ports_map


def _sdict_to_scoo(sdict: SDict) -> SCoo:
    all_ports = {}
    for p1, p2 in sdict:
        all_ports[p1] = None
        all_ports[p2] = None
    ports_map = {p: i for i, p in enumerate(all_ports)}
    Sx = jnp.stack(jnp.broadcast_arrays(*sdict.values()), -1)
    Si = jnp.array([ports_map[p] for p, _ in sdict])
    Sj = jnp.array([ports_map[p] for _, p in sdict])
    return Si, Sj, Sx, ports_map

In [None]:
assert scoo(_scoo) is _scoo
assert scoo(_sdict) == (0, 1, 3.0, {"in0": 0, "out0": 1})
Si, Sj, Sx, port_map = scoo(_sdense)  # type: ignore
np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))
np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))
np.testing.assert_array_almost_equal(Sx, jnp.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]))
assert port_map == {"in0": 0, "in1": 2, "out0": 1}

Convert an `SDict`, `SCoo` or `SDense` into an `SDense` (or convert a model generating any of these types into a model generating an `SDense`):

In [None]:
# exporti


@overload
def sdense(S: Callable) -> Callable:
    ...


@overload
def sdense(S: SType) -> SDense:
    ...

In [None]:
# export

def sdense(S: Union[Callable, SType]) -> Union[Callable, SDense]:
    """Convert an `SDict` or `SCoo` to `SDense`"""

    if is_model(S):
        model = cast(Model, S)

        @functools.wraps(model)
        def wrapper(**kwargs):
            return sdense(model(**kwargs))

        return wrapper

    if is_sdict(S):
        S = _sdict_to_sdense(cast(SDict, S))
    elif is_scoo(S):
        S = _scoo_to_sdense(*cast(SCoo, S))
    elif is_sdense(S):
        S = cast(SDense, S)
    else:
        raise ValueError("Could not convert arguments to sdense.")

    return S


def _scoo_to_sdense(
    Si: Array, Sj: Array, Sx: Array, ports_map: Dict[str, int]
) -> SDense:
    n_col = len(ports_map)
    S = jnp.zeros((*Sx.shape[:-1], n_col, n_col), dtype=complex)
    S = S.at[..., Si, Sj].add(Sx)
    return S, ports_map


def _sdict_to_sdense(sdict: SDict) -> SDense:
    Si, Sj, Sx, ports_map = _sdict_to_scoo(sdict)
    return _scoo_to_sdense(Si, Sj, Sx, ports_map)

In [None]:
assert sdense(_sdense) is _sdense
Sd, port_map = sdense(_scoo)  # type: ignore
Sd_ = jnp.array([[3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
                 [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],
                 [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j]])

np.testing.assert_array_almost_equal(Sd, Sd_)
assert port_map == {"in0": 0, "in1": 2, "out0": 1}

In [None]:
# export

def modelfactory(func):
    """Decorator that marks a function as `ModelFactory`"""
    sig = inspect.signature(func)
    if _is_callable_annotation(sig.return_annotation):  # already model factory
        return func
    func.__signature__ = sig.replace(return_annotation=Model)
    return func