In [1]:
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, lax

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import scipy.integrate as si
import scipy.optimize as so
import scipy.linalg as sl

import time

from tqdm import trange
import os 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

In [42]:
mol = 'heh+'
basis = 'sto-3g'

In [46]:
# construct prefix used to load and save files
if basis=='sto-3g':
    prefix='casscf22_s2_'
elif basis=='6-31g':
    prefix='casscf24_s15_'
else:
    print("Error: basis set not recognized! Must choose either sto-3g or 6-31g")
    sys.exit(1)

In [47]:
P = np.load('./logfiles/'+prefix+mol+'_'+basis+'_tensor.npy')
dimat = np.load('./logfiles/'+prefix+mol+'_'+basis+'_CI_dimat.npz')

In [48]:
trajshape0 = 24218
mydt = 0.008268

# E field parameters
freq = 1.5
ncyc = 5
emax = 0.5

In [49]:
# load diagonalized CI Hamiltonian
ham = np.load('./logfiles/'+prefix+mol+'_'+basis+'_hamiltonian.npy')
drcCI = ham.shape[0]

# why did we do this?
ham = ham - np.diag([np.min(ham)]*drcCI)

In [50]:
offset = 0
tvec = np.arange(offset,offset+trajshape0)*mydt
tmeoff = ncyc*2*np.pi/freq
ef = (tvec>=0)*(tvec<=tmeoff)*emax*np.sin(freq*tvec)

In [51]:
# here we compute all one-step propagators!

# the hamCI Hamiltonian consists of two pieces:
# 1) the core CI diagonal matrix (loaded from Gaussian)
# 2) an electric field term (amplitude * dipole moment matrix in z direction)

# to compute each propagator, we diagonalize hamCI at each time step

hamCI = np.expand_dims(ham,0) - np.einsum('i,jk->ijk',ef,dimat)
hamCI = (1+0j)*hamCI
alldd = np.zeros((trajshape0,drcCI),dtype=np.float64)
allvv = np.zeros((trajshape0,drcCI,drcCI),dtype=np.complex128)
allprop = np.zeros((trajshape0,drcCI,drcCI),dtype=np.complex128)
for i in range(trajshape0):
    dd, vv = np.linalg.eigh(hamCI[i,:,:])
    alldd[i,:] = dd
    allvv[i,:,:] = vv
    allprop[i,:,:] = vv @ np.diag(np.exp(-1j*mydt*alldd[i,:])) @ vv.conj().T
    

In [52]:
newtdcicoeffs = np.zeros((trajshape0,drcCI),dtype=np.complex128)
newtdcicoeffs[0,0] = 1.0 
for i in range(trajshape0-1):
    newtdcicoeffs[i+1,:] = allprop[i,:,:] @ newtdcicoeffs[i,:]

bigtens = P

# load overlap matrix from disk
goodlines = []
with open('./logfiles/'+prefix+mol+'_'+basis+'.log','r') as f:
    startmat = 0
    cnt = 0
    for line in f:
        if line.rstrip() == ' *** Overlap ***':
            startmat = 1
        if startmat==1:
            if line.rstrip() == ' *** Kinetic Energy ***':
                break
            if cnt >= 2:
                goodlines.append(line.rstrip())
            cnt += 1

drc = len(goodlines)
S = np.zeros((drc, drc))
for j in range(drc):
    for k in range(j+1):
        myarr = np.fromstring(goodlines[j].replace('D','e'), sep=' ')
        S[j,k] = myarr[k+1]

for j in range(drc):
    for k in range(j+1,drc):
        S[j,k] = S[k,j]

In [53]:
newrdmAO = np.einsum('ni,nj,ijab->nab',newtdcicoeffs,np.conjugate(newtdcicoeffs),bigtens)
traces_p = np.einsum('ijj->i', newrdmAO@S)
print(np.mean(np.abs(traces_p)))

1.9999975931544807


In [54]:
tdciden = np.einsum('ni,nj->nij',newtdcicoeffs,np.conj(newtdcicoeffs))
print( np.linalg.norm(np.einsum('nij,njk->nik',tdciden,tdciden) - tdciden) )

5.149606664557452e-12


In [55]:
bigtens = bigtens.reshape((drcCI**2,drc**2)).astype(np.complex128)
matmulrdmAO = np.einsum('ij,jk->ik',
                        np.transpose(tdciden,axes=[0,1,2]).reshape((-1,drcCI**2)),
                        bigtens).reshape((-1,drc,drc))

print( np.mean(np.abs(matmulrdmAO - newrdmAO)) )

0.0


In [56]:
ells = np.arange(70,72,2)

In [57]:
# matrix to convert an (nxn) Hermitian matrix to its vectorized form
def mat2vec(n):
    # real part
    realmat = np.zeros((n**2,(n+1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<=j:
                col = i*n + j - i*(i+1)//2
            else:
                col = j*n + i - j*(j+1)//2
            realmat[row,col]=1
    j = 0
    i = 0
    # taking care of constant trace and we had no idea all along
    while j < (realmat.shape[1] - 1):
        realmat[-1,j] = -1
        j += n - i*1
        i += 1
    #
    # imaginary part
    imagmat = np.zeros((n**2,(n-1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<j:
                col = i*n + j - (i+1)*(i+2)//2
                imagmat[row,col]=1
            if i>j:
                col = j*n + i - (j+1)*(j+2)//2
                imagmat[row,col]=-1
    symmat = np.hstack([realmat, 1j*imagmat])
    return symmat

smat = mat2vec(drcCI)

In [58]:
def ind_to_pair_upper(n, i, j):
    if i < j:
        return i * (2 * n - i - 1) // 2 + j
    else:
        return j * (2 * n - j - 1) // 2 + i
n = drcCI 

upper_mapping = {}

#this is the diag upper representation
for i in range(n):
    for j in range(i, n):
        k = ind_to_pair_upper(n, i, j)
        upper_mapping[(i,j)] = k
        
def ind_to_pair_upper(n, i, j):
    if i < j:
        return i * (2 * n - i - 1) // 2 + j
    # else:
    #     return j * (2 * n - j - 1) // 2 + i
n = drcCI 

upper_wo_diag_mapping = {}

#this is the offdiag upper representation
k=drcCI*(drcCI+1)//2
for i in range(n):
    for j in range(i+1, n):
        upper_wo_diag_mapping[(i,j)] = k
        k+=1

In [59]:
# this part here is problem-dependent!

# the next two rows are for H2 in STO-3G
if mol=='h2':
    if basis=='sto-3g':
        good_cols = np.array([ 0, 2, 3])
        zero_cols = np.array([ 1])
    elif basis=='6-31g':
        good_cols = np.array([ 0,  2,  4,  5,  7,  9, 11, 12, 14, 15])
        zero_cols = np.array([ 1,  3,  6,  8,  10,  13])
elif mol=='heh+':
    if basis=='sto-3g':
        good_cols = np.array([ 0, 2, 3])
        zero_cols = np.array([ 1])
    elif basis=='6-31g':
        good_cols = np.array([ 0,  2,  4,  6,  7,  9, 11, 12, 14, 15])
        zero_cols = np.array([ 1,  3,  5,  8,  10,  13])

bad_inds = []
del_upper = 0

for key in list(upper_mapping.keys()):
    for c in zero_cols:
        if c in key:
            #print(key, ' bad')
            bad_inds.append(upper_mapping[key])
            del_upper+=1
            break
print(del_upper)

4


In [60]:
del_lower = 0
for key in list(upper_wo_diag_mapping.keys()):
    for c in zero_cols:
        if c in key:
            #print(key, ' bad')
            bad_inds.append(upper_wo_diag_mapping[key])
            del_lower+=1
            break
print(del_lower)

3


In [61]:
split = drcCI*(drcCI+1)//2-del_upper

In [62]:
good_inds = np.delete(np.arange(0,drcCI**2),bad_inds)
total_inds = good_inds.tolist()
total_inds.extend(bad_inds)
total_inds.sort()

print( all(total_inds == np.arange(0,drcCI**2)) )

True


In [63]:
magicind = drcCI*(drcCI+1)//2 - 1
good_inds_del = good_inds[good_inds!=magicind]

In [64]:
bigtensJNP = jnp.array(bigtens)
bigtensTJNP = jnp.array(bigtens.T)
allpropJNP = jnp.array(allprop)
smatJNP = jnp.array(smat)

In [65]:
def firststep(rdmAO):
    j = ell
    bigmat = []
    bigmat.append( bigtensTJNP )
    allCmatT = []
    for i in range(1,ell+1):
        myexp = allpropJNP[j-i,:,:]
        if i==1:
            Cmat = myexp
        else:
            Cmat = Cmat @ myexp
        # note that transpose
        allCmatT.append(Cmat.T)
    
    newstack = jnp.stack(allCmatT)
    
    for i in range(1, ell+1):
        CmatT = newstack[i-1]
        Amat = CmatT.conj()
        bigmat.append( bigtensTJNP @ jnp.kron( CmatT, Amat ) )

    bigmat = jnp.concatenate(bigmat,axis=0)
    
    btrue = jnp.flipud(rdmAO[j-ell:(j+1),:]).reshape((-1))
    mprime = bigmat @ smatJNP

    # reconstruct full TDCI density
    xxapprox = jnp.real( jnp.linalg.pinv(mprime[:,good_inds_del],1e-12) @ (btrue - bigmat[:,-1]) )
    xxapprox2 = jnp.concatenate([xxapprox[:split-1],jnp.array([1.0]),-xxapprox[split-1:]])
    
    recon = (smatJNP[:,good_inds] @ xxapprox2).reshape((drcCI,drcCI))
    assert (recon == recon.conjugate().transpose()).all()

    # propagate in full TDCI density space via one step of MMUT!
    reconprop = allpropJNP[j,:,:] @ recon @ allpropJNP[j,:,:].conj().T

    # compute new rdm
    nextrdmAO = (reconprop.reshape((-1)) @ bigtensJNP).conj()
    return nextrdmAO, newstack

In [66]:
def loopbody(j, intup):
    rdmAO, oldstack, sv, residuals = intup
    #cond_num = []
    bigmat = []
    bigmat.append( bigtensTJNP )
    
    # note that the "icb" and "ida" here means that we are storing Cmat.T
    allpropJNPds = lax.dynamic_slice(allpropJNP,[j-1-ell,0,0],[ell,drcCI,drcCI])
    newstack = jnp.einsum('ab,icb,idc->ida',
                         allpropJNP[j-1,:,:],
                         oldstack,
                         jnp.flipud(allpropJNPds).conj(),optimize=True)

    for i in range(1, ell+1):
        CmatT = newstack[i-1]
        Amat = CmatT.conj()
        bigmat.append( bigtensTJNP @ jnp.kron( CmatT, Amat ) )

    bigmat = jnp.concatenate(bigmat,axis=0)
    rdmAOds = lax.dynamic_slice(rdmAO,[j-ell,0],[ell+1,drc**2])
    btrue = jnp.flipud(rdmAOds).reshape((-1))
    mprime = bigmat @ smatJNP
    ss = jnp.linalg.svd(mprime[:,good_inds_del], compute_uv=False)
    
    # reconstruct full TDCI density
    # xxapprox = jnp.real( jnp.linalg.pinv(mprime[:,good_inds_del],1e-12) @ (btrue - bigmat[:,-1]) )
    xxapproxtup = jnp.linalg.lstsq(mprime[:,good_inds_del], btrue - bigmat[:,-1], 1e-12 )
    xxapprox = jnp.real( xxapproxtup[0] )
    xxapprox2 = jnp.concatenate([xxapprox[:split-1],jnp.array([1.0]),-xxapprox[split-1:]])
    
    recon = (smatJNP[:,good_inds] @ xxapprox2).reshape((drcCI,drcCI))
    
    # propagate in full TDCI density space via one step of MMUT!
    reconprop = allpropJNP[j,:,:] @ recon @ allpropJNP[j,:,:].conj().T

    # compute new rdm
    rdmAO = rdmAO.at[j+1].set( (reconprop.reshape((-1)) @ bigtensJNP).conj() )
    sv = sv.at[j].set(jnp.max(ss)/jnp.min(ss))
    residuals = residuals.at[j].set(jnp.linalg.norm(mprime[:,good_inds_del] @ xxapprox - (btrue - bigmat[:,-1])))
    #sv.append(jnp.max(ss)/jnp.min(ss))
    return (rdmAO, newstack, sv, residuals)

In [67]:
rdmAOs = []
cond_nums = []
MSEs = []
residuals = []

In [68]:
jloopbody = jit(loopbody)

In [69]:
for ell in ells:
    start = time.time()
    
    numsteps = newrdmAO.shape[0]

    myrdmAOinitblock = jnp.transpose(newrdmAO[:ell+1,:,:],(0,2,1)).reshape((-1,drc**2))
    firstnewrdmAO, newstack = firststep(myrdmAOinitblock)

    myrdmAO = jnp.concatenate([myrdmAOinitblock, jnp.expand_dims(firstnewrdmAO,0), 
                               jnp.zeros((numsteps-(ell+2), drc**2), dtype=np.complex128)])
    sv = jnp.zeros((numsteps-1-ell-1))
    residual = jnp.zeros((numsteps-1-ell-1))
    outtup = lax.fori_loop(ell+1,numsteps-1,jloopbody,(myrdmAO,newstack,sv,residual))
    myrdmAO = outtup[0]
    MSE = jnp.mean(jnp.square(myrdmAO[ell+1:].reshape((-1,drc,drc)).conj() - newrdmAO[ell+1:]))
    MSEs.append(MSE)
    cond_num = outtup[2]
    myresidual = outtup[3]
    end = time.time()
    print(end-start)
    rdmAOs.append(myrdmAO)
    cond_nums.append(cond_num)
    residuals.append(myresidual[-1])
    print(MSEs[-1])
    print(residuals[-1])

53.26267766952515
(1.0816623811612049e-05-1.3883493039405498e-22j)
3.6344613515752498e-12


In [70]:
print(jnp.mean(jnp.square(myrdmAO[ell+1:].reshape((-1,drc,drc)).conj() - newrdmAO[ell+1:])))

(1.0816623811612049e-05-1.3883493039405498e-22j)
