In [10]:
import jax.numpy as jnp
from jax.scipy.linalg import expm
from scipy.stats import unitary_group
from jax import random
from matplotlib import pyplot as plt

from optimization import *
from matrix_utils import *
from gates import *
from jax_circuits import *
from topology import *

import qiskit as qs

Introduce simple parameterization for generic unitary matrix through its hermitian logarithm.

In [11]:
def h3(h):
    return jnp.array([[h[0], h[1]+1j*h[2], h[3]+1j*h[4]],
                      [h[1]-1j*h[2], h[5], h[6]+1j*h[7]],
                      [h[3]-1j*h[4], h[6]-1j*h[7], h[8]]], dtype=jnp.complex64)
                    
def u3(h):
    return expm(1j*h3(h))

Let's check that a random unitary can be learned in this parametrization.

In [12]:
u_target = unitary_group.rvs(3, random_state=131)

res = unitary_learn(u3, u_target, 9, num_repeats=10)
success_hist = [r['loss'][-1] < 1e-5 for r in res]
print(sum(success_hist)/len(success_hist))

100%|██████████| 10/10 [00:17<00:00,  1.76s/it]

1.0





Introduce qutrit gates.

In [13]:
# Projectors to levels 0, 1, 2.
P0 = jnp.array([[1,0,0],
                [0,0,0],
                [0,0,0]], dtype=jnp.complex64)

P1 = jnp.array([[0,0,0],
                [0,1,0],
                [0,0,0]], dtype=jnp.complex64)

P2 = jnp.array([[0,0,0],
                [0,0,0],
                [0,0,1]], dtype=jnp.complex64)

# Z gate in span(0, 1)
Z01 = jnp.array([[1,0,0],
                 [0,-1,0],
                 [0,0,1]], dtype=jnp.complex64)

# qutrit Id
Id3 = jnp.array([[1,0,0],
                 [0,1,0],
                 [0,0,1]], dtype=jnp.complex64)

# CZ acting between levels 0,1 of the first qutrit and 0,1 of the second.
CZ = jnp.kron(P0+P2, Id3)+jnp.kron(P1, Z01)
  
# CP gate acting in the same space as CZ.    
def U3_CP(phi):
    return jnp.kron(P0+P2, Id3)+jnp.kron(P1, P0+P2)+jnp.exp(-1j*phi/2)*jnp.kron(P1, P1)

# RZZ gate acting in the same space as above.
def U3_RZZ(phi):
    return jnp.exp(-1j*phi/2)*jnp.kron(P0, P0)+jnp.exp(1j*phi/2)*jnp.kron(P0, P1)+jnp.exp(1j*phi/2)*jnp.kron(P1, P0)+jnp.exp(-1j*phi/2)*jnp.kron(P1, P1)+jnp.kron(P2, P0+P1+P2)+jnp.kron(P0+P1, P2)

In [14]:
# Just to check that the way I construct gates above is meaningful, the following two representations can be compared.
# I checked and they agree.

def U2_RZZ_qs(phi):
    qc = qs.QuantumCircuit(2)
    qc.rzz(phi, 0, 1)
    return Operator(qc.reverse_bits()).data

def U2_RZZ(phi):
    p0 = jnp.array([[1,0],[0,0]])
    p1 = jnp.array([[0,0],[0,1]])
    return jnp.exp(-1j*phi/2)*jnp.kron(p0, p0)+jnp.exp(1j*phi/2)*jnp.kron(p0, p1)+jnp.exp(1j*phi/2)*jnp.kron(p1, p0)+jnp.exp(-1j*phi/2)*jnp.kron(p1, p1)

Here is the ansatz which consists of interleaved CZ and arbitrary single-qutrit gates. $d$ is the number of CZ gates,  number of 1-qutrit gates is $(d+1)*2$

In [15]:
def qutrit_anz(d):
    def u_anz(h):
        h = h.reshape(d+1, 2, 9)
        u = jnp.kron(u3(h[0,0]), u3(h[0, 1]))
        
        def apply_layer(i, u):
            u = CZ @ u
            u = jnp.kron(u3(h[i,0]), u3(h[i,1])) @ u
            return u
        
        if d > 0:
            u = lax.fori_loop(1, d+1, lambda i, u: apply_layer(i, u), u)
            
        return u
    
    return u_anz    

Let's check 3-CNOT ansatz can learn tensor product of random 2-qubit unitaries.

In [16]:
u0 = unitary_group.rvs(2, random_state=13)
u1 = unitary_group.rvs(2, random_state=14)

u0 = jnp.array([[u0[0,0], u0[0,1], 0],
                [u0[1,0], u0[1,1], 0],
                [0, 0, 1]])

u1 = jnp.array([[u1[0,0], u1[0,1], 0],
                [u1[1,0], u1[1,1], 0],
                [0, 0, 1]])

u01 = jnp.kron(u0, u1)
jnp.allclose(jnp.identity(9), u01 @ u01.conj().T, atol=1e-5)

DeviceArray(True, dtype=bool)

In [17]:
d=2
u_target = u01

loss_f = lambda h: (jnp.abs(qutrit_anz(d)(h)-u_target)**2).sum()

res = mynimize_repeated(loss_f, (d+1)*2*9, num_repeats=10)
success_hist = [r['loss'][-1] < 1e-5 for r in res]
print(sum(success_hist)/len(success_hist))

100%|██████████| 10/10 [00:34<00:00,  3.45s/it]

0.4





Note that the problem of local minima is present already in this simple case!

OK, so here is what I wanted to check.

In [175]:
d=1
phi = 0.25*jnp.pi
u_target = U3_RZZ(phi)

def loss_f(h):
    u_diff = qutrit_anz(d)(h)-u_target
    u_diff_projected = u_diff @ jnp.kron(P0+P1, P0+P1)

    return (jnp.abs(u_diff_projected)**2).sum()

initial_params = [random_angles((d+1)*2*9, key=random.PRNGKey(i)) for i in range(100, 150)]

res = mynimize_repeated(loss_f, (d+1)*2*9, learning_rate=0.01, initial_params_batch=initial_params)
success_hist = [jnp.min(r['loss']) < 1e-4 for r in res]
print(sum(success_hist)/len(success_hist))

100%|██████████| 50/50 [03:35<00:00,  4.30s/it]

0.0





The learning was not successful in many attempts. Just to make sure, learning the same gate with 2 CZ is no problem.

In [18]:
d=2
phi = 0.15*jnp.pi
u_target = U3_RZZ(phi)

def loss_f(h):
    u_diff = qutrit_anz(d)(h)-u_target
    u_diff_projected = u_diff @ jnp.kron(P0+P1, P0+P1)

    return (jnp.abs(u_diff_projected)**2).sum()

initial_params = [random_angles((d+1)*2*9, key=random.PRNGKey(i)) for i in range(10)]

res = mynimize_repeated(loss_f, (d+1)*2*9, learning_rate=0.01, initial_params_batch=initial_params)
success_hist = [jnp.min(r['loss']) < 1e-4 for r in res]
print(sum(success_hist)/len(success_hist))

100%|██████████| 10/10 [00:44<00:00,  4.47s/it]

0.9



