In [None]:
import tensorcircuit as tc
import jax.numpy as np
import optax
from tensorcircuit . experimental import qng
from functools import partial

In [None]:
backend = tc.set_backend("jax")
tc.set_dtype("complex128")
tc.set_contractor("auto")  # “auto”, “greedy”, “branch”, “plain”, “tng”, “custom”

In [None]:
n = 4
k = 3
#
params = backend.implicit_randn([3 * n, k + 1], dtype="complex")

def sensor(params, phi):
    p = 0.0
    dmc = tc.Circuit(n)

    params_probe = params[:, 0:-1]
    params_measure = params[:, -1:]

    # probe state
    # for layer in range(k):
    #     for i in range(n):
    #         dmc.r(i, theta=params_probe[i * 3, k],
    #               alpha=params_probe[i])
    #     for i in range(1, n):
    #         dmc.cnot(i-1, i)
    #     for i in range(n):
    #         dmc.depolarizing(i, px=p, py=p, pz=p)
    dmc.h(0)
    for i in range(1, n):
        dmc.cnot(0, i)

    # interaction
    for i in range(n):
        dmc.rz(i, theta = phi)

    # measurement
    # for i in range(n):
    #     dmc.rx(i, theta=params_measure[i * 3, 0])
    #     dmc.ry(i, theta=params_measure[i * 3 + 1, 0])
    #     dmc.rz(i, theta=params_measure[i * 3 + 1, 0])

    return dmc.densitymatrix()


def cfi(params, phi):

    def prob(params, phi):
        dm = sensor(params, phi)
        return backend.real(backend.diagonal(dm))

    p = prob(params, phi)
    # dprob = backend.jit(backend.jacrev(lambda phi: prob(params=params, phi=phi)))
    dprob = backend.jacrev(lambda phi: prob(params=params, phi=phi))
    print(dprob(phi))
    print(dprob(phi).squeeze())
    fi = backend.sum((dprob(phi).squeeze()) ** 2 / p)
    return fi


# phi = np.array([1.12314])
phi = backend.implicit_randn()
print(cfi(params, phi))

In [None]:
n = 4
k = 3
#
params = backend.implicit_randn([3 * n, k + 1])

def sensor(params, phi):
    p = 0.0
    # dmc = tc.Circuit(n)
    dmc = tc.DMCircuit(n)

    params_probe = params[:, 0:-1]
    params_measure = params[:, -1:]

    # probe state
    for layer in range(k):
        for i in range(n):
            dmc.r(i,
                  theta=params_probe[3*i, k],
                  alpha=params_probe[3*i + 1, k],
                  phi = params_probe[3*i + 2, k])
        for i in range(1, n):
            dmc.cnot(0, i)
    #     for i in range(n):
    #         dmc.depolarizing(i, px=p, py=p, pz=p)

    # dmc.h(0)
    # for i in range(1, n):
    #     dmc.cnot(0, i)

    # interaction
    for i in range(n):
        dmc.rx(i, theta = phi)

    # measurement
    for i in range(n):
        dmc.r(i,
              theta=params_measure[3*i, 0],
              alpha=params_measure[3*i + 1, 0],
              phi = params_measure[3*i + 2, 0])

    #     dmc.rx(i, theta=params_measure[i * 3, 0])
    #     dmc.ry(i, theta=params_measure[i * 3 + 1, 0])
    #     dmc.rz(i, theta=params_measure[i * 3 + 1, 0])

    return dmc.densitymatrix()
    # return dmc.wavefunction()


def cfi(params, phi):

    def prob(params, phi):
        dm = sensor(params, phi)
        return backend.real(backend.diagonal(dm))
        # return backend.real(dm)

    p = prob(params, phi)
    # dprob = backend.jit(backend.jacrev(lambda phi: prob(params=params, phi=phi)))
    dprob = backend.jacrev(lambda phi: prob(params=params, phi=phi))
    # print(dprob(phi))
    # print(dprob(phi).squeeze())
    fi = backend.sum((dprob(phi).squeeze()) ** 2 / p)
    return fi

print(backend.sum(backend.abs(sensor(params, phi))**2))
# phi = np.array([1.12314])
# phi = backend.implicit_randn()
# print(cfi(params, phi))

In [None]:
cfi_val_grad_jit = backend.jit(backend.value_and_grad(lambda params: -cfi(params=params, _phi=phi)))
val, grad = cfi_val_grad_jit(params)

In [None]:
opt = tc.backend.optimizer(optax.adagrad(learning_rate=0.95))


In [None]:
params = backend.implicit_randn([3 * n, k + 1])

for i in range(250):
    val, grad = cfi_val_grad_jit(params)
    params = opt.update(grad, params)
    print(f"Step {i} | CFI {val}")
    # print(params)

In [None]:
# %timeit noisy(4)

In [None]:
n = 6
k = 4
#
params = backend.implicit_randn([3 * n, k]).astype("complex")


def sensor(params, phi):
    dmc = tc.Circuit(n)

    for i in range(k):
        for j in range(n):
            dmc.r(j, theta=params[3*j, i], alpha=params[3*j + 1, i], phi=params[3*j + 2, i])

        for j in range(1, n):
            dmc.cnot(j-1, j)

    # interaction
    for j in range(n):
        dmc.rz(j, theta = phi)
    return dmc

phi = np.array([1.12314]).astype("complex")
params = backend.implicit_randn([3 * n, k])

dmc = sensor(params, phi)
dmc.draw(output="text")


def qfi(_params, phi):
    psi = sensor(_params, phi).state()[:, None]
    f_dpsi_phi = backend.jacrev(lambda phi: sensor(params=_params, phi=phi).state())
    d_psi = f_dpsi_phi(phi)
    fi = 4 * backend.real((backend.conj(d_psi.T) @ d_psi) + (backend.conj(d_psi.T) @ psi)**2)
    return fi[0, 0]


In [None]:
dmc.draw(output="text")

In [None]:
qfi_val_grad_jit = backend.jit(backend.value_and_grad(lambda params: -qfi(_params=params, phi=phi)))
val, grad = qfi_val_grad_jit(params)
print(val, grad)

In [None]:
opt = tc.backend.optimizer(optax.adagrad(learning_rate=0.35))

params = backend.implicit_randn([3 * n, k])

for i in range(250):
    val, grad = qfi_val_grad_jit(params)
    # print(grad)
    params = opt.update(grad, params)
    print(f"Step {i} | QFI {val}")
    # print(params)

# CFI

In [62]:
n = 6
k = 4

def sensor(params, phi):
    dmc = tc.Circuit(n)

    for i in range(k):
        for j in range(n):
            dmc.r(j, theta=params[3*j, i], alpha=params[3*j + 1, i], phi=params[3*j + 2, i])

        for j in range(1, n):
            dmc.cnot(j-1, j)

    for j in range(n):
        dmc.r(j, theta=params[3*j, i], alpha=params[3*j + 1, i], phi=params[3*j + 2, i])

    # interaction
    for j in range(n):
        dmc.rz(j, theta = phi)

    # measurement
    for j in range(n):
        dmc.u(j, theta = params[3*j, -1], phi=params[3*j + 1, -1])

    return dmc


phi = np.array([1.12314])
params = backend.implicit_randn([3 * n, k+1])

dmc = sensor(params, phi)

def cfi(_params, _phi):
    def probs(_params, _phi):
        return backend.abs(sensor(_params, _phi).state())**2
    pr = probs(_params, _phi)
    dpr_phi = backend.jacrev(lambda _phi: probs(_params=_params, _phi=_phi))
    d_pr = dpr_phi(phi).squeeze()
    fim = backend.sum(d_pr * d_pr / pr)
    return fim

print(cfi(params, phi))


0.6394717310835949


In [63]:
dmc.draw(output="text")

In [64]:
cfi_val_grad_jit = backend.jit(backend.value_and_grad(lambda params: -cfi(_params=params, _phi=phi)))
val, grad = cfi_val_grad_jit(params)
print(val, grad)

-0.6394717310835949 [[ 5.18865168e-01 -6.11652255e-01  8.16755183e-03  5.30412495e-02
  -4.49732132e-02]
 [ 1.15116306e-01  3.28251682e-02  1.25711054e-01  2.40929276e-02
   1.61329283e-16]
 [-1.63690925e-01  1.16701452e-02  1.72127262e-01  4.79627214e-03
   0.00000000e+00]
 [-1.94282010e-02 -1.98114067e-01  1.32799104e-01  1.51566472e-02
   9.41578299e-03]
 [-2.73651898e-01 -9.14896876e-02 -1.37297630e-01  2.43475050e-01
   7.63278329e-17]
 [ 1.41687365e-03  3.02134128e-03  6.01618551e-04  3.70453298e-03
   0.00000000e+00]
 [ 3.03009659e-01 -6.77434623e-01 -7.13985264e-01 -1.42582804e-01
  -4.15211141e-01]
 [ 1.50111970e-02 -1.13205642e-01 -2.80193567e-01 -1.37435943e-01
   2.60208521e-17]
 [-3.83593887e-03 -1.34522796e-01  5.09209409e-02  1.66091055e-01
   0.00000000e+00]
 [ 4.89116937e-01  1.63508490e-01  2.25313216e-01 -4.13548231e-01
  -2.58194387e-01]
 [ 3.97523820e-01  1.12787612e-01  8.18160027e-02 -7.06880689e-01
  -1.38777878e-16]
 [-7.92494714e-02  2.84399599e-01 -7.29515683

In [65]:
opt = tc.backend.optimizer(optax.adagrad(learning_rate=0.2))
params = backend.implicit_randn([3 * n, k+1])

for i in range(250):
    val, grad = cfi_val_grad_jit(params)
    # print(grad)
    params = opt.update(grad, params)
    print(f"Step {i} | CFI {val}")
    # print(params)

Step 0 | CFI -1.1092060578649465
Step 1 | CFI -2.100920917650086
Step 2 | CFI -2.367207042897983
Step 3 | CFI -1.9396719011168324
Step 4 | CFI -3.8829945799987304
Step 5 | CFI -2.7282084158076576
Step 6 | CFI -2.874060961304079
Step 7 | CFI -3.4453112270270347
Step 8 | CFI -5.630369510119214
Step 9 | CFI -5.3125408866388515
Step 10 | CFI -8.753380867283273
Step 11 | CFI -7.103218086282496
Step 12 | CFI -7.412387424200567
Step 13 | CFI -8.847269389994562
Step 14 | CFI -12.215263568060617
Step 15 | CFI -7.462722219955374
Step 16 | CFI -11.387174194741418
Step 17 | CFI -13.371275382200118
Step 18 | CFI -14.707636824265853
Step 19 | CFI -16.206505119985547
Step 20 | CFI -21.001552133521415
Step 21 | CFI -19.844634044159196
Step 22 | CFI -21.458649202889767
Step 23 | CFI -22.52630314948898
Step 24 | CFI -24.31573212616739
Step 25 | CFI -23.389846101514493
Step 26 | CFI -20.31944387881403
Step 27 | CFI -9.895113577607367
Step 28 | CFI -20.422009488238256
Step 29 | CFI -27.40264805791638
Step

# Mixed CFI

In [95]:
n = 4
k = 4


def sensor(params, phi, gamma):
    dmc = tc.DMCircuit(n)

    for i in range(k):
        for j in range(n):
            dmc.r(j, theta=params[3 * j, i], alpha=params[3 * j + 1, i], phi=params[3 * j + 2, i])

        for j in range(1, n):
            dmc.cnot(j - 1, j)

        for j in range(n):
            dmc.phasedamping(j, gamma=gamma[0])

    for j in range(n):
        dmc.r(j, theta=params[3 * j, i], alpha=params[3 * j + 1, i], phi=params[3 * j + 2, i])

    # interaction
    for j in range(n):
        dmc.rz(j, theta=phi[0])

    # measurement
    for j in range(n):
        dmc.u(j, theta=params[3 * j, -1], phi=params[3 * j + 1, -1])

    return dmc


phi = np.array([1.12314])
gamma = np.array([0.0])
params = backend.implicit_randn([3 * n, k + 1])

dmc = sensor(params, phi, gamma)


def cfi(_params, _phi, _gamma):
    def probs(_params, _phi, _gamma):
        return backend.abs(backend.diagonal(sensor(_params, _phi, _gamma).densitymatrix()))

    pr = probs(_params, _phi, _gamma)
    dpr_phi = backend.jacrev(lambda _phi: probs(_params=_params, _phi=_phi, _gamma=_gamma))
    d_pr = dpr_phi(phi).squeeze()
    fim = backend.sum(d_pr * d_pr / pr)
    return fim

print(cfi(params, phi, gamma))

def neg_cfi(_params, _phi, _gamma):
    return -cfi(_params, _phi, _gamma)

0.7559174938461792


In [93]:
dmc.draw(output="text")

In [94]:
cfi_val_grad_jit = backend.jit(backend.value_and_grad(lambda params: -cfi(_params=params, _phi=phi)))

# val, grad = cfi_val_grad_jit(params)
# print(val, grad)

In [86]:
opt = tc.backend.optimizer(optax.adagrad(learning_rate=0.2))
params = backend.implicit_randn([3 * n, k + 1])

for i in range(250):
    val, grad = cfi_val_grad_jit(params)
    # print(grad)
    params = opt.update(grad, params)
    print(f"Step {i} | CFI {val}")
    # print(params)

Step 0 | CFI -0.441705329509956
Step 1 | CFI -1.0312341507063127
Step 2 | CFI -2.4729042134941497
Step 3 | CFI -3.499292325706916
Step 4 | CFI -1.7221474242026285
Step 5 | CFI -2.586935689190976
Step 6 | CFI -5.093449601988809
Step 7 | CFI -6.085003139746293
Step 8 | CFI -5.240416849811725
Step 9 | CFI -4.484116325624765
Step 10 | CFI -5.919639263812128
Step 11 | CFI -7.354064023088309
Step 12 | CFI -8.020980449260312
Step 13 | CFI -10.556941245156962
Step 14 | CFI -7.948263701269647
Step 15 | CFI -4.475623416714635
Step 16 | CFI -11.317801269989971
Step 17 | CFI -12.497689876295496
Step 18 | CFI -11.050902613415928
Step 19 | CFI -10.105636429202734
Step 20 | CFI -11.488079682309
Step 21 | CFI -12.59503152537006
Step 22 | CFI -12.530274726287084
Step 23 | CFI -13.168919090268009
Step 24 | CFI -13.655790344632027
Step 25 | CFI -14.1799688774284
Step 26 | CFI -14.501963850410322
Step 27 | CFI -14.676715131582764
Step 28 | CFI -14.72075044862501
Step 29 | CFI -14.71667281096704
Step 30 | 

In [None]:
def optimal_information_under_dephasing(gamma):



for gamma in np.linspace(0, 1, 10):


In [105]:
import jax

def func(a, b):
    return a*a*b


jax.jacrev(func, argnums=(0,))(0.1, 1.0)


(Array(0.2, dtype=float64, weak_type=True),)

In [119]:
params = backend.implicit_randn([3])
def func(a, b, c):
    dmc = tc.Circuit(1)
    dmc.r(0, theta=a, alpha=b, phi=c)
    return dmc.state()
df = backend.jacrev(func, [0])
print(df(1.0, 1.0, 1.0))

(Array([-0.84147098,  0.3825737 ], dtype=float64, weak_type=True), Array([0.       , 0.3825737], dtype=float64, weak_type=True), Array([0.       , 0.3825737], dtype=float64, weak_type=True))
