# $XX$ all-to-all Ising interactions


In [16]:
import itertools

import equinox as eqx
import jax.numpy as jnp
import matplotlib.pyplot as plt
import seaborn as sns
from rich.pretty import pprint

from squint.circuit import Circuit
from squint.ops.base import SharedGate
from squint.ops.dv import DiscreteVariableState, HGate, Phase, RXXGate

In [27]:
dim = 2
n = 4

circuit = Circuit()

for i in range(n):
    circuit.add(DiscreteVariableState(wires=(i,), n=(0,)))

for i, j in itertools.combinations(list(range(n)), 2):
    circuit.add(RXXGate(wires=(i, j), angle=jnp.pi / 4))

circuit.add(
    SharedGate(op=Phase(wires=(0,), phi=0.1 * jnp.pi), wires=tuple(range(1, n))),
    "phase",
)

for i in range(n):
    circuit.add(HGate(wires=(i,)))

params, static = eqx.partition(circuit, eqx.is_inexact_array)

pprint(circuit)

In [None]:
sim = circuit.compile(params, static, dim=dim, optimize="greedy").jit()

get = lambda pytree: jnp.array([pytree.ops["phase"].op.phi])

prob = sim.prob.forward(params)
dprob = sim.prob.grad(params)
cfi = jnp.sum(get(dprob) ** 2 / (prob + 1e-14))

print(get(dprob))
print(cfi)

[32m2025-03-06 20:11:02.557[0m | [1mINFO    [0m | [36msquint.circuit[0m:[36mcompile[0m:[36m114[0m - [1m  Complete contraction:  a,b,c,d,aebf,egch,gidj,fkhl,kmjn,lonp,iq,mr,os,pt,qu,rv,sw,tx->uvwx
         Naive scaling:  24
     Optimized scaling:  6
      Naive FLOP count:  3.020e+8
  Optimized FLOP count:  9.440e+2
   Theoretical speedup:  3.199e+5
  Largest intermediate:  1.600e+1 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4           GEMM            aebf,a->ebf    b,c,d,egch,gidj,fkhl,kmjn,lonp,iq,mr,os,pt,qu,rv,sw,tx,ebf->uvwx
   4    GEMV/EINSUM            egch,c->egh    b,d,gidj,fkhl,kmjn,lonp,iq,mr,os,pt,qu,rv,sw,tx,ebf,egh->uvwx
   4    GEMV/EINSUM            gidj,d->gij    b,fkhl,kmjn,lonp,iq,mr,os,pt,qu,rv,sw,tx,ebf,egh,gij->uvwx
   3    GEMV/EINSUM       

[[[[[-0.0772543   0.07725423]
    [ 0.07725422 -0.07725418]]

   [[ 0.07725428 -0.0772542 ]
    [-0.07725422  0.07725421]]]


  [[[ 0.07725429 -0.07725421]
    [-0.07725421  0.07725421]]

   [[-0.07725419  0.07725419]
    [ 0.07725417 -0.07725424]]]]]
16.000004


In [None]:
colors = sns.color_palette("crest", n_colors=jnp.prod(jnp.array(probs.shape[1:])))
fig, ax = plt.subplots()
for i, idx in enumerate(
    itertools.product(*[list(range(ell)) for ell in probs.shape[1:]])
):
    ax.plot(phis, probs[:, *idx], label=f"{idx}", color=colors[i])
ax.legend()
ax.set(xlabel=r"Phase, $\varphi$", ylabel=r"Probability, $p(\mathbf{x} | \varphi)$")


fig, ax = plt.subplots()
ax.plot(phis, qfims.squeeze(), color=colors[i])
ax.set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^Q$",
    ylim=[0, 1.05 * jnp.max(qfims)],
)

fig, ax = plt.subplots()
ax.plot(phis, cfims.squeeze(), color=colors[i])
ax.set(
    xlabel=r"Phase, $\varphi$",
    ylabel=r"$\mathcal{I}_\varphi^C$",
    ylim=[0, 1.05 * jnp.max(cfims)],
)