# Scheme 1 Proposal

This notebook implements Scheme 1, using multimode, fixed-energy Fock state as the ancilla to estimate $|g|$ and $\varphi$ of the weak thermal star photon state.

We optimize over the tunable parameters of the scheme, including the superposition amplitudes of the ancilla state. 

In [12]:
import dataclasses
import itertools
from typing import Any

import equinox as eqx
import jax.numpy as jnp
import numpy as np
from tqdm.auto import tqdm

from squint.circuit import Circuit
from squint.diagram import draw
from squint.ops.fock import (
    BeamSplitter,
    FixedEnergyFockState,
    FockState,
    TwoModeWeakThermalState,
)
from squint.ops.noise import ErasureChannel

## Definition of the optical circuit
We first define the optical circuit as a `squint.Circuit` object, which contains symbolic descriptions of the states, operators, and channels in sequential/causal order. 

The star is modelled as a weak thermal source, 
$$
\rho^{\text{star}}_{s_0, s_1} = 
\begin{bmatrix}
1 - \epsilon & 0 & 0 & 0 \\
0 & \epsilon / 2 &  |g|\exp(i\varphi) & 0 \\
0 & |g|\exp(-i\varphi) & \epsilon / 2 & 0 \\
0 & 0 & 0 & 0 \\
\end{bmatrix}
.
$$

The variational ancilla resource state is a multimode-entangled, fixed photon number state,
$$
|{\rho}\rangle^{\text{ancilla}}_{a_0, a_1} = 
\sum_{n,m}  \alpha_{n,m}  |{n}\rangle_{a_0}|{m}\rangle_{a_1} .
$$
where $n+m = N$.

This ancilla state is distributed to the left and right telescope through lossy channels, each modelled as a beamsplitter with the vacuum state on remaining port,
$$
|{\rho}\rangle^{\text{loss}}_{d_0, d_1} = |{0}\rangle_{d_0} |{0}\rangle_{d_1}
$$

In [13]:
def telescope(n_ancilla_modes: int = 1, n_ancilla_photons_per_mode: int = 1):
    dim = n_ancilla_modes * n_ancilla_photons_per_mode + 1 + 1

    wire_star_left = "sl"
    wire_star_right = "sr"
    wires_ancilla_left = tuple(f"al{i}" for i in range(n_ancilla_modes))
    wires_ancilla_right = tuple(f"ar{i}" for i in range(n_ancilla_modes))
    wires_dump_left = tuple(f"dl{i}" for i in range(n_ancilla_modes))
    wires_dump_right = tuple(f"dr{i}" for i in range(n_ancilla_modes))

    circuit = Circuit(backend="mixed")

    # star modes
    circuit.add(
        TwoModeWeakThermalState(
            wires=(wire_star_left, wire_star_right), epsilon=1.0, g=1.0, phi=0.1
        ),
        "star",
    )

    # ancilla modes
    for i, (wire_ancilla_left, wire_ancilla_right) in enumerate(
        zip(wires_ancilla_left, wires_ancilla_right, strict=False)
    ):
        circuit.add(
            FixedEnergyFockState(
                wires=(wire_ancilla_left, wire_ancilla_right),
                n=n_ancilla_photons_per_mode,
            ),
            f"ancilla{i}",
        )

    # loss modes
    for i, wire_dump in enumerate(wires_dump_left + wires_dump_right):
        circuit.add(FockState(wires=(wire_dump,), n=(0,)), f"vac{i}")

    # loss beamsplitters
    for i, (wire_ancilla, wire_dump) in enumerate(
        zip(
            wires_ancilla_left + wires_ancilla_right,
            wires_dump_left + wires_dump_right,
            strict=False,
        )
    ):
        circuit.add(BeamSplitter(wires=(wire_ancilla, wire_dump), r=0.0), f"loss{i}")

    iterator = itertools.count(0)
    for i, wire_ancilla in enumerate(wires_ancilla_left):
        circuit.add(
            BeamSplitter(wires=(wire_ancilla, wire_star_left)), f"ul{next(iterator)}"
        )
    for wire_i, wire_j in itertools.combinations(wires_ancilla_left, 2):
        circuit.add(BeamSplitter(wires=(wire_i, wire_j)), f"ul{next(iterator)}")

    iterator = itertools.count(0)
    for i, wire_ancilla in enumerate(wires_ancilla_right):
        circuit.add(
            BeamSplitter(wires=(wire_ancilla, wire_star_right)), f"ur{next(iterator)}"
        )
    for wire_i, wire_j in itertools.combinations(wires_ancilla_right, 2):
        circuit.add(BeamSplitter(wires=(wire_i, wire_j)), f"ur{next(iterator)}")

    # circuit.add(LinearOpticalUnitaryGate(wires=wires_ancilla_left + (wire_star_left,)), f"ul")
    # circuit.add(LinearOpticalUnitaryGate(wires=wires_ancilla_right + (wire_star_right,)), f"ur")

    for i, wire_dump in enumerate(wires_dump_left + wires_dump_right):
        circuit.add(ErasureChannel(wires=(wire_dump,)), f"ptrace{i}")

    return circuit, dim

With the circuit defined, we verify it was constructed properly (`circuit.verify()`), followed by drawing the tensor network diagram (`circuit.draw()`) and compiling the forward, gradient, and Fisher information methods `sim = circuit.compile()`. We also define the necessary `get` function, which identifies the parameters in our system for which we want to compute the Fisher information (matrix) with respect to.

In [14]:
circuit, dim = telescope(n_ancilla_modes=1, n_ancilla_photons_per_mode=2)
print(f"dim = {dim}")
# pprint(circuit)

dim = 4


In [15]:
circuit.verify()
fig = draw(circuit)
# fig.show()

In [16]:
@dataclasses.dataclass
class TelescopeAncilla:
    n_ancilla_modes: int
    n_ancilla_photons_per_mode: int
    dim: int
    cfims: Any


get = lambda pytree: jnp.array([pytree.ops["star"].phi])
phis = jnp.linspace(-jnp.pi, jnp.pi, 100)


def update(phi, params):
    return eqx.tree_at(lambda pytree: pytree.ops["star"].phi, params, phi)


data = []

In [17]:
for n_ancilla_modes, n_ancilla_photons_per_mode in ((1, 1), (1, 2), (1, 3)):
    print(n_ancilla_modes, n_ancilla_photons_per_mode)
    circuit, dim = telescope(n_ancilla_modes, n_ancilla_photons_per_mode)
    params, static = eqx.partition(circuit, eqx.is_inexact_array)
    sim = circuit.compile(circuit, static, dim=dim).jit()

    # cfims = jax.lax.map(lambda phi: sim.probabilities.cfim(get, update(phi, params)), phis)
    cfims = []
    for phi in tqdm(phis):
        cfims.append(sim.probabilities.cfim(get, update(phi, params)))
    cfims = jnp.array(cfims)

    d = TelescopeAncilla(
        n_ancilla_modes=n_ancilla_modes,
        n_ancilla_photons_per_mode=n_ancilla_photons_per_mode,
        dim=dim,
        cfims=np.array(cfims),
    )

1 1


KeyboardInterrupt: 

In [None]:
# colors = itertools.cycle(sns.color_palette('deep', n_colors=3))
# styles = itertools.cycle(["--", "-.", ":", "solid", "dashed", "dashdot", "dotted"])
# fig, ax = plt.subplots()

# for d in data:
# #     print(cfims.shape)
#     label = r"$m_{\text{mode}}="+f"{d.n_ancilla_modes},"+r" n_{\text{photon}}=" + f"{d.n_ancilla_photons_per_mode}" + r"$"
#     ax.plot(
#         phis,
#         d.cfims.squeeze(),
# #         color=next(colors),
# #         ls=next(styles),
# #         alpha=0.8,
#         label=label,
#     )
# ax.legend()
# ax.set(xlabel=r"Phase, $\varphi$", ylabel=r"CFIM")

In [None]:
# colors = itertools.cycle(sns.color_palette('deep', n_colors=3))
# styles = itertools.cycle(["--", "-.", ":", "solid", "dashed", "dashdot", "dotted"])
# fig, ax = plt.subplots()
# print(data[0].squeeze())
# plt.plot(phis, data[0].squeeze())

# for cfims in data:
#     print(cfims.shape)
#     # label = r"$m_{\text{mode}}="+f"{d.n_ancilla_modes},"+r" n_{\text{photon}}=" + f"{d.n_ancilla_photons_per_mode}" + r"$"
#     ax.plot(
#         phis,
#         cfims[:, 0, 0],
#         color=next(colors),
#         ls=next(styles),
#         alpha=0.8,
#         label=label,
#     )
# # ax.legend()
# # ax.set(xlabel=r"Phase, $\varphi$", ylabel=r"CFIM")