# Fock space: 
## JGC quantum-enhanced telescope example

Implements the original GJC scheme to compute & reproduce the classical Fisher information of this. 
Here, a phase shift is encoded into a photon in the superposition of being collected by the left and right telescopes. 
An ancilla photon is distributed in a quantum network between the telescopes, and enables a quantum interference measurement between the two photon arriving the telescope modes.

<!-- ![10.1103/PhysRevLett.109.070503](assets/gjc_schematic.png) -->
<img src="assets/gjc_schematic.png" width=400 />
Source: 10.1103/PhysRevLett.109.070503

In [31]:
import itertools

import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.fock import BeamSplitter, FockState, Phase
from squint.utils import partition_op, print_nonzero_entries

In [32]:
cut = 4
circuit = Circuit()


# we add in the stellar photon, which is in an even superposition of spatial modes 0 and 2 (left and right telescopes)
circuit.add(
    FockState(
        wires=(
            0,
            2,
        ),
        n=[(1 / jnp.sqrt(2).item(), (1, 0)), (1 / jnp.sqrt(2).item(), (0, 1))],
    )
)
# the stellar photon accumulates a phase shift in left telescope. 
circuit.add(Phase(wires=(0,), phi=0.01), "phase")

# we add the resources photon, which is in an even superposition of spatial modes 1 and 3
circuit.add(
    FockState(
        wires=(
            1,
            3,
        ),
        n=[(1 / jnp.sqrt(2).item(), (1, 0)), (1 / jnp.sqrt(2).item(), (0, 1))],
    )
)

# we add the linear optical circuit at each telescope
circuit.add(
    BeamSplitter(
        wires=(
            0,
            1,
        ),
    )
)
circuit.add(
    BeamSplitter(
        wires=(
            2,
            3,
        ),
    )
)
pprint(circuit)
   
# we split out the params which can be varied (in this example, it is just the "phase" phi value), and all the static parameters (wires, etc.)
params, static = partition_op(circuit, "phase")

# next we compile the circuit description into function calls, which compute, e.g., the quantum state, probabilities, partial derivates of the quantum state, and partial derivatives of the probabilities
sim = circuit.compile(params, static, dim=cut, optimize="greedy").jit()

# we define a function which indexes in the circuit object, and all other pytrees computed from it, a specific value. this will be necessary to access, e.g., the gradients
get = lambda pytree: jnp.array([pytree.ops["phase"].phi])


[32m2025-03-06 18:49:28.310[0m | [1mINFO    [0m | [36msquint.circuit[0m:[36mcompile[0m:[36m114[0m - [1m  Complete contraction:  ab,ac,de,cfdg,bhei->fghi
         Naive scaling:  9
     Optimized scaling:  6
      Naive FLOP count:  1.311e+6
  Optimized FLOP count:  1.242e+4
   Theoretical speedup:  1.056e+2
  Largest intermediate:  2.560e+2 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3           GEMM              ac,ab->cb                 de,cfdg,bhei,cb->fghi
   5           TDOT          cfdg,de->cfge                    bhei,cb,cfge->fghi
   5           GEMM          cfge,cb->fgeb                       bhei,fgeb->fghi
   6           TDOT        fgeb,bhei->fghi                            fghi->fghi[0m


In [33]:
ket = sim.amplitudes.grad(params)
prob = sim.prob.forward(params)
grad = sim.prob.grad(params)

print_nonzero_entries(prob)

Basis: [0 0 0 2], Value: 0.12499997764825821
Basis: [0 0 1 1], Value: 1.778163348918507e-15
Basis: [0 0 2 0], Value: 0.12499997764825821
Basis: [0 1 0 1], Value: 0.24999374151229858
Basis: [0 1 1 0], Value: 6.249947546166368e-06
Basis: [0 2 0 0], Value: 0.12499997764825821
Basis: [1 0 0 1], Value: 6.249947546166368e-06
Basis: [1 0 1 0], Value: 0.24999377131462097
Basis: [1 1 0 0], Value: 1.778163348918507e-15
Basis: [2 0 0 0], Value: 0.12499997764825821


In [34]:
# we next compute the classical Fisher information
cfi = jnp.sum(get(grad)**2 / (prob + 1e-14))
print(f"The classical Fisher information for `phi` is {cfi}")

# this can also be performed from the `sim` object
cfim = sim.prob.cfim(get, params)
print(f"The classical Fisher information is {cfim}")


The classical Fisher information for `phi` is 0.5
The classical Fisher information is [[0.5]]
