In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import scipy.integrate as si
import scipy.optimize as so
import scipy.linalg as sl

import time

from numba import njit
from tqdm import trange

In [None]:
P = np.load('casscf22_s2_heh+_sto-3g_tensor.npz')
dimat = np.array([[-1.0724436,0.0,1.6277593,0.2589907],[0.0,0.1114341,0.0,0.0],[1.6277593,0.0,-0.2712937,0.0676768],[0.2589907,0.0,0.0676768,1.6780414]])

In [None]:
trajectory_number = 135
mol = 'heh+'
method = 'tdcasscf'
prefix = './dt0.008268/'
trajnum = trajectory_number
traj = str(trajnum).zfill(3)
fname = prefix + 'time_coeffs.' + method + '_' + mol + '_sto-3g_' + traj + '_dt=0.008268au.txt'
tdcicoeffs = np.loadtxt(fname,dtype=np.complex128)

In [None]:
ham = np.load('casscf22_s2_heh+_sto-3g_hamiltonian.npz')
ham = ham - np.diag([np.min(ham)]*ham.shape[0])

In [None]:
runfile = np.load(prefix+'tdcasscf_heh+_sto-3g_'+str(trajnum).zfill(3)+'_dt=0.008268au.npz')
mydt = runfile['dt_au'].item()
print(mydt)

In [None]:
offset = 0
tvec = np.arange(offset,offset+tdcicoeffs.shape[0])*runfile['dt_au']
tmeoff = runfile['ncyc']*2*np.pi/runfile['freq']
ef = (tvec>=0)*(tvec<=tmeoff)*(runfile['emax'])*np.sin(runfile['freq']*tvec)

In [None]:
shape = 4
#shape = 16
hamCI = np.expand_dims(ham,0) - np.einsum('i,jk->ijk',ef,dimat)
hamCI = (1+0j)*hamCI
numhamCI = hamCI.shape[0]
alldd = np.zeros((numhamCI,shape),dtype=np.float64)
allvv = np.zeros((numhamCI,shape,shape),dtype=np.complex128)
allprop = np.zeros((numhamCI,shape,shape),dtype=np.complex128)
for i in range(numhamCI):
    dd, vv = np.linalg.eigh(hamCI[i,:,:])
    alldd[i,:] = dd
    allvv[i,:,:] = vv
    allprop[i,:,:] = vv @ np.diag(np.exp(-1j*mydt*alldd[i,:])) @ vv.conj().T
        

In [None]:
newtdcicoeffs = np.zeros((tdcicoeffs.shape[0],shape),dtype=np.complex128)
newtdcicoeffs[0,0] = 1.0 
for i in range(newtdcicoeffs.shape[0]-1):
    newtdcicoeffs[i+1,:] = sl.expm(-1j*(ham-ef[i]*dimat)*runfile['dt_au']) @ newtdcicoeffs[i,:]

bigtens = P
S = np.array([[1.0,0.538415],[0.538415,1.0]])
S == S.T

newrdmAO_p = np.einsum('ni,nj,ijab->nab',newtdcicoeffs,np.conjugate(newtdcicoeffs),bigtens)
traces_p = np.einsum('ijj->i', newrdmAO_p@S)
print(np.mean(np.abs(traces_p)))

In [None]:
newrdmAO = np.einsum('ni,nj,ijab->nab',newtdcicoeffs,np.conjugate(newtdcicoeffs),bigtens)
tdciden = np.einsum('ni,nj->nij',newtdcicoeffs,np.conj(newtdcicoeffs))

In [None]:
print( np.linalg.norm(np.einsum('nij,njk->nik',tdciden,tdciden) - tdciden) )

In [None]:
bigtens = bigtens.reshape((4**2,2**2)).astype(np.complex128)
matmulrdmAO = np.einsum('ij,jk->ik',np.transpose(tdciden,axes=[0,1,2]).reshape((-1,4*4)), bigtens).reshape((-1,2,2))

print( np.mean(np.abs(matmulrdmAO - newrdmAO)) )

In [None]:
# matrix to convert an (nxn) Hermitian matrix to its vectorized form
drcCI = newtdcicoeffs.shape[1]
drc = 2
def mat2vec(n):
    # real part
    realmat = np.zeros((n**2,(n+1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<=j:
                col = i*n + j - i*(i+1)//2
            else:
                col = j*n + i - j*(j+1)//2
            realmat[row,col]=1
    j = 0
    i = 0
    # taking care of constant trace and we had no idea all along
    while j < (realmat.shape[1] - 1):
        realmat[-1,j] = -1
        j += n - i*1
        i += 1
    #
    # imaginary part
    imagmat = np.zeros((n**2,(n-1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<j:
                col = i*n + j - (i+1)*(i+2)//2
                imagmat[row,col]=1
            if i>j:
                col = j*n + i - (j+1)*(j+2)//2
                imagmat[row,col]=-1
    symmat = np.hstack([realmat, 1j*imagmat])
    return symmat

smat = mat2vec(drcCI)

In [None]:
def ind_to_pair_upper(n, i, j):
    if i < j:
        return i * (2 * n - i - 1) // 2 + j
    else:
        return j * (2 * n - j - 1) // 2 + i
n = drcCI 

upper_mapping = {}

#this is the diag upper representation
for i in range(n):
    for j in range(i, n):
        k = ind_to_pair_upper(n, i, j)
        upper_mapping[(i,j)] = k
        
def ind_to_pair_upper(n, i, j):
    if i < j:
        return i * (2 * n - i - 1) // 2 + j
    # else:
    #     return j * (2 * n - j - 1) // 2 + i
n = drcCI 

upper_wo_diag_mapping = {}

#this is the offdiag upper representation
k=10
for i in range(n):
    for j in range(i+1, n):
        upper_wo_diag_mapping[(i,j)] = k
        k+=1

In [None]:
good_cols = np.array([ 0, 2, 3])
zero_cols = np.array([ 1])
bad_inds = []
del_upper = 0
for key in list(upper_mapping.keys()):
    for c in zero_cols:
        if c in key:
            #print(key, ' bad')
            bad_inds.append(upper_mapping[key])
            del_upper+=1
            break
print(del_upper)

In [None]:
del_lower = 0
for key in list(upper_wo_diag_mapping.keys()):
    for c in zero_cols:
        if c in key:
            #print(key, ' bad')
            bad_inds.append(upper_wo_diag_mapping[key])
            del_lower+=1
            break
print(del_lower)

In [None]:
split = drcCI*(drcCI+1)//2-del_upper

In [None]:
good_inds = np.delete(np.arange(0,drcCI**2),bad_inds)
total_inds = good_inds.tolist()
total_inds.extend(bad_inds)
total_inds.sort()

print( all(total_inds == np.arange(0,drcCI**2)) )

In [None]:
magicind = drcCI*(drcCI+1)//2 - 1
good_inds_del = good_inds[good_inds!=magicind]

In [None]:
ell = ells[0]

In [None]:
import cupy as cp

In [None]:
bigtensCP = cp.asarray(bigtens)

In [None]:
bigtensTCP = cp.asarray(bigtens.T)

In [None]:
allpropCP = cp.asarray(allprop)

In [None]:
smatCP = cp.asarray(smat)

In [None]:
newrdmAOCP = cp.asarray(newrdmAO)

In [None]:
numsteps = 20001
strides = [2,3,4,5,6,7,8]
#strides = [10]
ells = np.array(strides)*160
print(ells)
MSEs = []
cond_nums = []
myrdmAOs = []
for k in range(len(strides)):
    stride = strides[k]
    ell = ells[k]
    print(stride,ell)
    myrdmAO = cp.zeros((numsteps, drc**2), dtype=cp.complex128)
    #myrdmAO stores the tranpose of the true rdmAO
    myrdmAO[:ell+1,:] = cp.transpose(cp.asarray(newrdmAO[:ell+1,:,:]),(0,2,1)).reshape((-1,drc**2))
    for j in range(ell,numsteps-1):
        if j % 2000 == 0:
            print(j)
        # allCmat = cp.zeros((ell,drcCI,drcCI), dtype=cp.complex128)

        bigmat = cp.zeros((((drc**2)*(ell+1)),drcCI**2), dtype=cp.complex128)
        bigmat[:drc**2,:] = bigtensTCP
        for i in range(1,ell+1):
            myexp = allpropCP[j-i,:,:]
            if i==1:
                Cmat = myexp
            else:
                Cmat = Cmat @ myexp
            Amat = Cmat.conj().T
            if i % stride != 0:
                continue
            bigmat[i*(drc**2):(i+1)*(drc**2),:] = bigtensTCP @ cp.kron( Cmat.T, Amat )
            
        bigmat_strided = bigmat.reshape((ell+1,drc**2,drcCI**2))
        bigmat_strided = bigmat_strided[::stride,:,:].reshape((-1,drcCI**2))
        
        btrue = cp.flipud(myrdmAO[j-ell:(j+1),:][::stride]).reshape((-1))
        mprime = bigmat_strided @ smatCP

        # monitor singular values
        _, ss, _ = cp.linalg.svd(mprime[:,good_inds_del]) # , compute_uv=False)
        if cp.min(ss) < 1e-15:
            print("Warning: singular value < 1e-15 detected at time step " + str(j))
            break

        # reconstruct full TDCI density
        xxapprox = cp.real( cp.linalg.pinv(mprime[:,good_inds_del],1e-12) @ (btrue - bigmat_strided[:,-1]) )
        xxapprox2 = cp.concatenate([xxapprox[:split-1],cp.array([1.0]),-xxapprox[split-1:]])

        recon = (smatCP[:,good_inds] @ xxapprox2).reshape((drcCI,drcCI))
        #assert (recon == recon.conjugate().transpose()).all()

        # propagate in full TDCI density space via one step of MMUT!
        reconprop = allpropCP[j,:,:] @ recon @ allpropCP[j,:,:].conj().T

        # compute new rdm
        myrdmAO[j+1,:] = (reconprop.reshape((-1)) @ bigtensCP).conj()
    MSE = cp.mean(cp.square(myrdmAO[ell+1:numsteps].reshape((-1,drc,drc)).conj() - newrdmAOCP[ell+1:numsteps])).item()
    cond_num = cp.max(ss)/cp.min(ss)
    MSEs.append(MSE)
    cond_nums.append(cond_num)
    myrdmAOs.append(myrdmAO)
    print('Stride: '+str(stride)+ ' MSE: ' +str(MSE))

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(16, 14))
icnt=0
jcnt=0
for i in range(2):
    for j in range(2):
        axs[i,j].plot(np.real(myrdmAO.get()[:numsteps,:].reshape((numsteps,2,2))[:,i,j]),color='red')
        axs[i,j].plot(np.real(newrdmAO[:numsteps,i,j]),color='black')
        axs[i,j].set_title('Re(P'+str(i)+str(j)+')')
        axs[i,j].set_xlabel('t (a.u.)')
        jcnt+=1
    icnt+=1
plt.legend(['Memory Model','Ground Truth'])
fig.suptitle('Real HeH+ in 6-31G')
plt.tight_layout()


In [None]:
fig, axs = plt.subplots(4, 4, figsize=(16, 14))
icnt=0
jcnt=0
for i in range(4):
    for j in range(4):
        axs[i,j].plot(np.imag(myrdmAO.get()[:numsteps,:].conj()).reshape((numsteps,4,4))[:,i,j],color='red')
        axs[i,j].plot(np.imag(newrdmAO[:numsteps,i,j]),color='black')
        axs[i,j].set_title('Im(P'+str(i)+str(j)+')')
        axs[i,j].set_xlabel('t (a.u.)')
        jcnt+=1
    icnt+=1
plt.legend(['Memory Model','Ground Truth'])
fig.suptitle('Imaginary HeH+ in 6-31G')
plt.tight_layout()
