In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
jax.config.update("jax_enable_x64", True)

In [2]:
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax

import numpy as np
import matplotlib.pyplot as plt

In [3]:
import time

In [4]:
n = 3

h0 = jnp.array(np.diag(np.sort(np.pi*np.random.uniform(size=n))))
mraw = np.random.normal(size=n**2).reshape((n,n))
m = jnp.array(0.5*(mraw + mraw.T))

print(jnp.mean(jnp.abs(h0 - h0.T)))
print(jnp.mean(jnp.abs(m - m.T)))

0.0
0.0


In [5]:
dt = 1e-3

In [6]:
numsteps = 100

In [7]:
# given initial condition and forcing f, return trajectory a
def propSchro(f, a0):
    a = [a0]
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    a = jnp.stack(a, axis=0)
    return a

In [8]:
# given initial condition and forcing f, return trajectory a
def propSchroRI(f, a0):
    a = [a0]
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    a = jnp.stack(a, axis=0)
    return jnp.stack([a.real, a.imag], axis=0)

In [9]:
jaxgrada = jacobian(propSchroRI, 0)

In [10]:
# given forcing f, IC a0, FC alpha, return cost
rho = 1
def cost(f, a0, alpha):
    a = propSchro(f, a0)
    resid = a[-1] - alpha
    pen = jnp.real(jnp.sum(resid * resid.conj()))
    return 0.5*jnp.sum(f**2) + 0.5*rho*pen

In [11]:
mya0 = jnp.eye(n)[0]
myalpha = jnp.array(np.random.normal(size=n))

In [12]:
a = propSchro( jnp.array(np.random.normal(size=numsteps)), mya0 )
cost( jnp.array(np.random.normal(size=numsteps)), mya0, myalpha )

Array(54.44616496, dtype=float64)

In [13]:
gradcost = grad(cost, 0)

In [14]:
hesscost = jacobian(gradcost, 0)

In [15]:
# (d/dx) \exp(-1j*dt*(h0 + x m))
# where you pass in the eigenvectors and eigenvalues of (h0 + x m)
def firstderiv(evecs, evals):
    amat = evecs.conj().T @ (-1j*dt*m) @ evecs
    dvec = -1j*dt*evals
    dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
    mask = jnp.ones((n,n)) - jnp.eye(n)
    numer = jnp.exp(dvec1) - jnp.exp(dvec2)
    denom = (dvec1 - dvec2)*mask + jnp.eye(n)
    derivmat = mask*numer/denom + jnp.diag(jnp.exp(dvec))
    qmat = evecs @ (amat * derivmat) @ evecs.conj().T
    return qmat

In [16]:
# adjoint method
rho = 1
def adjgrad(f, a0, alpha):
    a = [a0]
    allevals = []
    allevecs = []
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        allevals.append( evals )
        allevecs.append( evecs )
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    # forward trajectory
    a = jnp.stack(a, axis=0)
    
    # initialize lambda
    resid = a[-1] - alpha
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = [rho*resid.conj()]
    for k in range(numsteps-1,-1,-1):
        kk = (numsteps-1) - k
        alllamb.append( alllamb[kk] @ allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T )
    
    # backward trajectory
    alllamb = jnp.flipud(jnp.stack(alllamb, axis=0))
    
    # set up mask
    mask = jnp.ones((n, n)) - jnp.eye(n)
    
    # first critical calculation
    allexpderivs = []
    for k in range(numsteps):
        allexpderivs.append( firstderiv(allevecs[k], allevals[k]) )
    
    # output gradient we want
    ourgrad = []
    for k in range(numsteps):
        # because we have stored "lambda conjugate", we just need transpose here to get the dagger
        ourgrad.append( alllamb[k+1].T @ allexpderivs[k] @ a[k] )
    
    return f + jnp.real(jnp.array(ourgrad))

# simple finite-difference gradient checker
def fdgrad(kk, a0, alpha, myeps=1e-6):
    pertvec = jnp.eye(numsteps)[kk]
    cplus = cost( ftest + myeps*pertvec, a0, alpha )
    cminus = cost( ftest - myeps*pertvec, a0, alpha )
    return (cplus - cminus)/(2*myeps)

In [17]:
ftest = jnp.array(np.random.normal(size=numsteps))
adjres = adjgrad( ftest, mya0, myalpha )
# print(adjres)
jaxres = gradcost( ftest, mya0, myalpha )
# print(jaxres)
print(jnp.mean(jnp.abs(adjres - jaxres)))

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


2.633310236532793e-17


In [18]:
def myexp(x):
    thisham = h0 + x*m
    evals, evecs = jnp.linalg.eigh(thisham)
    out = evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T
    of = out.reshape((-1))
    return jnp.stack([of.real, of.imag])

In [19]:
mask = jnp.ones((n,n)) - jnp.eye(n)

def myhess(x):
    z = h0 + x*m
    evals, evecs = jnp.linalg.eigh(z)
    evals *= (-1j*dt)
    a = (evecs.conj().T @ m @ evecs) * (-1j*dt)
    expevals = jnp.exp(evals)
    evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
    mask = jnp.ones((n,n)) - jnp.eye(n)
    expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
    # first D_{ii}=D_{kk} term
    diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
    # second D_{ii}=D_{kk} term
    numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    denom1 = (evals1-evals2)**2 + jnp.eye(n)
    maska = mask * a
    diagterm2 = np.eye(n) * 2*((numer1/denom1 * maska) @ maska)
    # first D_{ii}!=D_{kk} term
    frac1 = numer1/denom1 * mask
    term1 = frac1*2*(jnp.diag(a)*a).T
    # second D_{ii}!=D_{kk} term
    numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
    denom2 = (evals1-evals2)**2 + jnp.eye(n)
    frac2 = numer2/denom2 * mask
    term2 = frac2*2*(a*jnp.diag(a))
    # third D_{ii}!=D_{kk} term
    matij = mask*(1.0/((evals1-evals2) + jnp.eye(n)))
    matind1a = (expevals1 * matij) * a
    matind2a = (expevals2 * matij) * a
    term3 = 2*mask*((matind2a) @ (matij*a))
    term3 -= 2*matij*( matind1a @ maska )
    term3 -= 2*matij*( maska @ matind2a ) 
    # put it all together
    # udagru stands for "U^{\dagger} R U"
    udagru = term1 - term2 - term3 + diagterm1 + diagterm2
    return evecs @ udagru @ evecs.conj().T

In [20]:
u = (myexp(2.0)[0,:] + myexp(2.0)[1,:]*1j).reshape((n,n))

In [21]:
jnp.mean(jnp.abs(u @ u.conj().T - jnp.eye(n)))

Array(2.58568317e-16, dtype=float64)

In [22]:
dmyexp = jacobian(myexp)
ddmyexp = jacobian(dmyexp)

In [23]:
jaxres = ddmyexp( 1.5 )[0,:] + ddmyexp( 1.5 )[1,:]*1j

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


In [24]:
myres = myhess( 1.5 )

In [25]:
jnp.mean(jnp.abs(jaxres.reshape((n,n)) - myres))

Array(2.32297734e-16, dtype=float64)

In [26]:
mask = jnp.ones((n,n)) - jnp.eye(n)

# (d^2/dx^2) \exp(-1j*dt*(h0 + x m))
# where you pass in the eigenvectors and eigenvalues of (h0 + x m)
def secondderiv(evecs, rawevals):
    evals = (-1j*dt)*rawevals
    a = (evecs.conj().T @ m @ evecs) * (-1j*dt)
    expevals = jnp.exp(evals)
    evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
    mask = jnp.ones((n,n)) - jnp.eye(n)
    expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
    # first D_{ii}=D_{kk} term
    diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
    # second D_{ii}=D_{kk} term
    numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    denom1 = (evals1-evals2)**2 + jnp.eye(n)
    maska = mask * a
    diagterm2 = np.eye(n) * 2*((numer1/denom1 * maska) @ maska)
    # first D_{ii}!=D_{kk} term
    frac1 = numer1/denom1 * mask
    term1 = frac1*2*(jnp.diag(a)*a).T
    # second D_{ii}!=D_{kk} term
    numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
    denom2 = (evals1-evals2)**2 + jnp.eye(n)
    frac2 = numer2/denom2 * mask
    term2 = frac2*2*(a*jnp.diag(a))
    # third D_{ii}!=D_{kk} term
    matij = mask*(1.0/((evals1-evals2) + jnp.eye(n)))
    matind1a = (expevals1 * matij) * a
    matind2a = (expevals2 * matij) * a
    term3 = 2*mask*((matind2a) @ (matij*a))
    term3 -= 2*matij*( matind1a @ maska )
    term3 -= 2*matij*( maska @ matind2a ) 
    # put it all together
    # udagru stands for "U^{\dagger} R U"
    udagru = term1 - term2 - term3 + diagterm1 + diagterm2
    return evecs @ udagru @ evecs.conj().T

In [27]:
# second-order adjoint method
rho = 1
def adjgradhess(f, a0, alpha):
    a = [a0]
    allevals = []
    allevecs = []
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        allevals.append( evals )
        allevecs.append( evecs )
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    # forward trajectory
    a = jnp.stack(a, axis=0)
    
    # initialize lambda
    resid = a[-1] - alpha
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = [rho*resid.conj()]
    for k in range(numsteps-1,-1,-1):
        kk = (numsteps-1) - k
        alllamb.append( alllamb[kk] @ allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T )
    
    # backward trajectory
    alllamb = jnp.flipud(jnp.stack(alllamb, axis=0))
    
    # set up mask
    mask = jnp.ones((n, n)) - jnp.eye(n)
    
    # first critical calculation
    allexpderivs = []
    for k in range(numsteps):
        allexpderivs.append( firstderiv(allevecs[k], allevals[k]) )
    
    # output gradient we want
    ourgrad = []
    for k in range(numsteps):
        # because we have stored "lambda conjugate", we just need transpose here to get the dagger
        ourgrad.append( alllamb[k+1].T @ allexpderivs[k] @ a[k] )
    
    thegrad = f + jnp.real(jnp.array(ourgrad))
    
    # Hessian part of the calculation
    
    # compute gradient of a w.r.t. f
    # this tensor stores the gradient of the n-dimensional vector a[k] with respect to f[l]
    grada = jnp.zeros((numsteps,numsteps,n), dtype=jnp.complex128)
    for l in range(numsteps):
        gradvec = allexpderivs[l] @ a[l]
        grada = grada.at[l,l,:].set(gradvec)
        for k in range(l+1,numsteps):
            gradvec = allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T @ gradvec
            grada = grada.at[k,l,:].set(gradvec)
    
    # create and propagate mu
    # as before, let us store and propagate "mu conjugate"
    allmu = [rho*grada[numsteps-1,:,:].conj()]
    for k in range(numsteps-1,-1,-1):
        kk = (numsteps-1) - k
        prevmu = allmu[kk] @ allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T
        prevmu = prevmu.at[k].set( prevmu[k] + alllamb[k+1].T @ allexpderivs[k] )
        allmu.append(prevmu)

    # backward trajectory
    allmu = jnp.flipud(jnp.stack(allmu, axis=0))
        
    # compute and store all Hessians
    allexpderivs2 = []
    for k in range(numsteps):
        allexpderivs2.append( secondderiv(allevecs[k], allevals[k]) )
    
    # compute Hessian one row at a time
    ourhess = []
    for k in range(numsteps):
        # because we have stored "mu conjugate", we just need transpose here to get the dagger
        thisrow = jnp.real(allmu[k+1,:,:] @ allexpderivs[k] @ a[k])
        thisrow += (k>=1)*jnp.real(alllamb[k+1].T @ allexpderivs[k] @ grada[k-1,:,:].T)
        thisrow = thisrow.at[k].set( thisrow[k] + jnp.real(alllamb[k+1].T @ allexpderivs2[k] @ a[k]) )
        ourhess.append(thisrow)    
    
    thehess = jnp.eye(numsteps) + jnp.stack(ourhess, axis=0)
    
    return thehess

In [28]:
# debug with NumPy
# allmuNP = np.zeros((numsteps+1, numsteps, n), dtype=np.complex128)
# allmuNP[numsteps,:,:] = np.array(rho*grada[numsteps-1,:,:])
# for k in range(numsteps-1, -1, -1):
#     allmuNP[k,:,:] = allmuNP[k+1] @ np.array(allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T)
#     allmuNP[k,k,:] += np.array(alllamb[k+1].T @ allexpderivs[k])

# jaxresraw = jaxgrada( ftest, mya0 )
# jaxres = jaxresraw[0,:,:] + 1j*jaxresraw[1,:,:]
# jaxres = jnp.transpose(jaxres[1:,:,:], axes=(0,2,1))

In [29]:
start = time.time()
myhess = adjgradhess(ftest, mya0, myalpha)
end = time.time()
print(end-start)

9.149816274642944


In [30]:
print(jnp.mean(jnp.abs(myhess - myhess.T)))

1.7032138503389472e-21


In [31]:
start = time.time()
jaxhess = hesscost(ftest, mya0, myalpha)
end = time.time()
print(end-start)

  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)
  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


10.308220624923706


In [32]:
jnp.mean(jnp.abs(myhess-jaxhess))

Array(8.07243871e-19, dtype=float64)