In [16]:
import numpy as np
from functools import reduce
from scipy.stats import unitary_group

def depth3tensor(U3,U2,U1):
    """
    Produce a tensor A[(σ1σ2),(i0i1),(j0j1)]:
    A[(σ1σ2),(i0i1),(j0j1)] = Σαβ U1[σ1, σ2, α, j1]U2[i1,α,β,j0]U3[i0,β,0,0]  
               Index Number =        0   1      2     3      4     5   6  7
    """
    temp = np.tensordot(
        U2.reshape(2,2,2,2),
        U3.reshape(2,2,2,2),
        [2,1]
    )
    
    A = np.tensordot(
        U1.reshape(2,2,2,2),
        temp,
        [2,1]
    ).transpose([0,1,5,3,4,2,6,7]).reshape(4,4,4,4)[...,0]
    
    return A

# U5,U4,U3,U2,U1... = U4U5, U3U4, 
def depthNtensor(N_Us):
    # depth N
    N = len(N_Us)
    N_Us = map(lambda a: a.reshape(2,2,2,2), N_Us)
    # number indices after contraction
    ni = 2+2*N
    
    # bond dimension
    d = 2**(N - 1)
    
    # corresponds to the iₐ indices - reversed list of odd indices between 3 and ni-3 inclusive 
    i_indices = [i for i in reversed(range(3,ni-2,2))]
    
    # corresponds to the jₐ indices - reversed list of even indices between 2 and ni-4 inclusive
    j_indices = [j for j in reversed(range(2,ni-3,2))]
    
    # add the physical indices, [0,1], to the auxillary indices, and then the indices set to 0s
    indices = [0,1] + i_indices + j_indices + [ni-2, ni-1]
    
    # Tensordot all the tensors together in the order:
    #    1: UN-1, UN = A1
    #    2: UN-2, A1 = A2
    #    3  UN-3, A2 = A3, ...
    # Then reshape into the index order specified by indices
    A = reduce(lambda a,b: np.tensordot(b, a, [2,1]), N_Us).transpose(*indices).reshape(4,d,d,4)[...,0]
    
    return A
    

U1 = unitary_group.rvs(4)
U2 = unitary_group.rvs(4)
U3 = unitary_group.rvs(4)

def tensor(U2, U1):
    """
    Produce a tensor A[(σ1σ2),i,j]:
    
    A[σ1σ2,i,j] = Σα U1[σ1,σ2,α,j]U2[i,α,0,0]
    """
    
    return np.transpose(np.tensordot(
        U1.reshape(2,2,2,2),
        U2[:,0].reshape(2,2),
        (2,1)
    ),(0,1,3,2)).reshape(4,2,2)


print(tensor(U2,U1) - depthNtensor([U2,U1]))

[[[0.+0.j 0.+0.j]
  [0.+0.j 0.+0.j]]

 [[0.+0.j 0.+0.j]
  [0.+0.j 0.+0.j]]

 [[0.+0.j 0.+0.j]
  [0.+0.j 0.+0.j]]

 [[0.+0.j 0.+0.j]
  [0.+0.j 0.+0.j]]]


In [7]:
list(range(2,10-3,2))

[2, 4, 6]