In [22]:
import itertools

import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.base import SharedGate
from squint.ops.dv import (
    Conditional,
    DiscreteVariableState,
    HGate,
    Phase,
    XGate,
    MaximallyMixedState,
)
from squint.ops.noise import ErasureChannel, BitFlipChannel
from squint.diagram import draw

In [25]:
n = 2  # number of qubits
circuit = Circuit(backend="mixed")

circuit.add(MaximallyMixedState(wires=(0,)))
for i in range(1, n):
    circuit.add(DiscreteVariableState(wires=(i,), n=(0,)))

circuit.add(HGate(wires=(0,)))
for i in range(n - 1):
    circuit.add(Conditional(gate=XGate, wires=(i, i + 1)))

circuit.add(
    SharedGate(op=Phase(wires=(0,), phi=0.1 * jnp.pi), wires=tuple(range(1, n))),
    "phase",
)

for i in range(n):
    circuit.add(HGate(wires=(i,)))

for i in range(n):
    circuit.add(BitFlipChannel(wires=(i,), p=0.4))

circuit.add(ErasureChannel(wires=(0,)))

pprint(circuit)

In [27]:
diagram = draw(circuit)
diagram.show()