In [None]:
# default_exp backends.__init__

# Backend

> SAX Backends

In [None]:
# hide
import jax.numpy as jnp
from nbdev import show_doc
from sax.typing_ import SDense, SDict, sdict

import os, sys; sys.stderr = open(os.devnull, "w")

In [None]:
# exporti
from __future__ import annotations

try:
    import jax
    JAX_AVAILABLE = True
except ImportError:
    JAX_AVAILABLE = False
    
try:
    import klujax
    KLUJAX_AVAILABLE = True
except ImportError:
    KLUJAX_AVAILABLE = False

from sax.backends.default import analyze_circuit, evaluate_circuit
from sax.backends.klu import analyze_circuit_klu, evaluate_circuit_klu
from sax.backends.additive import analyze_circuit_additive, evaluate_circuit_additive

#### circuit_backends

In [None]:
# exports

circuit_backends = {
    "default": (analyze_circuit, evaluate_circuit),
    "klu": (analyze_circuit_klu, evaluate_circuit_klu),
    "additive": (analyze_circuit_additive, evaluate_circuit_additive),
}

if (not JAX_AVAILABLE) or (not KLUJAX_AVAILABLE):
    del circuit_backends["klu"]

SAX allows to easily interchange the backend of a circuit. A SAX backend consists of a static analysis step and an evaluation step:

In [None]:
# hide_input
from sax.backends.default import evaluate_circuit
show_doc(analyze_circuit, doc_string=False)
show_doc(evaluate_circuit, doc_string=False)

The `analyze_circuit` step should statically analyze the connections and ports and should return an `analyzed` object. This object contains all the static objects that are needed for circuit computation but won't be needed to be recalculated when any parameters of the circuit change.

The `evaluate_circuit` step evaluates the circuit for given `SType` instances.

> Example

Let's create an MZI `SDict` using the default backend's `evaluate_circuit`:

In [None]:
wg_sdict: SDict = {
    ("in0", "out0"): 0.5 + 0.86603j,
    ("out0", "in0"): 0.5 + 0.86603j,
}

τ, κ = 0.5 ** 0.5, 1j * 0.5 ** 0.5
dc_sdense: SDense = (
    jnp.array([[0, 0, τ, κ], 
               [0, 0, κ, τ], 
               [τ, κ, 0, 0], 
               [κ, τ, 0, 0]]),
    {"in0": 0, "in1": 1, "out0": 2, "out1": 3},
)

instances={
    "dc1": dc_sdense,
    "wg": wg_sdict,
    "dc2": dc_sdense,
}
connections={
    "dc1,out0": "wg,in0",
    "wg,out0": "dc2,in0",
    "dc1,out1": "dc2,in1",
}
ports={
    "in0": "dc1,in0",
    "in1": "dc1,in1",
    "out0": "dc2,out0",
    "out1": "dc2,out1",
}

analyzed = analyze_circuit(connections, ports)
mzi_sdict = evaluate_circuit(analyzed, instances)
display(mzi_sdict)

In [None]:
analyzed = analyze_circuit_klu(connections, ports)
mzi_sdict_klu = sdict(evaluate_circuit_klu(analyzed, instances))
display(mzi_sdict_klu)

In [None]:
# hide
for k in mzi_sdict:
    print(k, abs(mzi_sdict[k]-mzi_sdict_klu[k]))