In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
jax.config.update("jax_enable_x64", True)

In [2]:
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax, vmap

import numpy as np
import matplotlib.pyplot as plt

In [3]:
import time
import scipy.optimize as so

In [4]:
dt = 1e-2
numsteps = 2000
tvec = dt*jnp.arange(numsteps)
rho = 1e6

In [5]:
mol = 'heh+'

# basis = 'sto-3g'
# prefix = 'casscf22_s2_'

basis = '6-31g'
prefix = 'casscf24_s15_'

# if basis=='sto-3g':
#     prefix = 'casscf22_s2_'
# elif basis=='6-31g':
#     prefix = 'casscf24_s15_'

In [6]:
# load Hamiltonian
h0 = np.load('./data/'+prefix+mol+'_'+basis+'_hamiltonian.npz')
n = h0.shape[0]

# load dipole moment matrix
m = np.load('./data/'+prefix+mol+'_'+basis+'_CI_dimat.npz')

# load initial and final states
P0T = np.load('./data/'+mol+'_'+basis+'_P0T.npz')
thisalpha = jnp.array(P0T['alpha'])
thisbeta = jnp.array(P0T['beta'])

print("alpha = " + str(thisalpha))
print("beta = " + str(thisbeta))

alpha = [1. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
beta = [0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1.]


In [7]:
# (d/dx) \exp(-1j*dt*(h0 + x m))
# where you pass in the eigenvectors and eigenvalues of (h0 + x m)
def firstderiv(evecs, evals):
    amat = evecs.conj().T @ (-1j*dt*m) @ evecs
    dvec = -1j*dt*evals
    dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
    mask = jnp.ones((n,n)) - jnp.eye(n)
    numer = jnp.exp(dvec1) - jnp.exp(dvec2)
    denom = (dvec1 - dvec2)*mask + jnp.eye(n)
    derivmat = mask*numer/denom + jnp.diag(jnp.exp(dvec))
    qmat = evecs @ (amat * derivmat) @ evecs.conj().T
    return qmat

In [8]:
mask = jnp.ones((n,n)) - jnp.eye(n)

# (d^2/dx^2) \exp(-1j*dt*(h0 + x m))
# where you pass in the eigenvectors and eigenvalues of (h0 + x m)
def secondderiv(evecs, rawevals):
    evals = (-1j*dt)*rawevals
    a = (evecs.conj().T @ m @ evecs) * (-1j*dt)
    expevals = jnp.exp(evals)
    evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
    expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
    # first D_{ii}=D_{kk} term
    diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
    # second D_{ii}=D_{kk} term
    numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    denom1 = (evals1-evals2)**2 + jnp.eye(n)
    maska = mask * a
    diagterm2 = jnp.eye(n) * 2*((numer1/denom1 * maska) @ maska)
    # first D_{ii}!=D_{kk} term
    frac1 = numer1/denom1 * mask
    term1 = frac1*2*(jnp.diag(a)*a).conj().T
    # second D_{ii}!=D_{kk} term
    numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
    denom2 = (evals1-evals2)**2 + jnp.eye(n)
    frac2 = numer2/denom2 * mask
    term2 = frac2*2*(a*jnp.diag(a))
    # third D_{ii}!=D_{kk} term
    matij = mask*(1.0/((evals1-evals2) + jnp.eye(n)))
    matind1a = (expevals1 * matij) * a
    matind2a = (expevals2 * matij) * a
    term3 = 2*mask*((matind2a) @ (matij*a))
    term3 -= 2*matij*( matind1a @ maska )
    term3 -= 2*matij*( maska @ matind2a ) 
    # put it all together
    # udagru stands for "U^{\dagger} R U"
    udagru = term1 - term2 - term3 + diagterm1 + diagterm2
    return evecs @ udagru @ evecs.conj().T

In [9]:
def gradal(l,expderiv,a,matexp):
    ea = expderiv @ a
    gradvecs = [(l==0)*ea + (l>0)*jnp.zeros(n, dtype=jnp.complex128)]
    for k in range(1,numsteps):
        thisvec = (k<l)*jnp.zeros(n, dtype=jnp.complex128)
        thisvec += (k==l)*ea
        thisvec += (k>l)*(matexp[k] @ gradvecs[k-1])
        gradvecs.append( thisvec )
    
    # zeroblock = jnp.zeros((l, n), dtype=jnp.complex128)
    return jnp.stack(gradvecs, axis=0)

In [10]:
def onematexp(evecs,expevals):
    return evecs @ jnp.diag(expevals) @ evecs.conj().T

In [11]:
manyeigh = vmap(jnp.linalg.eigh)
vfd = vmap(firstderiv, in_axes=(0,0))
vsd = vmap(secondderiv, in_axes=(0,0))
vgradal = vmap(gradal, in_axes=(0,0,0,None))
vonematexp = vmap(onematexp)

In [12]:
# freqvec = jnp.array([0.62831853, 1.25663706, 1.88495559, 2.51327412, 3.14159265, 3.76991118, 
#                      4.39822972, 5.02654825, 5.65486678])
# nf = freqvec.shape[0]
# numparams = 1 + 2*nf
# def fmodelraw(theta, t):
#     return theta[0] + jnp.sum(theta[1:(1+nf)]*jnp.sin(freqvec*t)) + jnp.sum(theta[(1+nf):(1+2*nf)]*jnp.cos(freqvec*t))

In [13]:
# PARAMETERS THAT DEFINE NEURAL NET CONTROL (FIELD STRENGTH)
layerwidths = [1, 4, 4, 4, 1]
nlayers = len(layerwidths)-1
numparams = 0
numweights = 0
for j in range(nlayers):
    numparams += layerwidths[j]*layerwidths[j+1] + layerwidths[j+1]
    numweights += layerwidths[j]*layerwidths[j+1]

print("number of neural network parameters = " + str(numparams))

def fmodelraw(theta, t):
    filt = []
    for j in range(nlayers):
        if j==0:
            si = 0
            ei = layerwidths[0]*layerwidths[1]
        else:
            si += layerwidths[j-1]*layerwidths[j]
            ei += layerwidths[j]*layerwidths[j+1]
        filt.append( theta[si:ei].reshape((layerwidths[j],layerwidths[j+1])) )

    bias = []
    for j in range(nlayers):
        if j==0:
            si += layerwidths[nlayers-1]*layerwidths[nlayers]
            ei += layerwidths[1]
        else:
            si += layerwidths[j]
            ei += layerwidths[j+1]
        bias.append( theta[si:ei] )
    
    f = jax.nn.softplus( t * filt[0] + bias[0] )
    for j in range(nlayers-2):
        if j==(nlayers-3):
            activation = jnp.sin
        else:
            activation = jax.nn.softplus
        f = activation( f @ filt[j+1] + bias[j+1] )
    
    f = f @ filt[nlayers-1] + bias[nlayers-1]
    return f[0,0]

number of neural network parameters = 53


In [14]:
# XAVIER WEIGHT INITIALIZATION
def xavier():
    params = []
    for i in range(nlayers):
        a = 1.0/np.sqrt(layerwidths[i])
        params.append( np.random.uniform(size=layerwidths[i]*layerwidths[i+1], low=-a, high=a) )
    params.append( np.zeros(numparams-numweights) )
    return np.concatenate(params)

In [15]:
fmodel = vmap(fmodelraw, in_axes=(None,0))

gradfraw = jacobian(fmodelraw)
gradf = vmap(gradfraw, in_axes=(None,0))

hessfraw = jacobian(gradfraw)
hessf = vmap(hessfraw, in_axes=(None,0))

In [16]:
# given initial condition and forcing f, return trajectory a
def propSchro(theta, a0):
    manyhams = jnp.expand_dims(h0,0) + jnp.expand_dims(fmodel(theta, tvec),(1,2))*jnp.expand_dims(m,0)
    allevals, allevecs = manyeigh(manyhams)
    expevals = jnp.exp(-1j*dt*allevals)
    matexp = vonematexp(allevecs,expevals)
    
    a = jnp.concatenate([jnp.expand_dims(a0,0), jnp.zeros((numsteps, n), dtype=jnp.complex128)])
    def amatbody(k, am):
        return am.at[k+1].set( matexp[k] @ am[k] )
    
    # forward trajectory
    a = lax.fori_loop(0, numsteps, amatbody, a)
    return a

# given forcing f, IC a0, FC alpha, return cost
def cost(theta, a0, alpha):
    a = propSchro(theta, a0)
    resid = a[-1] - alpha
    pen = jnp.real(jnp.sum(resid * resid.conj()))
    return 0.5*jnp.sum(fmodel(theta, tvec)**2) + 0.5*rho*pen

In [17]:
# adjoint method
def adjgrad(theta, a0, alpha):
    f = fmodel(theta, tvec)
    gf = gradf(theta, tvec)
    manyhams = jnp.expand_dims(h0,0) + jnp.expand_dims(f,(1,2))*jnp.expand_dims(m,0)
    allevals, allevecs = manyeigh(manyhams)
    expevals = jnp.exp(-1j*dt*allevals)
    matexp = vonematexp(allevecs,expevals)
    
    a = jnp.concatenate([jnp.expand_dims(a0,0), jnp.zeros((numsteps, n), dtype=jnp.complex128)])
    def amatbody(k, am):
        return am.at[k+1].set( matexp[k] @ am[k] )
    
    # forward trajectory
    a = lax.fori_loop(0, numsteps, amatbody, a)
    
    # initialize lambda
    resid = a[-1] - alpha
    
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = jnp.concatenate([jnp.expand_dims(rho*resid.conj(),0), jnp.zeros((numsteps, n), dtype=jnp.complex128)])
    def lambbody(i, al):
        k = (numsteps-1) - i
        return al.at[i+1].set( al[i] @ matexp[k] )
    
    # backward trajectory
    alllamb = lax.fori_loop(0, numsteps, lambbody, alllamb)
    alllamb = jnp.flipud(alllamb)
    
    # first critical calculation
    allexpderivs = vfd(allevecs, allevals)
    
    # output gradient we want
    ourgrad = jnp.einsum('ai,aij,al,aj->l',alllamb[1:],allexpderivs,gf,a[:-1])
    thegrad = f @ gf + jnp.real(ourgrad)
    
    return thegrad

In [18]:
# second-order adjoint method
def adjhess(theta, a0, alpha):
    f = fmodel(theta, tvec)
    gf = gradf(theta, tvec)
    hf = hessf(theta, tvec)
    manyhams = jnp.expand_dims(h0,0) + jnp.expand_dims(f,(1,2))*jnp.expand_dims(m,0)
    allevals, allevecs = manyeigh(manyhams)
    expevals = jnp.exp(-1j*dt*allevals)
    matexp = vonematexp(allevecs,expevals)

    a = jnp.concatenate([jnp.expand_dims(a0,0), jnp.zeros((numsteps, n), dtype=jnp.complex128)])
    def amatbody(k, am):
        return am.at[k+1].set( matexp[k] @ am[k] )
    
    # forward trajectory
    a = lax.fori_loop(0, numsteps, amatbody, a)
    
    # initialize lambda
    resid = a[-1] - alpha
    
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = jnp.concatenate([jnp.expand_dims(rho*resid.conj(),0), jnp.zeros((numsteps, n), dtype=jnp.complex128)])
    def lambbody(i, al):
        k = (numsteps-1) - i
        return al.at[i+1].set( al[i] @ matexp[k] )
    
    # backward trajectory
    alllamb = lax.fori_loop(0, numsteps, lambbody, alllamb)
    alllamb = jnp.flipud(alllamb)
    
    # first critical calculation
    allexpderivs = vfd(allevecs, allevals)
    
    # compute gradient of a w.r.t. f
    # grada tensor stores the gradient of the n-dimensional vector a[k] with respect to f[l]
    lvec = jnp.arange(numsteps,dtype=jnp.int16)
    grada = vgradal(lvec, allexpderivs, a[:-1], matexp)
    grada = jnp.transpose(grada,(1,0,2))
    grada = jnp.einsum('ijk,jl->ilk',grada,gf)
    
    # create and propagate mu
    # as before, let us store and propagate "mu conjugate"
    allmu0 = rho*grada[numsteps-1,:,:].conj()
    allmu = jnp.concatenate([jnp.expand_dims(allmu0,0),
                             jnp.zeros((numsteps, numparams, n), dtype=jnp.complex128)])
    # allprevmu2 = jnp.flipud(jnp.outer(jnp.ones(n),jnp.eye(numsteps)).T.reshape((numsteps,numsteps,n)))
    def mubody(kk, amu):
        k = (numsteps-1) - kk
        prevmu1 = amu[kk] @ matexp[k]
        prevmu2 = jnp.outer(gf[k],alllamb[k+1].T @ allexpderivs[k])
        return amu.at[kk+1].set( prevmu1+prevmu2 )

    # backward trajectory
    allmu = lax.fori_loop(0, numsteps, mubody, allmu)
    allmu = jnp.flipud(allmu)
        
    # second critical calculation
    allexpderivs2 = vsd(allevecs, allevals)
    
    # compute Hessian
    gradapad = jnp.concatenate([jnp.zeros((1,numparams,n),dtype=jnp.complex128), grada[:-1,:,:]])
    # j -> numsteps
    # l -> numparams
    # k -> n
    term1 = jnp.einsum('jlk,jka,jm,ja->lm',allmu[1:],allexpderivs,gf,a[:-1])
    term2a = jnp.einsum('jk,jka,jlm,ja->lm',alllamb[1:],allexpderivs,hf,a[:-1])
    term2b = jnp.einsum('jk,jka,jl,jm,ja->lm',alllamb[1:],allexpderivs2,gf,gf,a[:-1])
    term3 = jnp.einsum('jk,jka,jm,jla->lm',alllamb[1:],allexpderivs,gf,gradapad)
    pcc = term1 + term2a + term2b + term3
    hcc = jnp.einsum('ai,aj->ij',gf,gf) + jnp.einsum('a,aij->ij',f,hf)
    thehess = hcc + jnp.real(pcc)
    
    return thehess

In [19]:
jcost = jit(cost)
jadjgrad = jit(adjgrad)
jadjhess = jit(adjhess)

In [20]:
mya0 = jnp.eye(n)[0]
myalpha = jnp.eye(n)[n-1]

In [21]:
thetastar = np.load('./NNoutput/nnresult_'+mol+'_'+basis+'_14441.npz')['thetastar']
print(thetastar.shape)
print(cost(thetastar, mya0, myalpha))
print(np.linalg.norm(adjgrad(thetastar, mya0, myalpha)))
traj = propSchro(thetastar, a0=mya0)
print(np.linalg.norm(traj[-1,:]-thisbeta))

(53,)
841.1578342871533
4969.208630834285
0.0011683124078297393


In [22]:
# thetatest = xavier() # jnp.array(0.1*np.random.normal(size=numparams))
# mycost = jcost(thetatest, mya0, myalpha)
# mygrad = jadjgrad(thetatest, mya0, myalpha)
# myhess = jadjhess(thetatest, mya0, myalpha)

In [23]:
# def obj(x):
#     jx = jnp.array(x)
#     return jcost(jx,mya0,myalpha).item()

In [24]:
# def gradobj(x):
#     jx = jnp.array(x)
#     return np.array(jadjgrad(jx,mya0,myalpha))

In [25]:
# def hessobj(x):
#     jx = jnp.array(x)
#     return np.array(jadjhess(jx,mya0,myalpha))

In [26]:
# for j in range(1000):
#     thetatest = jnp.array(0.25*np.random.normal(size=numparams))
#     thisobj = obj(thetatest)
#     if thisobj < 200000:
#         print(thisobj)
#         break

In [27]:
# np.savez('Feb16resultsGreen.npz',thetatest=thetatest,xstarx=xstar.x)

In [28]:
# thetatest = xavier()

# start = time.time()
# xstar = so.minimize(obj, x0=np.array(thetatest), method='trust-constr', jac=gradobj, hess=hessobj,
#                     options={'gtol':1e-16,'xtol':1e-16,'verbose':2,'maxiter':10000})
# end = time.time()
# print(end-start)

# np.savez('nnresult_'+mol+'_'+basis+'_14441.npz',thetastar=xstar.x,thetainit=np.array(thetatest))

In [29]:
# thetastar = xstar.x
traj = propSchro(thetastar, a0=mya0)
trajNP = np.array(traj)
print(trajNP.shape)

(2001, 16)


In [30]:
# # matplotlib, with Agg to save to disk
# import matplotlib
# matplotlib.use('Agg')

# # set plot font+size
# font = {'weight' : 'bold', 'size' : 16}
# matplotlib.rc('font', **font)

# plt.rcParams['pdf.fonttype'] = 42

In [31]:
# # autogenerate labels
# labels=[]
# for i in range(n):
#     labels.append(r'| $a_{'+str(i+1)+'}(t)$ |')

# # time vector
# plottvec = np.arange(numsteps+1)*dt

# # this only works because we've hard-coded i==0 and i==11 for the case of 2x2 and 6x6 matrices
# labeled = False
# plt.figure(figsize=(9,6))
# for i in range(n):
#     if i==0:
#         plt.plot(plottvec, np.abs(trajNP[:,i]), label=labels[i], color='#d01c8b', zorder=10, linewidth=2)
#     elif i==(n-1):
#         plt.plot(plottvec, np.abs(trajNP[:,i]), label=labels[i], color='#4dac26', zorder=10, linewidth=2)
#     else:
#         plt.plot(plottvec, np.abs(trajNP[:,i]), color='silver')

# plt.legend(loc='upper center', bbox_to_anchor=(.5, 1.12), ncol=3, fancybox=False, shadow=False, frameon=False)

# plt.xlabel('time')
# plt.savefig('NNcontroltraj_'+mol+'_'+basis+'_14441.pdf',bbox_inches = "tight")
# plt.close()

In [32]:
# fm = fmodel(thetastar, plottvec)

In [33]:
# plt.figure(figsize=(8,6))
# plt.plot(plottvec, fm)
# plt.xlabel('time (a.u.)')
# plt.ylabel('control f(t)')
# plt.savefig('NNcontrolsig_'+mol+'_'+basis+'_14441.pdf',bbox_inches = "tight")
# plt.close()

In [34]:
# def cheaphessobj(x):
#     jx = jnp.array(x)
#     return np.array(adjhess(jx,mya0,myalpha))

In [35]:
# np.save('hessevals_'+mol+'_'+basis+'.npy',np.linalg.eigvalsh(cheaphessobj(thetastar)))

In [36]:
# # set plot font+size
# font = {'weight' : 'bold', 'size' : 14}
# matplotlib.rc('font', **font)

# numparams = 53

# plt.figure(figsize=(5,6))
# allevals = np.zeros((2,2,numparams))
# mols = ['heh+', 'h2']
# bases = ['sto-3g', '6-31g']
# for ii in range(2):
#     for jj in range(2):
#         allevals[ii,jj,:] = np.load('./NNoutput/hessevals_'+mols[ii]+'_'+bases[jj]+'.npy')
#         plt.plot(np.arange(numparams)+1,allevals[ii,jj,:],label=mols[ii]+', '+bases[jj])

# # plt.yscale('log')
# plt.yscale('symlog')
# plt.title('Neural network with ' + r"$|\theta| = 53$")
# plt.ylabel('eigenvalues of final Hessian')
# # plt.legend(loc='center right',bbox_to_anchor=(1.75, 0.5))
# plt.legend(loc='upper left')
# plt.grid()
# negexp = np.log(-np.min(allevals))/np.log(10)+1
# posexp = np.log(np.max(allevals))/np.log(10)+1
# plt.yticks(np.concatenate([np.flipud(-10**np.arange(negexp)),10**np.arange(posexp)]))
# plt.xticks([1, 10, 20, 30, 40, 53])
# plt.savefig('NNeigs_14441.pdf',bbox_inches = "tight")
# plt.close()

In [37]:
# np.linalg.eigvalsh(hessobj(thetastar))