In [15]:
# Added to silence some warnings.
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
import pennylane as qml

n = 6
k = 2

dev = qml.device("default.qubit", wires=n)

@qml.qnode(dev, interface="jax")
def probe(theta, phi):

    # probe state preparation
    for j in range(k):
        for i in range(n):
            qml.RX(theta[i, 3*j], wires=i)
            qml.RY(theta[i, 3*j + 1], wires=i)
            qml.RZ(theta[i, 3*j + 2], wires=i)
        for i in range(0, n-1, 2):
            qml.CNOT(wires=[i, i+1])
        for i in range(1, n-1, 2):
            qml.CNOT(wires=[i, i+1])

    # interaction
    for i in range(n):
        qml.RX(phi, wires=i)

    return qml.state()


@qml.qnode(dev, interface="jax")
def sensor(theta, phi, mu):

    # probe state preparation
    for j in range(k):
        for i in range(n):
            qml.RX(theta[i, 3*j], wires=i)
            qml.RY(theta[i, 3*j + 1], wires=i)
            qml.RZ(theta[i, 3*j + 2], wires=i)
        for i in range(0, n-1, 2):
            qml.CNOT(wires=[i, i+1])
        for i in range(1, n-1, 2):
            qml.CNOT(wires=[i, i+1])

    # interaction
    for i in range(n):
        qml.RX(phi, wires=i)

    # measurement
    for i in range(n):
        qml.RX(mu[i, 0], wires=i)
        qml.RY(mu[i, 1], wires=i)
        qml.RZ(mu[i, 2], wires=i)


@qml.qnode(dev, interface="jax")
def sensor(theta, phi, mu):

    # probe state preparation
    for j in range(k):
        for i in range(n):
            qml.RX(theta[i, 3*j], wires=i)
            qml.RY(theta[i, 3*j + 1], wires=i)
            qml.RZ(theta[i, 3*j + 2], wires=i)
        for i in range(0, n-1, 2):
            qml.CNOT(wires=[i, i+1])
        for i in range(1, n-1, 2):
            qml.CNOT(wires=[i, i+1])

    # interaction
    for i in range(n):
        qml.RX(phi, wires=i)

    # measurement
    for i in range(n):
        qml.RX(mu[i, 0], wires=i)
        qml.RY(mu[i, 1], wires=i)
        qml.RZ(mu[i, 2], wires=i)

    return qml

In [17]:
key = jax.random.PRNGKey(0)
phi = jax.random.uniform(key).astype("complex")
theta = jax.random.uniform(key, shape=[n, 3*k]).astype("complex")

jit_probe = jax.jit(probe)
jit_probe_grad = jax.jacrev(probe, argnums=1, holomorphic=True)

print(probe(theta, phi))
print(jit_probe(theta, phi).shape)
# print(jit_probe_grad(theta, phi).shape)

# def qfim(phi):
#     ket = jit_probe(theta, phi)
#     dket = jit_probe_grad(theta, phi)
#
#     print(ket)
#     print(dket)
#
# qfim(phi)

[ 0.08254194-0.17372768j -0.12142696-0.31693351j  0.04699802-0.17269352j
 -0.10776463-0.10816239j -0.22202476-0.18709174j -0.19245916-0.07931027j
 -0.04786103+0.00631811j -0.0756872 -0.0777465j  -0.05939747+0.03369411j
 -0.05298905+0.01969888j -0.158381  -0.09177549j -0.09934472-0.01820109j
 -0.04228061-0.00507773j -0.06411352+0.02955867j -0.01311911-0.06545985j
 -0.12653213-0.05735766j -0.06085876+0.0107406j  -0.04975247-0.02571916j
 -0.13774352-0.16944144j -0.11256319-0.0754626j   0.02272343-0.05824919j
 -0.06047701-0.04045654j  0.05406281-0.09400927j -0.06912738-0.17822378j
 -0.01664856+0.01573966j -0.0317322 +0.04160776j -0.05560944+0.05326311j
 -0.01920945+0.04972324j -0.02326593+0.00117269j -0.00267403+0.01992863j
 -0.02946624+0.00718339j -0.02907542+0.04061723j -0.03411911+0.01232871j
 -0.05437835+0.06758511j -0.05845088+0.06220948j -0.01545415+0.0662085j
 -0.04014621+0.03100614j -0.00410277+0.03949367j -0.02453146+0.02020252j
 -0.02325712+0.05374831j -0.05107998+0.04002754j -0.

In [96]:
# print(theta.shape)
# print(jit_probe(theta, phi))
print(jit_probe_grad(theta, phi))

[ 0.2008803 -0.07330582j  0.12037285-0.00547297j  0.00480808+0.00863459j
 ... -0.038725  -0.00655311j -0.02425729-0.00051095j
 -0.02017862+0.00521551j]


In [81]:
key = jax.random.PRNGKey(0)
phi = 0.1
theta = jax.random.uniform(key, shape=[n, 3*k])
mu = jax.random.uniform(key, shape=[n, 1])

jit_sensor = jax.jit(sensor)
jit_sensor_grad = jax.jacrev(sensor, argnums=1, holomorphic=False)
jit_sensor(theta, phi, mu)
jit_sensor_grad(theta, phi, mu)
print(jit_sensor(theta, phi, mu))
print(jit_sensor_grad(theta, phi, mu))

[0.00897753 0.08999533 0.04398529 0.02453177 0.06512392 0.07417258
 0.00558877 0.1133059  0.07624569 0.08061748 0.03336137 0.00339381
 0.05665476 0.22336002 0.01196069 0.08872509]
[-0.02181035 -0.18341337 -0.06686712  0.03509145 -0.03718322  0.02874459
  0.02464864  0.06584543  0.14751854 -0.10011345 -0.03290145  0.02451776
 -0.07024069  0.28668713 -0.01331841 -0.08720549]
