In [None]:
import numpy as np
import qiskit
from qiskit.quantum_info import state_fidelity
from numpy import linalg as LA
import qib
import matplotlib.pyplot as plt
import scipy
import h5py

import sys
sys.path.append("../../src/brickwall_ansatz")
from utils import construct_heisenberg_local_term, X, I2
from precision import HP as prec
from ansatz import ansatz
import rqcopt as oc
from scipy.sparse.linalg import expm_multiply
from qiskit.quantum_info import random_statevector
from optimize import optimize


def random_unitary_4():
    Z = (np.random.randn(4,4) + 1j*np.random.randn(4,4)) / np.sqrt(2)
    Q, R = np.linalg.qr(Z)
    diag = np.diag(R)
    lambdas = diag / np.abs(diag)
    Q = Q * lambdas
    return Q


N, t = 4, 3
latt = qib.lattice.IntegerLattice((N, ), pbc=True)
field = qib.field.Field(qib.field.ParticleType.QUBIT, latt)
#J, h = (np.random.randn(3,1)[:, 0].real, np.random.randn(3,1)[:, 0].real)
J, h = ((1, -1, 1), (1, 2, -2))
hamil = qib.HeisenbergHamiltonian(field, J, h).as_matrix().todense()
#hamil = np.zeros((2**N, 2**N))

def random_hermitian(n, normalize=True):
    A = np.random.randn(n, n) + 1j * np.random.randn(n, n)
    H = (A + A.conj().T) / 2
    evals = np.linalg.eigvalsh(H)
    norm = np.max(np.abs(evals)) if normalize else 1
    return H / norm

#perms = [[0, 1], [2, 3], [1, 2], [3, 0], [0, 1], [2, 3]]
L = 3
perms = [[0, 1, 2, 3] if i%2==0 else [1, 2, 3, 0] for i in range(L)]
layers = len(perms)
V = lambda : scipy.linalg.expm(-1j*t*random_hermitian(4))
Vlist = [V() for i in range(layers)]
G = ansatz(Vlist, N, perms)
#G = ansatz(Vlist_reduced, N, perms)
#distorted = list(Vlist_reduced+eps*np.array([random_unitary_4() for i in range(layers)]))


errs = []
for _ in range(10):
    #Vlist_reduced = [random_unitary_4() for i in range(layers)]
    Vlist_reduced = [V() for i in range(layers)]
    Vlist_trap, f_iter, err_iter = optimize(N,
                                               G,
                                               #scipy.linalg.expm(-1j*t*hamil),
                                               len(Vlist_reduced), 1,
                                               #list(np.array(Vlist)+eps*np.array(Vlist_reduced))
                                               #list(Vlist+eps*np.array(Vlist_reduced))
                                               Vlist_reduced
                                               , perms, niter=100,
                                               
                                               rho_trust=1e-1, radius_init=0.01, maxradius=0.1,
                                               tcg_abstol=1e-12, tcg_reltol=1e-10, tcg_maxiter=100
                                              )
    errs.append(err_iter[-1])


Real part:  0.3856774907343007
[2.089141076676167, 1.970559753764581, 2.7839380597264127]
6.84363889016716
Real part:  0.3856774907343007
Real part:  0.33706498013164277
Radius  0.02
Real part:  0.33706498013164277
[2.0874182343193897, 1.9826393945549325, 2.7949004385961005]
6.864958067470422
Real part:  0.33706498013164277
Real part:  0.23927784570956906
Radius  0.04
Real part:  0.23927784570956906
[2.0834792757245766, 2.0052522672649893, 2.8153496277828682]
6.904081170772434
Real part:  0.23927784570956906
Real part:  0.04166236369774173
Radius  0.08
Real part:  0.04166236369774173
[2.0744628402546277, 2.0440236568807175, 2.850165188354615]
6.96865168548996
Real part:  0.04166236369774173
Real part:  -0.36000698997463754
Radius  0.1
Real part:  -0.36000698997463754
[2.0616249744066177, 2.0940575790046365, 2.8944809199571515]
7.050163473368405
Real part:  -0.36000698997463754
Real part:  -0.8699699431913458
Radius  0.1
Real part:  -0.8699699431913458
[2.0878104332781806, 2.10619501872

In [None]:
errs

In [75]:
from hessian import ansatz_hessian_matrix
import mpmath as mp

n = len(Vlist_trap)
tol=1e-8
#G = scipy.linalg.expm(-1j*t*hamil)
HM = -ansatz_hessian_matrix(Vlist_trap, N, G, perms, flatten=True, unprojected=False)
HM = np.array(HM, dtype=complex)
evals = np.linalg.eigvalsh(HM)
num_neg  = np.sum(evals < -tol)
num_pos  = np.sum(evals >  tol)
num_zero = np.sum(np.abs(evals) <= tol)
print("min eigenvalue:", evals[0])
print("max eigenvalue:", evals[-1])
print("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)

"""mp.mp.dps = 20
A = mp.matrix([list(row) for row in HM])
evals = mp.eig(A, left=False, right=False)
for i, ev in enumerate(evals):
    print(f"λ[{i}] = {mp.nstr(ev, n=20)}")
tol = mp.mpf('1e-10') 
real_evals = []
for ev in evals:
    real_evals.append(mp.re(ev))
num_neg  = sum(ev < -tol for ev in real_evals)
num_pos  = sum(ev >  tol for ev in real_evals)
num_zero = sum(abs(ev) <= tol for ev in real_evals)
print("min eigenvalue:", real_evals[-1])
print("max eigenvalue:", real_evals[0])
print("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)
"""

min eigenvalue: -3.4153221122267234e-10
max eigenvalue: 101.28682377826101
num_neg, num_zero, num_pos: 0 49 79


'mp.mp.dps = 20\nA = mp.matrix([list(row) for row in HM])\nevals = mp.eig(A, left=False, right=False)\nfor i, ev in enumerate(evals):\n    print(f"λ[{i}] = {mp.nstr(ev, n=20)}")\ntol = mp.mpf(\'1e-10\') \nreal_evals = []\nfor ev in evals:\n    real_evals.append(mp.re(ev))\nnum_neg  = sum(ev < -tol for ev in real_evals)\nnum_pos  = sum(ev >  tol for ev in real_evals)\nnum_zero = sum(abs(ev) <= tol for ev in real_evals)\nprint("min eigenvalue:", real_evals[-1])\nprint("max eigenvalue:", real_evals[0])\nprint("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)\n'

In [36]:
n = len(Vlist)
#G = ansatz(Vlist, N, perms)
HM = -ansatz_hessian_matrix(Vlist, N, G, perms, flatten=True, unprojected=False)
HM = np.array(HM, dtype=complex)
evals = np.linalg.eigvalsh(HM)
num_neg  = np.sum(evals < -tol)
num_pos  = np.sum(evals >  tol)
num_zero = np.sum(np.abs(evals) <= tol)
print("min eigenvalue:", evals[0])
print("max eigenvalue:", evals[-1])
print("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)

"""mp.mp.dps = 20
A = mp.matrix([list(row) for row in HM])
evals = mp.eig(A, left=False, right=False)
for i, ev in enumerate(evals):
    print(f"λ[{i}] = {mp.nstr(ev, n=20)}")
tol = mp.mpf('1e-10')
real_evals = []
for ev in evals:
    real_evals.append(mp.re(ev))
num_neg  = sum(ev < -tol for ev in real_evals)
num_pos  = sum(ev >  tol for ev in real_evals)
num_zero = sum(abs(ev) <= tol for ev in real_evals)
print("min eigenvalue:", real_evals[-1])
print("max eigenvalue:", real_evals[0])
print("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)"""

min eigenvalue: -2.5313664932949775
max eigenvalue: 3.6070082689936704
num_neg, num_zero, num_pos: 18 2 28


'mp.mp.dps = 20\nA = mp.matrix([list(row) for row in HM])\nevals = mp.eig(A, left=False, right=False)\nfor i, ev in enumerate(evals):\n    print(f"λ[{i}] = {mp.nstr(ev, n=20)}")\ntol = mp.mpf(\'1e-10\')\nreal_evals = []\nfor ev in evals:\n    real_evals.append(mp.re(ev))\nnum_neg  = sum(ev < -tol for ev in real_evals)\nnum_pos  = sum(ev >  tol for ev in real_evals)\nnum_zero = sum(abs(ev) <= tol for ev in real_evals)\nprint("min eigenvalue:", real_evals[-1])\nprint("max eigenvalue:", real_evals[0])\nprint("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)'

In [268]:
from utils import (
	applyG_tensor, applyG_block_tensor,
	partial_trace_keep, antisymm_to_real, antisymm, I2, X, Y, Z,
	project_unitary_tangent, real_to_antisymm, partial_trace_keep_tensor
	)


CNOT = np.array([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, -1]])
def ansatz(Vlist):
    ret = np.kron(Vlist[1], Vlist[1]) @ CNOT @ np.kron(Vlist[0], Vlist[0])
    return ret

def ansatz_grad(V, U_tilde):    
    U_working1 = U_tilde.copy()
    U_working2 = U_tilde.copy()
    
    U_working1 = np.kron(I2, V) @ U_working1
    U_working2 = np.kron(V, I2) @ U_working2
    T1 = partial_trace_keep(U_working1, [0], 2)
    T2 = partial_trace_keep(U_working2, [1], 2)
    return T1 + T2

def ansatz_grad_vector(Vlist, G, flatten=True, unprojected=False):
    grad = []
    U_tilde1 = G.conj().T @ np.kron(Vlist[1], Vlist[1]) @ CNOT
    U_tilde2 = CNOT @ np.kron(Vlist[0], Vlist[0]) @   G.conj().T
            
    grad.append(ansatz_grad(Vlist[0], U_tilde1).conj().T)
    grad.append(ansatz_grad(Vlist[1], U_tilde2).conj().T)

    if unprojected:
        return grad
    # Project onto tangent space.
    if flatten:
        return np.stack([
            antisymm_to_real(antisymm(Vlist[j].conj().T @ grad[j]))
            for j in range(len(grad))
        ]).reshape(-1)
    else:
        return np.stack([
            antisymm_to_real(antisymm(Vlist[j].conj().T @ grad[j]))
            for j in range(len(grad))
        ])


def ansatz_hessian_matrix(Vlist, G, flatten=True, unprojected=False):
	eta = len(Vlist)
	Hess = np.zeros((eta, 4, eta, 4), dtype=complex)

	for k in range(eta):
		for j in range(4):
			if unprojected:
				Z = np.zeros((2, 2))
				Z.flat[j] = 1.0
			else:
				Z = np.zeros(4)
				Z[j] = 1
				Z = real_to_antisymm(np.reshape(Z, (2, 2)))

			dVZj = ansatz_hess(Vlist, Vlist[k] @ Z, k, G, unprojected=unprojected)
			
			for i in range(eta):
				Hess[i, :, k, j] = dVZj[i].reshape(-1) if unprojected else \
						antisymm_to_real(antisymm( Vlist[i].conj().T @ dVZj[i] )).reshape(-1)

	if flatten:
		return Hess.reshape((eta*4, eta*4))
	else:
		return Hess


def ansatz_hess(Vlist, Z, k, G, unprojected=False):
    dVlist = [None for i in range(len(Vlist))]

    # k = 0, i= 1
    for i in range(k+1, len(Vlist)):
        U_tilde = G.conj().T
        U_tilde = ansatz_grad_directed(Vlist[k], U_tilde, Z)
        U_tilde = CNOT @ U_tilde
        dVi = ansatz_grad(Vlist[i], U_tilde).conj().T
        dVlist[i] = dVi if unprojected else project_unitary_tangent(Vlist[i], dVi)

    # i runs over Vlist, k>i
    for i in range(k):
        U_tilde = CNOT.copy()
        U_tilde = ansatz_grad_directed(Vlist[k], U_tilde, Z)
        U_tilde = G.conj().T @ U_tilde
        dVi = ansatz_grad(Vlist[i], U_tilde).conj().T
        dVlist[i] = dVi if unprojected else project_unitary_tangent(Vlist[i], dVi)
        
	# i=k case.
    i = k
    U_tilde = CNOT.copy() if k==0 else np.eye(4)
    for j in range(i+1, len(Vlist)):
        U_tilde = np.kron(Vlist[j], Vlist[j]) @ U_tilde  
    U_tilde = G.conj().T @ U_tilde
    for j in range(i):
        U_tilde = np.kron(Vlist[j], Vlist[j]) @ U_tilde  
    U_tilde = (CNOT if k==1 else np.eye(4)) @ U_tilde
        
    G_ = ansatz_hess_single_layer(Vlist[k], Z, U_tilde).conj().T
    # Projection.
    if not unprojected:
        V = Vlist[k]
        G_ = project_unitary_tangent(V, G_)
        grad = ansatz_grad(V, U_tilde).conj().T
        G_ -= 0.5 * (Z @ grad.conj().T @ V + V @ grad.conj().T @ Z)
        if not np.allclose(Z, project_unitary_tangent(V, Z)):
            G_ -= 0.5 * (Z @ V.conj().T + V @ Z.conj().T) @ grad
    dVlist[k] = G_
    return dVlist

def ansatz_grad_directed(V, U_tilde, Z):
	G = np.zeros((2**2, 2**2), dtype=complex)
	for i in range(2):
		U = np.eye(2**2)
		for j in range(i):
			U = applyV(V, U, j)
		U = applyV(Z, U, i)
		for j in range(i+1, 2):
			U = applyV(V, U, j)
		G += U
	U_tilde = G @ U_tilde
	return U_tilde

def ansatz_hess_single_layer(V, Z, U_tilde_):
	G = np.zeros_like(V)
	for i in range(2):
		for z in range(2):
			U_tilde = U_tilde_.copy()
			if z==i:
				continue

			for j in range(i):
				if z==j:
					U_tilde = applyV(Z, U_tilde, j)
				else:
					U_tilde = applyV(V, U_tilde, j)

			for j in range(i+1, 2):
				if z==j:
					U_tilde = applyV(Z, U_tilde, j) # Reversed left to right multiplication here!
				else:
					U_tilde = applyV(V, U_tilde, j) # Reversed left to right multiplication here!

			# Take partial trace wrt all qubits but k, l.
			T = partial_trace_keep(U_tilde, [i], 2)
			G += T
	return G


def applyV(V, U, j):
    return np.kron(V, I2) @ U if j==0 else np.kron(I2, V) @ U

In [269]:

def f_(vlist, U):
    f_base = -np.trace(U.conj().T @ ansatz(vlist)).real
    return f_base
    

def grad_numerical(Glist, U, epsilon=1e-6, flatten=True):
    grads = []
    for _ in range(len(Glist)):
        W = Glist[_]
        d = Glist[_].shape[0]
        grad_complex = np.zeros((d, d), dtype=complex)

        for i in range(d):
            for j in range(d):
                # Real perturbation
                dW_real = np.zeros_like(W, dtype=complex)
                dW_real[i, j] = epsilon

                Glist_plus_real  = Glist[:_] + [W + dW_real] + Glist[_+1:]
                Glist_minus_real = Glist[:_] + [W - dW_real] + Glist[_+1:]
                
                f_plus  = f_(Glist_plus_real, U)
                f_minus = f_(Glist_minus_real, U)
                df_real = (f_plus - f_minus) / (2 * epsilon)

                # Imaginary perturbation
                dW_imag = np.zeros_like(W, dtype=complex)
                dW_imag[i, j] = 1j * epsilon

                Glist_plus_imag  = Glist[:_] + [W + dW_imag] + Glist[_+1:]
                Glist_minus_imag = Glist[:_] + [W - dW_imag] + Glist[_+1:]
                
                f_plus  = f_(Glist_plus_imag, U)
                f_minus = f_(Glist_minus_imag, U)
                df_imag = (f_plus - f_minus) / (2 * epsilon)
    
                grad_complex[i, j] = df_real + 1j * df_imag
        grads.append(grad_complex)
    
    stack = np.stack([ antisymm_to_real(antisymm(W.conj().T @ grads[j])) for j, W in enumerate(Glist)])
    if flatten:
        return stack.reshape(-1)
    return stack


np.linalg.norm(grad_numerical(list(Vlist_reduced), U, epsilon=1e-6)+ansatz_grad_vector(list(Vlist_reduced), U))

1.7565780911329416e-10

In [270]:

def numerical_hessian(Glist, U, i, j, epsilon=1e-6):
    """Numerically compute d/dW1 of projected gradient dL/dV1 (Riemannian)."""
    numerical_H = []

    for _ in range(4):
        Z_real = np.zeros(4)
        Z_real[_] = 1.0
        Z = real_to_antisymm(Z_real.reshape(2, 2))

        Gj_plus  = Glist[j] @ scipy.linalg.expm(+epsilon*Z)
        Gj_minus = Glist[j] @ scipy.linalg.expm(-epsilon*Z)

        if i==j:
            grad_plus  = ansatz_grad_vector(Glist[:j]+[Gj_plus]+Glist[j+1:], U, unprojected=True, flatten=False)[i]
            grad_minus = ansatz_grad_vector(Glist[:j]+[Gj_minus]+Glist[j+1:],  U, unprojected=True, flatten=False)[i]
            dgrad = (grad_plus - grad_minus) / (2 * epsilon)  # shape (16,)
            G = dgrad.reshape(2, 2)

            V = Glist[j]
            Z = V @ Z
            G = project_unitary_tangent(V, G)
            grad = ansatz_grad_vector(Glist, U, flatten=False, unprojected=True)[i]
            G -= 0.5 * (Z @ grad.conj().T @ V + V @ grad.conj().T @ Z)
            if not np.allclose(Z, project_unitary_tangent(V, Z)):
                G -= 0.5 * (Z @ V.conj().T + V @ Z.conj().T) @ grad
            G = antisymm_to_real(antisymm( V.conj().T @ G ))
        else:
            grad_plus  = ansatz_grad_vector(Glist[:j]+[Gj_plus]+Glist[j+1:], U, unprojected=False, flatten=False)[i]
            grad_minus = ansatz_grad_vector(Glist[:j]+[Gj_minus]+Glist[j+1:], U, unprojected=False, flatten=False)[i]
            dgrad = (grad_plus - grad_minus) / (2 * epsilon)  # shape (16,)
            G = dgrad.reshape(2, 2)
            
        numerical_H.append(G)
    
    return np.array(numerical_H)  # shape: (4, 4, 4)

i, j = 1, 1
Glist = [random_unitary_2() for _ in range(2)]
H = ansatz_hessian_matrix(Glist, U, unprojected=False, flatten=False)
grad = []
for _ in range(4):
    grad.append(H[i, :, j, _].reshape(2,2))
analytical = np.array(grad)

numerical = numerical_hessian(Glist, U, i, j)

print("Difference norm:", np.linalg.norm(numerical - analytical))

Difference norm: 2.6826186046417294e-10


In [280]:
from rqcopt.trust_region import riemannian_trust_region_optimize
from utils import (polar_decomp, real_to_antisymm, real_to_skew, 
    separability_penalty, grad_separability_penalty, reduce_list)
import scipy

def f_(vlist, U):
    f_base = -np.trace(U.conj().T @ ansatz(vlist)).real
    return f_base
    

def optimize(U, Vlist_start, **kwargs):
    n = len(Vlist_start)
    def f(vlist):
        f_base = f_(vlist, U)
        print("Real part: ", f_base)
        return f_base

    def gradfunc(vlist):
        gradfunc1 = -ansatz_grad_vector(vlist, U, flatten=False).real
        magn = [np.linalg.norm(G, ord=2) for G in gradfunc1]
        print(magn)
        print(np.sum(np.array(magn)))
        return gradfunc1.reshape(-1)

    def hessfunc(vlist):
        hessfunc1 = -ansatz_hessian_matrix(vlist, U, flatten=False).real
        return hessfunc1.reshape((n*4, n*4))
        #return np.zeros((n*4, n*4))

    def errfunc(vlist): 
        return f(vlist)

    kwargs["gfunc"] = errfunc
    Vlist, f_iter, err_iter = riemannian_trust_region_optimize(
        f, retract_unitary_list, gradfunc, hessfunc, np.stack(Vlist_start), **kwargs)
    return Vlist, f_iter, err_iter


def retract_unitary_list(vlist, eta):
    n = len(vlist)
    eta = np.reshape(eta, (n, 2, 2))
    dvlist = [vlist[j] @ real_to_antisymm(eta[j]) for j in range(n)]
    return np.stack([polar_decomp(vlist[j] + dvlist[j])[0] for j in range(n)])

In [307]:
def random_unitary_4():
    Z = (np.random.randn(4,4) + 1j*np.random.randn(4,4)) / np.sqrt(2)
    Q, R = np.linalg.qr(Z)
    diag = np.diag(R)
    lambdas = diag / np.abs(diag)
    Q = Q * lambdas
    return Q

def random_unitary_2():
    Z = (np.random.randn(2,2) + 1j*np.random.randn(2,2)) / np.sqrt(2)
    Q, R = np.linalg.qr(Z)
    diag = np.diag(R)
    lambdas = diag / np.abs(diag)
    Q = Q * lambdas
    return Q

#U = random_unitary_4()
errs = []
for _ in range(1):
    Vlist_reduced = [random_unitary_2(), random_unitary_2()]
    #U = ansatz(Vlist_reduced)
    Vlist_trap, f_iter, err_iter = optimize(U, Vlist_reduced, niter=100,
                                               rho_trust=1e-1, radius_init=0.01, maxradius=0.1,
                                               tcg_abstol=1e-12, tcg_reltol=1e-10, tcg_maxiter=100
                                              )
    errs.append(err_iter[-1])

Real part:  -0.7724675248632049
[1.0223967231144382, 1.4664561184160845]
2.4888528415305227
Real part:  -0.7724675248632049
Real part:  -0.7926302188488372
Real part:  -0.7926302188488372
[1.014641810750247, 1.4498483368579993]
2.464490147608246
Real part:  -0.7926302188488372
Real part:  -0.8322544659605999
Real part:  -0.8322544659605999
[0.9979380659901527, 1.414888527626892]
2.4128265936170448
Real part:  -0.8322544659605999
Real part:  -0.9085158533290703
Real part:  -0.9085158533290703
[0.959640436262075, 1.3382604007092769]
2.297900836971352
Real part:  -0.9085158533290703
Real part:  -1.0478217831053231
Real part:  -1.0478217831053231
[0.8647095739869105, 1.1630252190046462]
2.0277347929915566
Real part:  -1.0478217831053231
Real part:  -1.1963144138902704
Real part:  -1.1963144138902704
[0.7322835979509451, 0.9456267323556438]
1.677910330306589
Real part:  -1.1963144138902704
Real part:  -1.319670912962445
Real part:  -1.319670912962445
[0.6378165335639506, 0.8288911421101037]

In [308]:
tol=1e-8
HM = -ansatz_hessian_matrix(Vlist_trap, U, flatten=True, unprojected=False).real
HM = np.array(HM, dtype=complex)
evals = np.linalg.eigvalsh(HM)
num_neg  = np.sum(evals < -tol)
num_pos  = np.sum(evals >  tol)
num_zero = np.sum(np.abs(evals) <= tol)
print("min eigenvalue:", evals[0])
print("max eigenvalue:", evals[-1])
print("num_neg, num_zero, num_pos:", num_neg, num_zero, num_pos)
evals

min eigenvalue: -4.271004124110731e-16
max eigenvalue: 8.000000000000004
num_neg, num_zero, num_pos: 0 3 5


array([-4.27100412e-16,  4.16168549e-16,  1.27512487e-15,  4.00000000e+00,
        4.00000000e+00,  4.00000000e+00,  4.00000000e+00,  8.00000000e+00])

In [289]:
from sympy import *
import numpy as np


t1, p1, l1, a1, t2, p2, l2, a2 = symbols('t1 p1 l1 a1 t2 p2 l2 a2', real=True)
θ = (t1, p1, l1, a1, t2, p2, l2, a2)

def U2(theta, phi, lam, alpha):
    c = cos(theta/2)
    s = sin(theta/2)
    eip   = exp(I*phi)
    eilam = exp(I*lam)
    eial  = exp(I*alpha)

    su2 = Matrix([
        [      c,         -eilam*s],
        [ eip*s,  eip*eilam*c    ]
    ])
    return eial * su2   # full U(2); drop eial for SU(2)

def W(theta_params):
    t1, p1, l1, a1, t2, p2, l2, a2 = theta_params
    A1 = U2(t1, p1, l1, a1)
    A2 = U2(t2, p2, l2, a2)
    K1 = kronecker_product(A1, A1)
    K2 = kronecker_product(A2, A2)
    return K2 * CNOT * K1

U_sym = Matrix(U)
f_complex = -trace(U_sym.H * W(θ))
f_real = re(f_complex).expand(complex=True)

params = [t1, p1, l1, a1, t2, p2, l2, a2]
grad = [f_real.diff(p) for p in params]

sol = nsolve(grad, params, [0.1]*8, tol=1e-8)
print(sol)


KeyboardInterrupt

