In [1]:
import numpy as np
from scipy.linalg import schur, eigh,expm
from time import time
from randdiag import *

In [2]:
def Generate_M():
    Lambda_real = np.random.normal(size=[4,4])
    Lambda_im = np.random.normal(size=[4,4])
    Lambda = Lambda_real + Lambda_im * 1j
    G = (Lambda+Lambda.T.conjugate())/(4*np.sqrt(2))
    return G

def Generate_Haar(n):
    Lambda_real = np.random.normal(size=[n,n])
    Lambda_im = np.random.normal(size=[n,n])
    Lambda = Lambda_real + Lambda_im * 1j
    U,_ = np.linalg.qr(Lambda)
    return U

def Generate_U(j,L):
    M = Generate_M()
    M = expm(1j*M)
    left = np.eye(2**(j-1))
    right = np.eye(2**(L-j-1))
    return np.kron(np.kron(left,M),right)
def offdiagonal_frobenius(A):
    loss = np.linalg.norm(A - np.diag(np.diagonal(A)),'fro')
    return loss

In [3]:
L = 11
U_0 = Generate_Haar(2)
for i in range (L-1):
    U_0 = np.kron(U_0, Generate_Haar(2))
    
permuted = np.random.permutation(range(1,L))
U_int = np.eye(2**L)
for j in permuted:
    U_int = U_int @ Generate_U(j,L)
U = U_int @ U_0
print(U.shape)
repeats = 100

rt_schur = 0
err_schur = []
for _ in range(repeats):
    start = time()
    T,Z = schur(U, 'complex')
    rt_schur+=time()-start
    err_schur.append(offdiagonal_frobenius(Z.conj().T @ U @Z))
    #print(np.linalg.norm(Z.conj().T @ Z -np.eye(n)))
mean_schur, std_schur, min_schur, max_schur = report_stats(err_schur)
print("Schur:\n Rum time {:.2f}, Mean: {:.2e}, Std: {:.2e}, Min: {:.2e}, Max: {:.2e}".\
              format( rt_schur / repeats,  mean_schur, std_schur, min_schur, max_schur))

rt_rjd = 0
err_rjd = []
for _ in range(repeats):
            start = time()
            H = (U+U.conj().T) / 2; S = (U-U.conj().T) / 2
            AA = np.array([H,1j*S])
            mu = np.random.normal(0,1,2)
            A_mu = mu[0] * H + mu[1] * 1j*S
            _, Q = eigh(A_mu)
            #print(np.linalg.norm(Q.conj().T @ Q -np.eye(n)))
            rt_rjd += time()-start
            err_rjd.append(offdiagonal_frobenius(Q.conj().T @ U @ Q))
mean_rjd, std_rjd, min_rjd, max_rjd = report_stats(err_rjd)
print("RandDiag:\n Rum time {:.2f}, Mean: {:.2e}, Std: {:.2e}, Min: {:.2e}, Max: {:.2e}".\
              format( rt_rjd / repeats,  mean_rjd, std_rjd, min_rjd, max_rjd))


(2048, 2048)
Schur:
 Rum time 19.57, Mean: 8.86e-13, Std: 4.04e-28, Min: 8.86e-13, Max: 8.86e-13
RandDiag:
 Rum time 4.14, Mean: 2.44e-09, Std: 7.51e-09, Min: 1.26e-10, Max: 5.83e-08
