# Typing

> SAX types

In [None]:
import jax.numpy as jnp
import numpy as np
import pytest

import sax

In [None]:
assert sax.try_into[float](3.0)
assert sax.try_into[float](3)
assert not sax.try_into[float](3.0 + 2j)
assert not sax.try_into[float](jnp.array(3.0, dtype=complex))
assert sax.try_into[float](jnp.array(3, dtype=int))

In [None]:
assert sax.try_into[complex](3.0)
assert sax.try_into[complex](3)
assert sax.try_into[complex](3.0 + 2j)
assert not sax.try_into[complex](jnp.array(3.0, dtype=complex))

# discrepency with float case here? (-> pydantic issue)
assert not sax.try_into[complex](jnp.array(3, dtype=int))

In [None]:
_sdict: sax.SDict = {
    ("in0", "out0"): 3.0,
}

In [None]:
Si = jnp.arange(3, dtype=int)
Sj = jnp.array([0, 1, 0], dtype=int)
Sx = jnp.array([3.0, 4.0, 1.0])
port_map = {"in0": 0, "in1": 2, "out0": 1}
_scoo: sax.SCoo = (Si, Sj, Sx, port_map)

In [None]:
Sd = jnp.arange(9, dtype=float).reshape(3, 3)
port_map = {"in0": 0, "in1": 2, "out0": 1}
_sdense = Sd, port_map

In [None]:
assert not sax.try_into[sax.SDict](object())
assert sax.try_into[sax.SDict](_sdict)
assert not sax.try_into[sax.SDict](_scoo)
assert not sax.try_into[sax.SDict](_sdense)

In [None]:
assert not sax.try_into[sax.SCoo](object)
assert not sax.try_into[sax.SCoo](_sdict)
assert sax.try_into[sax.SCoo](_scoo)
assert not sax.try_into[sax.SCoo](_sdense)

In [None]:
assert not sax.try_into[sax.SDense](object)
assert not sax.try_into[sax.SDense](_sdict)
assert not sax.try_into[sax.SDense](_scoo)
assert sax.try_into[sax.SDense](_sdense)

In [None]:
def good_model(x=3.0, y=4.0) -> sax.SDict:
    return {("in0", "out0"): jnp.array(3.0)}

In [None]:
assert sax.try_into[sax.Model](good_model)

In [None]:
def bad_model(positional_argument, x=3.0, y=4.0) -> sax.SDict:
    return {("in0", "out0"): jnp.array(3.0)}

In [None]:
assert not sax.try_into[sax.Model](bad_model)

> Note: For a `Callable` to be considered a `ModelFactory` in SAX, it **MUST** have a `Callable` or `Model` return annotation. Otherwise SAX will view it as a `Model` and things might break!

In [None]:
def func() -> sax.Model: ...


# yes, we only check the annotation for now...
assert sax.try_into[sax.ModelFactory](func)

In [None]:
def func() -> None: ...


# yes, we only check the annotation for now...
assert not sax.try_into[sax.ModelFactory](func)

## SAX return type helpers

> a.k.a SDict, SDense, SCoo helpers

Convert an `SDict`, `SCoo` or `SDense` into an `SDict` (or convert a model generating any of these types into a model generating an `SDict`):

In [None]:
_sdict

In [None]:
_sdict = {k: jnp.asarray(v, dtype=jnp.complex128) for k, v in _sdict.items()}
assert sax.sdict(_sdict) == _sdict
assert sax.sdict(_scoo) == {
    ("in0", "in0"): 3.0,
    ("in1", "in0"): 1.0,
    ("out0", "out0"): 4.0,
}
assert sax.sdict(_sdense) == {
    ("in0", "in0"): 0.0,
    ("in0", "out0"): 1.0,
    ("in0", "in1"): 2.0,
    ("out0", "in0"): 3.0,
    ("out0", "out0"): 4.0,
    ("out0", "in1"): 5.0,
    ("in1", "in0"): 6.0,
    ("in1", "out0"): 7.0,
    ("in1", "in1"): 8.0,
}

Convert an `SDict`, `SCoo` or `SDense` into an `SCoo` (or convert a model generating any of these types into a model generating an `SCoo`):

In [None]:
sax.scoo(_sdense)

In [None]:
assert sax.scoo(_sdict) == (0, 1, 3.0, {"in0": 0, "out0": 1})
Si, Sj, Sx, port_map = sax.scoo(_sdense)
np.testing.assert_array_equal(Si, jnp.array([0, 0, 0, 1, 1, 1, 2, 2, 2]))
np.testing.assert_array_equal(Sj, jnp.array([0, 1, 2, 0, 1, 2, 0, 1, 2]))
np.testing.assert_array_almost_equal(
    Sx, jnp.array([0.0, 2.0, 1.0, 6.0, 8.0, 7.0, 3.0, 5.0, 4.0])
)
assert port_map == {"in0": 0, "in1": 1, "out0": 2}

Convert an `SDict`, `SCoo` or `SDense` into an `SDense` (or convert a model generating any of these types into a model generating an `SDense`):

In [None]:
Sd, port_map = sax.sdense(_scoo)
Sd_ = jnp.array(
    [
        [3.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
        [0.0 + 0.0j, 4.0 + 0.0j, 0.0 + 0.0j],
        [1.0 + 0.0j, 0.0 + 0.0j, 0.0 + 0.0j],
    ]
)

np.testing.assert_array_almost_equal(Sd, Sd_)
assert port_map == {"in0": 0, "in1": 2, "out0": 1}