# Backend

> SAX Backends

In [2]:
from nbdev import show_doc
from sax.typing_ import SDense, SDict, sdict

import os, sys; sys.stderr = open(os.devnull, "w")

In [3]:
from __future__ import annotations

import warnings
from typing import Any, Dict

try:
    import jax
    import jax.numpy as jnp
    JAX_AVAILABLE = True
except ImportError:
    import numpy as jnp
    JAX_AVAILABLE = False
    
try:
    import klujax
    KLUJAX_AVAILABLE = True
except ImportError:
    KLUJAX_AVAILABLE = False

from sax.backends.additive import analyze_circuit_additive, evaluate_circuit_additive
from sax.backends.filipsson_gunnar import analyze_circuit_fg, evaluate_circuit_fg
from sax.typing_ import SType, sdict

if JAX_AVAILABLE and KLUJAX_AVAILABLE:
    from sax.backends.klu import analyze_circuit_klu, evaluate_circuit_klu

#### circuit_backends

In [4]:

circuit_backends = {
    "fg": (analyze_circuit_fg, evaluate_circuit_fg),
    "filipsson_gunnar": (analyze_circuit_fg, evaluate_circuit_fg),
    "additive": (analyze_circuit_additive, evaluate_circuit_additive),
}

if JAX_AVAILABLE and KLUJAX_AVAILABLE:
    circuit_backends["klu"] = (analyze_circuit_klu, evaluate_circuit_klu)
    circuit_backends["default"] = (analyze_circuit_klu, evaluate_circuit_klu)
else:
    circuit_backends["default"] = (analyze_circuit_fg, evaluate_circuit_fg)
    warnings.warn("klujax not found. Please install klujax for better performance during circuit evaluation!")

SAX allows to easily interchange the backend of a circuit. A SAX backend consists of a static analysis step and an evaluation step:

:::{eval-rst}
.. autofunction:: sax.backends.__init__.analyze_circuit
:::


In [5]:
def analyze_circuit(connections: Dict[str, str], ports: Dict[str, str]) -> Any:
    return circuit_backends['default'][0](connections, ports)

:::{eval-rst}
.. autofunction:: sax.backends.__init__.evaluate_circuit
:::


In [6]:
def evaluate_circuit(analyzed: Any, instances: Dict[str, SType]) -> SType:
    return circuit_backends['default'][1](analyzed, instances)

The `analyze_circuit` step should statically analyze the connections and ports and should return an `analyzed` object. This object contains all the static objects that are needed for circuit computation but won't be needed to be recalculated when any parameters of the circuit change. See [KLU backend](./08b_backends_klu.ipynb) for a non-trivial implementation of the circuit analyzation.

The `evaluate_circuit` step evaluates the circuit for given `SType` instances.

> Example

Let's create an MZI `SDict` using the default backend's `evaluate_circuit`:

In [7]:
wg_sdict: SDict = {
    ("in0", "out0"): 0.5 + 0.86603j,
    ("out0", "in0"): 0.5 + 0.86603j,
}

τ, κ = 0.5 ** 0.5, 1j * 0.5 ** 0.5
dc_sdense: SDense = (
    jnp.array([[0, 0, τ, κ], 
               [0, 0, κ, τ], 
               [τ, κ, 0, 0], 
               [κ, τ, 0, 0]]),
    {"in0": 0, "in1": 1, "out0": 2, "out1": 3},
)

instances={
    "dc1": dc_sdense,
    "wg": wg_sdict,
    "dc2": dc_sdense,
}
connections={
    "dc1,out0": "wg,in0",
    "wg,out0": "dc2,in0",
    "dc1,out1": "dc2,in1",
}
ports={
    "in0": "dc1,in0",
    "in1": "dc1,in1",
    "out0": "dc2,out0",
    "out1": "dc2,out1",
}

analyzed = analyze_circuit(connections, ports)
mzi_sdict = sdict(evaluate_circuit(analyzed, instances))
display(mzi_sdict)

{('in0', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in0', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in0', 'out0'): Array(-0.25+0.433015j, dtype=complex128),
 ('in0', 'out1'): Array(-0.433015+0.75j, dtype=complex128),
 ('in1', 'in0'): Array(0.+0.j, dtype=complex128),
 ('in1', 'in1'): Array(0.+0.j, dtype=complex128),
 ('in1', 'out0'): Array(-0.433015+0.75j, dtype=complex128),
 ('in1', 'out1'): Array(0.25-0.433015j, dtype=complex128),
 ('out0', 'in0'): Array(-0.25+0.433015j, dtype=complex128),
 ('out0', 'in1'): Array(-0.433015+0.75j, dtype=complex128),
 ('out0', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out0', 'out1'): Array(0.+0.j, dtype=complex128),
 ('out1', 'in0'): Array(-0.433015+0.75j, dtype=complex128),
 ('out1', 'in1'): Array(0.25-0.433015j, dtype=complex128),
 ('out1', 'out0'): Array(0.+0.j, dtype=complex128),
 ('out1', 'out1'): Array(0.+0.j, dtype=complex128)}