# Single qubit: phase estimation
This example shows a one-qubit interference experiment.

In [8]:
import itertools

import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.dv import DiscreteVariableState, HGate, RZGate
from squint.ops.noise import BitFlipChannel

In [9]:
circuit = Circuit(backend="mixed")

#          ____      ___________      ____
# |0> --- | H | --- | Rz(\phi) | --- | H | ----
#         ----      -----------      ----

circuit.add(DiscreteVariableState(wires=(0,), n=(0,)))
circuit.add(HGate(wires=(0,)))
circuit.add(RZGate(wires=(0,), phi=0.1 * jnp.pi), "phase")
circuit.add(HGate(wires=(0,)))
circuit.add(BitFlipChannel(wires=(0,), p=0.1))

pprint(circuit)

In [10]:
pprint(circuit.subscripts)

In [11]:
circuit.ops["phase"].phi

Array(0.31415927, dtype=float64, weak_type=True)

In [12]:
params, static = eqx.partition(circuit, eqx.is_inexact_array)

pprint(params)
pprint(static)

In [13]:
sim = circuit.compile(params, static, dim=2, optimize="greedy")

In [None]:
ket = sim.amplitudes.forward(params)
dket = sim.amplitudes.grad(params)
prob = sim.probabilities.forward(params)
dprob = sim.probabilities.grad(params)

print(f"{ket.shape}, {ket.dtype}")
print(f"{prob.shape}, {prob.dtype}")

(2, 2), complex64
(2,), float32


In [None]:
sim = circuit.compile(params, static, dim=2, optimize="greedy")

get = lambda pytree: jnp.array([pytree.ops["phase"].phi])

phis = jnp.linspace(-jnp.pi, jnp.pi, 100)
params = eqx.tree_at(
    lambda pytree: pytree.ops["phase"].phi, params, jnp.expand_dims(phis, axis=1)
)

probs = eqx.filter_vmap(sim.probabilities.forward)(params)
cfims = eqx.filter_vmap(sim.probabilities.cfim, in_axes=(None, 0))(get, params)
qfims = eqx.filter_vmap(sim.amplitudes.qfim, in_axes=(None, 0))(get, params)

colors = sns.color_palette("crest", n_colors=jnp.prod(jnp.array(probs.shape[1:])))
fig, ax = plt.subplots()
for i, idx in enumerate(
    itertools.product(*[list(range(ell)) for ell in probs.shape[1:]])
):
    ax.plot(phis, probs[:, *idx], label=f"{idx}", color=colors[i])
ax.legend()
ax.set(xlabel=r"Phase, $\varphi$", ylabel=r"Probability, $p(\mathbf{x} | \varphi)$")


fig, ax = plt.subplots()
ax.plot(phis, qfims.squeeze(), color=colors[i])
ax.set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^Q$",
    ylim=[0, 1.05 * jnp.max(qfims)],
)

fig, ax = plt.subplots()
ax.plot(phis, cfims.squeeze(), color=colors[i])
ax.set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^C$",
    ylim=[0, 1.05 * jnp.max(cfims)],
)

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())