In [1]:
from jax import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import jit, lax

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import scipy.integrate as si
import scipy.optimize as so
import scipy.linalg as sl

import time

from tqdm import trange
import os 
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'

In [6]:
mol = 'h2'
basis = 'sto-3g'

In [7]:
# construct prefix used to load and save files
if basis=='sto-3g':
    prefix='casscf22_s2_'
elif basis=='6-31g':
    prefix='casscf24_s15_'
else:
    print("Error: basis set not recognized! Must choose either sto-3g or 6-31g")
    sys.exit(1)

In [8]:
bigtens = np.load('./logfiles/'+prefix+mol+'_'+basis+'_tensor.npy')
dimat = np.load('./logfiles/'+prefix+mol+'_'+basis+'_CI_dimat.npz')

In [9]:
# time step
mydt = 0.008268

# number of time steps to test for
# here i am going for a final time of 200.0 (atomic units)
trajshape0 = int(np.ceil(200.0 / mydt))

# E field parameters
freq = 1.5
ncyc = 5
emax = 0.5

In [10]:
# load diagonalized CI Hamiltonian
ham = np.load('./logfiles/'+prefix+mol+'_'+basis+'_hamiltonian.npy')

# drcCI is our own in-house acronym for "density rows/columns in configuration interaction (basis)" 
# basically, we need some descriptive variable name (preferably not N or M or something generic)
# that tells us how many rows/columns are in the following *configuration interaction* objects:
# - Hamiltonian matrices
# - state vectors $\mathbf{a}(t)$
# - density matrices
drcCI = ham.shape[0]

# why did we do this?
ham = ham - np.diag([np.min(ham)]*drcCI)

In [11]:
# defining some extra time-stepping parameters
offset = 0
tvec = np.arange(offset,offset+trajshape0)*mydt

# the electric field is switched on at t=0 and switched off at t=tmeoff (time off)
tmeoff = ncyc*2*np.pi/freq
ef = (tvec>=0)*(tvec<=tmeoff)*emax*np.sin(freq*tvec)

In [12]:
# here we compute all one-step propagators!

# the hamCI Hamiltonian consists of two pieces:
# 1) the core CI diagonal matrix (loaded from Gaussian)
# 2) an electric field term (amplitude * dipole moment matrix in z direction)

# to compute each propagator, we diagonalize hamCI at each time step

hamCI = np.expand_dims(ham,0) - np.einsum('i,jk->ijk',ef,dimat)
hamCI = (1+0j)*hamCI
alldd = np.zeros((trajshape0,drcCI),dtype=np.float64)
allvv = np.zeros((trajshape0,drcCI,drcCI),dtype=np.complex128)
allprop = np.zeros((trajshape0,drcCI,drcCI),dtype=np.complex128)
for i in range(trajshape0):
    dd, vv = np.linalg.eigh(hamCI[i,:,:])
    alldd[i,:] = dd
    allvv[i,:,:] = vv
    allprop[i,:,:] = vv @ np.diag(np.exp(-1j*mydt*alldd[i,:])) @ vv.conj().T
    

In [13]:
# propagate TDCI solution
newtdcicoeffs = np.zeros((trajshape0,drcCI),dtype=np.complex128)

# the initial density matrix is a purely diagonal matrix
# with 1.0 in the upper-left corner, and 0.0 everywhere else
newtdcicoeffs[0,0] = 1.0 

# because we have already computed and stored all 1-stop propagators,
# let us use them here
for i in range(trajshape0-1):
    newtdcicoeffs[i+1,:] = allprop[i,:,:] @ newtdcicoeffs[i,:]

# load overlap matrix from disk (the Gaussian .log file) rather than having it hardcoded...
goodlines = []
with open('./logfiles/'+prefix+mol+'_'+basis+'.log','r') as f:
    startmat = 0
    cnt = 0
    for line in f:
        if line.rstrip() == ' *** Overlap ***':
            startmat = 1
        if startmat==1:
            if line.rstrip() == ' *** Kinetic Energy ***':
                break
            if cnt >= 2:
                goodlines.append(line.rstrip())
            cnt += 1

drc = len(goodlines)
S = np.zeros((drc, drc))
for j in range(drc):
    for k in range(j+1):
        myarr = np.fromstring(goodlines[j].replace('D','e'), sep=' ')
        S[j,k] = myarr[k+1]

for j in range(drc):
    for k in range(j+1,drc):
        S[j,k] = S[k,j]

In [14]:
# at once, let us compute all true 1RDM matrices in the AO basis
newrdmAO = np.einsum('ni,nj,ijab->nab',newtdcicoeffs,np.conjugate(newtdcicoeffs),bigtens)

# these 1RDM matrices should have constant trace equal to the number of electrons
traces_p = np.einsum('ijj->i', newrdmAO@S)

# for H2 and HeH+, this number should be extremely close to 2, the number of electrons
print(np.mean(np.abs(traces_p)))

1.9999843712344565


In [15]:
# at once, let us compute all full density matrices in the CI basis
tdciden = np.einsum('ni,nj->nij',newtdcicoeffs,np.conj(newtdcicoeffs))

# let us check how idempotent all of these density matrices are
# they should be idempotent, because
#   \mathbf{a}(t) \mathbf{a}(t)^\dagger \mathbf{a}(t) \mathbf{a}(t)^\dagger 
# = \mathbf{a}(t) ( \mathbf{a}(t)^\dagger \mathbf{a}(t) ) \mathbf{a}(t)^\dagger 
# = \mathbf{a}(t) (1) \mathbf{a}(t)^\dagger
# = \mathbf{a}(t) \mathbf{a}(t)^\dagger
print( np.linalg.norm(np.einsum('nij,njk->nik',tdciden,tdciden) - tdciden) )

1.1281347107310403e-11


In [16]:
# reshape the 4-index tensor into a 2-index matrix
bigtens = bigtens.reshape((drcCI**2,drc**2)).astype(np.complex128)

In [17]:
# mat2vec produces a matrix that can be used to *represent*
# an n x n Hermitian matrix with constant trace via
# a real vector of size n**2

# note:
# smat = mat2vec(n) will generate a complex matrix of size n**2 x n**2

# multiplying smat by a real vector of size n**2 will generate a complex vector of size n**2,
# which when reshaped into an n x n matrix, will be Hermitian

# important technical note:
# in the real vector of size n**2, the entry at Python index location (n * (n+1))//2 - 1
# tells you what the trace of the Hermitian matrix will be!
#
# i refer to "(n * (n+1))//2 - 1" as the "magic index" below...
#

# example usage for the simplest possible Hermitian matrix, which is of size 2x2:
# 
# here n = 2, so the "magic index" (n * (n+1))//2 - 1 = 2
#
# let us say that our desired trace is 2.0 and that we want to represent the Hermitian matrix
# 
# [[ 1.0         2.0 - 3.0j ]
# [ 2.0 + 3.0j  1.0        ]]
#
# here is how we can do it:
#
# smat = mat2vec(2)
# realvec = np.array([1.0, 2.0, 2.0, -3.0])
# hermvec = smat @ realvec
# print(hermvec.reshape((2,2)))
# 
# the output of this code snippet will be
#
# [[1.+0.j 2.-3.j]
#  [2.+3.j 1.+0.j]]
#
# as desired!
#

def mat2vec(n):
    # each column of the realmat is a flattened element of the basis
    # of the vector space of all Hermitian matrices of size n x n
    # 
    # in particular, these basis elements handle the *symmetric* or *real*
    # part of the Hermitian matrix, which is why realmat's non-zero entries are 1
    realmat = np.zeros((n**2,(n+1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<=j:
                col = i*n + j - i*(i+1)//2
            else:
                col = j*n + i - j*(j+1)//2
            realmat[row,col]=1

    # here we take care of constant trace!
    # - you will see that we index realmat[-1, j], and
    #   the "-1" is Python shorthand for "last row"
    # - what we are doing here is basically setting up the equation
    #   H_{n,n} = traceconst - H_{1,1} - H_{2,2} - ... - H_{n-1,n-1}
    # - the "traceconst" term is already taken into account via the construction above;
    #   before we run the "while" loop below, realmat[-1,:] will be zero except for one
    #   entry of 1.0 in column number (n+1)*n//2 - 1, which is of course the "magic index"
    #   that we referred to above
    # - the job of the while loop below is to replace some zeros with -1, 
    #   at locations that correspond to diagonal entries of H
    # - you can figure out what indices these are by working through the while loop logic:
    #   - j=0 corresponds to H_{1,1}
    #   - j=n corresponds to H_{2,2} because we do not store the lower-triangular entry H_{2,1}
    #   - j=2*n-1 corresponds to H_{3,3} because we do not store either H_{3,1} or H_{3,2}
    # - try it for n=4 and you will see the relationship if you write down the triangular array
    #    0 1 2 3 
    #      4 5 6 
    #        7 8 
    #          9
    j = 0
    i = 0
    while j < (realmat.shape[1] - 1):
        realmat[-1,j] = -1
        j += n - i*1
        i += 1
    
    # each column of the imagmat is a flattened element of the basis
    # of the vector space of all Hermitian matrices of size n x n
    #
    # in particular, these basis elements handle the *antisymmetric* or *imaginary*
    # part of the Hermitian matrix, which is why imagmat has both 1 and -1 entries
    imagmat = np.zeros((n**2,(n-1)*n//2),dtype=np.int16)
    for i in range(n):
        for j in range(n):
            row = i*n + j
            if i<j:
                col = i*n + j - (i+1)*(i+2)//2
                imagmat[row,col]=1
            if i>j:
                col = j*n + i - (j+1)*(j+2)//2
                imagmat[row,col]=-1
    symmat = np.hstack([realmat, 1j*imagmat])
    return symmat

smat = mat2vec(drcCI)

In [18]:
# after this code block, upper_mapping will map tuples to integers
# so, row/column indices of the upper-triangular part of a matrix
# *including the diagonal*
# get mapped to flattened vectorial indices
#
# philosophical note: as we include the diagonal here,
# we can think of this as mapping tuples to indices 
# for the **real** part of a Hermitian matrix
#
upper_mapping = {}
k = 0
for i in range(drcCI):
    for j in range(i, drcCI):
        upper_mapping[(i,j)] = k
        k += 1

# after this code block, upper_mapping will map tuples to integers
# so, row/column indices of the upper-triangular part of a matrix
# *excluding the diagonal*
# get mapped to flattened vectorial indices
#
# philosophical note: as we exclude the diagonal here,
# we can think of this as mapping tuples to indices 
# for the **imaginary** part of a Hermitian matrix
#
upper_wo_diag_mapping = {}
k = drcCI*(drcCI+1)//2
for i in range(drcCI):
    for j in range(i+1, drcCI):
        upper_wo_diag_mapping[(i,j)] = k
        k += 1

# if we have done this correctly, the following two numbers should agree,
# because there a complex Hermitian drcCI x drcCI matrix is determined by
# a total of drcCI**2 real numbers
# print(k)
# print(drcCI**2)

In [19]:
# this part here is problem-dependent!
# 
# zero_cols = list of columns of the *full CI density matrix* that are identically zero
# 
# good_cols = all other columns
# 
if mol=='h2':
    if basis=='sto-3g':
        good_cols = np.array([ 0,  2,  3])
        zero_cols = np.array([ 1])
    elif basis=='6-31g':
        good_cols = np.array([ 0,  2,  4,  5,  7,  9,  11,  12,  14,  15])
        zero_cols = np.array([ 1,  3,  6,  8,  10,  13])
elif mol=='heh+':
    if basis=='sto-3g':
        good_cols = np.array([ 0,  2,  3])
        zero_cols = np.array([ 1])
    elif basis=='6-31g':
        good_cols = np.array([ 0,  2,  4,  6,  7,  9,  11,  12,  14,  15])
        zero_cols = np.array([ 1,  3,  5,  8,  10,  13])

In [20]:
# note: because the full CI density matrix is
#       \mathbf{a}(t) \mathbf{a}(t)^\dagger,
#       if a column is identically zero then so is the corresponding row!
#       in short, this happend because a particular *entry* of the \mathbf{a}(t)
#       vector is itself identically zero (i.e., zero for all time t)
#
# another note: the words "upper" and "lower" below originate in the ordering of indices
# 
# in the vectorial real representation of a complex n x n Hermitian matrix,
# we FIRST have n*(n+1)//2 indices (the "upper" ones) that give us the real part, and
# we THEN have n*(n-1)//2 indices (the "lower" ones) that gives us the imaginary part

bad_inds = []
del_upper = 0

# loop over the (i,j) tuples in the upper_mapping dictionary
for key in list(upper_mapping.keys()):
    # loop over the zero columns defined above
    for c in zero_cols:
        # if **either** the row i or column j (in the (i,j) tuple) equals the zero column c,
        # then the vectorized index upper_mapping[(i,j)] is a "bad index";
        # for each such match, we increment our count of deleted "upper" indices
        if c in key:
            bad_inds.append(upper_mapping[key])
            del_upper+=1
            break

del_lower = 0

# loop over the (i,j) tuples in the upper_wo_diag_mapping dictionary
for key in list(upper_wo_diag_mapping.keys()):
    # loop over the zero columns defined above
    for c in zero_cols:
        # if **either** the row i or column j (in the (i,j) tuple) equals the zero column c,
        # then the vectorized index upper_mapping[(i,j)] is a "bad index";
        # for each such match, we increment our count of deleted "lower" indices
        if c in key:
            bad_inds.append(upper_wo_diag_mapping[key])
            del_lower+=1
            break


In [30]:
# before we start deleting entries from our real representation,
# this is the "magic index" that enforces the constant trace condition
magicind = drcCI*(drcCI+1)//2 - 1

# we form an array of good indices by start with 
# a vector of all indices from 0 to drcCI**2 - 1 
# and then deleting the bad indices
good_inds = np.delete(np.arange(0,drcCI**2),bad_inds)

# for convenience, we define an array consisting of all good indices
# not including the magic index
good_inds_del = good_inds[good_inds!=magicind]

# since we have deleted "del_upper" number of entries from the 
# upper triangular portion of our real representation, we must shift
# the magic index by that amount
split = magicind - del_upper

In [31]:
# these next lines of code form a sanity check
# if we combine the good and bad indices, and then sort them,
# we should end up with a vector of all indices from 0 to drcCI**2 - 1
total_inds = good_inds.tolist()
total_inds.extend(bad_inds)
total_inds.sort()
print( all(total_inds == np.arange(0,drcCI**2)) )

True


In [32]:
# copy arrays to GPU
bigtensJNP = jnp.array(bigtens)
bigtensTJNP = jnp.array(bigtens.T)
allpropJNP = jnp.array(allprop)
smatJNP = jnp.array(smat)

In [33]:
def firststep(rdmAO):
    j = ell
    bigmat = []
    bigmat.append( bigtensTJNP )
    allCmatT = []
    for i in range(1,ell+1):
        myexp = allpropJNP[j-i,:,:]
        if i==1:
            Cmat = myexp
        else:
            Cmat = Cmat @ myexp
        # note that transpose
        allCmatT.append(Cmat.T)
    
    newstack = jnp.stack(allCmatT)
    
    for i in range(1, ell+1):
        CmatT = newstack[i-1]
        Amat = CmatT.conj()
        bigmat.append( bigtensTJNP @ jnp.kron( CmatT, Amat ) )

    bigmat = jnp.concatenate(bigmat,axis=0)
    
    btrue = jnp.flipud(rdmAO[j-ell:(j+1),:]).reshape((-1))
    mprime = bigmat @ smatJNP

    # reconstruct full TDCI density
    xxapprox = jnp.real( jnp.linalg.pinv(mprime[:,good_inds_del],1e-12) @ (btrue - bigmat[:,-1]) )
    xxapprox2 = jnp.concatenate([xxapprox[:split],jnp.array([1.0]),-xxapprox[split:]])
    
    recon = (smatJNP[:,good_inds] @ xxapprox2).reshape((drcCI,drcCI))
    assert (recon == recon.conjugate().transpose()).all()

    # propagate in full TDCI density space via one step of MMUT!
    reconprop = allpropJNP[j,:,:] @ recon @ allpropJNP[j,:,:].conj().T

    # compute new rdm
    nextrdmAO = (reconprop.reshape((-1)) @ bigtensJNP).conj()
    return nextrdmAO, newstack

In [34]:
def loopbody(j, intup):
    rdmAO, oldstack, sv, residuals = intup
    #cond_num = []
    bigmat = []
    bigmat.append( bigtensTJNP )
    
    # note that the "icb" and "ida" here means that we are storing Cmat.T
    allpropJNPds = lax.dynamic_slice(allpropJNP,[j-1-ell,0,0],[ell,drcCI,drcCI])
    newstack = jnp.einsum('ab,icb,idc->ida',
                         allpropJNP[j-1,:,:],
                         oldstack,
                         jnp.flipud(allpropJNPds).conj(),optimize=True)

    for i in range(1, ell+1):
        CmatT = newstack[i-1]
        Amat = CmatT.conj()
        bigmat.append( bigtensTJNP @ jnp.kron( CmatT, Amat ) )

    bigmat = jnp.concatenate(bigmat,axis=0)
    rdmAOds = lax.dynamic_slice(rdmAO,[j-ell,0],[ell+1,drc**2])
    btrue = jnp.flipud(rdmAOds).reshape((-1))
    mprime = bigmat @ smatJNP
    ss = jnp.linalg.svd(mprime[:,good_inds_del], compute_uv=False)
    
    # reconstruct full TDCI density
    # xxapprox = jnp.real( jnp.linalg.pinv(mprime[:,good_inds_del],1e-12) @ (btrue - bigmat[:,-1]) )
    xxapproxtup = jnp.linalg.lstsq(mprime[:,good_inds_del], btrue - bigmat[:,-1], 1e-12 )
    xxapprox = jnp.real( xxapproxtup[0] )
    xxapprox2 = jnp.concatenate([xxapprox[:split],jnp.array([1.0]),-xxapprox[split:]])
    
    recon = (smatJNP[:,good_inds] @ xxapprox2).reshape((drcCI,drcCI))
    
    # propagate in full TDCI density space via one step of MMUT!
    reconprop = allpropJNP[j,:,:] @ recon @ allpropJNP[j,:,:].conj().T

    # compute new rdm
    rdmAO = rdmAO.at[j+1].set( (reconprop.reshape((-1)) @ bigtensJNP).conj() )
    sv = sv.at[j].set(jnp.max(ss)/jnp.min(ss))
    residuals = residuals.at[j].set(jnp.linalg.norm(mprime[:,good_inds_del] @ xxapprox - (btrue - bigmat[:,-1])))
    #sv.append(jnp.max(ss)/jnp.min(ss))
    return (rdmAO, newstack, sv, residuals)

In [35]:
# initialize empty lists to store things
rdmAOs = []
cond_nums = []
MSEs = []
residuals = []

In [36]:
# JIT compilation
jloopbody = jit(loopbody)

In [37]:
# which time delays do we want to consider
ells = np.arange(70,72,2)

In [38]:
for ell in ells:
    start = time.time()
    
    numsteps = newrdmAO.shape[0]

    myrdmAOinitblock = jnp.transpose(newrdmAO[:ell+1,:,:],(0,2,1)).reshape((-1,drc**2))
    firstnewrdmAO, newstack = firststep(myrdmAOinitblock)

    myrdmAO = jnp.concatenate([myrdmAOinitblock, jnp.expand_dims(firstnewrdmAO,0), 
                               jnp.zeros((numsteps-(ell+2), drc**2), dtype=np.complex128)])
    sv = jnp.zeros((numsteps-1-ell-1))
    residual = jnp.zeros((numsteps-1-ell-1))
    outtup = lax.fori_loop(ell+1,numsteps-1,jloopbody,(myrdmAO,newstack,sv,residual))
    myrdmAO = outtup[0]
    MSE = jnp.mean(jnp.square(myrdmAO[ell+1:].reshape((-1,drc,drc)).conj() - newrdmAO[ell+1:]))
    MSEs.append(MSE)
    cond_num = outtup[2]
    myresidual = outtup[3]
    end = time.time()
    print(end-start)
    rdmAOs.append(myrdmAO)
    cond_nums.append(cond_num)
    residuals.append(myresidual[-1])
    print(MSEs[-1])
    print(residuals[-1])

46.29167413711548
(7.830718994836973e-13+5.665857986325539e-25j)
2.0627426044773962e-12


In [39]:
print(jnp.mean(jnp.square(myrdmAO[ell+1:].reshape((-1,drc,drc)).conj() - newrdmAO[ell+1:])))

(7.830718994836973e-13+5.665857986325539e-25j)
