In [9]:
import numpy as np
from scipy.stats import unitary_group
from scipy.linalg import eig, expm
from scipy.optimize import minimize
from functools import reduce 
from xmps.spin import U4
from xmps.tensor import rotate_to_hermitian

# tm_left:
"""
                   0      0      j
                   | -U2- |      |
A[i, σ1,σ2, j] =   |      | -U1- |
                   |      |      |
                   i      σ1      σ2
                   
"""

# tm_right:
"""
                   i      0      0
                   |      | -U2- |
A[i, σ1,σ2, j] =   | -U1- |      |
                   |      |      |
                   σ1     σ2     j
"""

Z = np.array([
    [1,0],
    [0,-1]
])

X = np.array([
    [0,1],
    [1,0]
])


I = np.eye(2)

def tensor(tensors):
    return reduce(lambda t1,t2: np.kron(t1,t2), tensors)

# Entangling Gate HcX to put the right environment directly in
E = np.array([
    [1, 0, 0, 1],
    [0, 1, 1, 0],
    [0, 1, -1,0],
    [1, 0, 0,-1]
]).reshape(2,2,2,2) / np.sqrt(2)

In [20]:
def state(U2,U1):
    return np.einsum(
        U1, [1,2,4,5],
        U2, [0,4,6,7],
        E, [5,3,8,9],
        [0,1,2,3,6,7,8,9]
    )

def Map(U2, U1):
    """
    A[i,σ1,σ2,j] = Σα U1[σ1,σ2,α,j]U2[i,α,0,0]
    """
    return np.transpose(np.tensordot(
        U1.reshape(2,2,2,2),
        U2[:,0].reshape(2,2),
        (2,1)
    ),(3,0,1,2)).reshape(2,4,2)

def single_transfer_matrix(s, s̄):
    s̄ = s̄.conj()
    ss̄ = np.transpose(np.tensordot(s̄,s,(1,1)),(0,2,1,3)).reshape(4,4)
    return ss̄

def right_env(U2,U1,Ū2,Ū1):
    s = Map(U2, U1)
    s̄ = Map(Ū2, Ū1)
    tm = single_transfer_matrix(s, s̄)
    η, v = eig(tm)
    vr = v[:,np.argmax(np.abs(η))].reshape(2,2)
    vr = rotate_to_hermitian(vr) / np.sign(vr[0,0])
    return vr, η

def left_env(U2,U1,Ū2,Ū1):
    s = Map(U2, U1)
    s̄ = Map(Ū2, Ū1)
    tm = single_transfer_matrix(s, s̄)
    η, v = eig(tm.T)
    vl = v[:,np.argmax(np.abs(η))].reshape(2,2)
    vl = rotate_to_hermitian(vl) / np.sign(vl[0,0])
    return vl, η

def overlap(U2,U1,Ū2,Ū1, Ut):
    s = state(U2.reshape(2,2,2,2), U1.reshape(2,2,2,2)).reshape(16,16)
    s̄ = state(Ū2.reshape(2,2,2,2), Ū1.reshape(2,2,2,2)).reshape(16,16)
    vl, _ = left_env(U2,U1,Ū2,Ū1)
    vr, _ = right_env(U2,U1,Ū2,Ū1)
    
    middle = tensor([vl,Ut,vr])
    
    unitary = s̄.conj().T @ middle @ s
    o = (2*np.abs(unitary[0,0]))**2
    return o

U2 = unitary_group.rvs(4)
U1 = unitary_group.rvs(4)

s = state(U2.reshape(2,2,2,2), U1.reshape(2,2,2,2)).reshape(16,16)
assert np.allclose(s @ s.conj().T, np.eye(16))

o = overlap(U2, U1, U2, U1, np.eye(4))
print(o)

0.5852026780464697


In [21]:
right_env(U2, U1, U2, U1)

(array([[ 0.5702766 +0.00000000e+00j, -0.43357706+1.36191630e-01j],
        [-0.43357705-1.36191630e-01j,  0.51157614-7.95521652e-25j]]),
 array([ 1.00000000e+00+1.38642701e-24j, -8.52176614e-14+1.06914537e-12j,
         8.52516132e-14-1.06911516e-12j, -4.27268806e-17-3.02172034e-17j]))

In [22]:
left_env(U2, U1, U2, U1)

(array([[ 7.07106781e-01+0.00000000e+00j, -9.32242394e-17-1.22050156e-16j],
        [-9.32242397e-17+1.22050156e-16j,  7.07106781e-01+7.21292788e-25j]]),
 array([ 1.00000000e+00-1.95894453e-25j,  4.92057152e-17+6.24227575e-17j,
        -2.63016617e-17-4.76042718e-17j, -1.17190495e-16-1.48184858e-17j]))