In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
jax.config.update("jax_enable_x64", True)

In [2]:
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax

import numpy as np
import matplotlib.pyplot as plt

In [3]:
n = 10

h0 = jnp.array(np.diag(np.sort(np.pi*np.random.uniform(size=n))))
mraw = np.random.normal(size=n**2).reshape((n,n))
m = jnp.array(0.5*(mraw + mraw.T))

print(jnp.mean(jnp.abs(h0 - h0.T)))
print(jnp.mean(jnp.abs(m - m.T)))

0.0
0.0


In [4]:
dt = 1e-3

In [5]:
numsteps = 100

In [6]:
# given initial condition and forcing f, return trajectory a
def propSchro(f, a0):
    a = [a0]
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    a = jnp.stack(a, axis=0)
    return a

In [7]:
# given forcing f, IC a0, FC alpha, return cost
rho = 1
def cost(f, a0, alpha):
    a = propSchro(f, a0)
    resid = a[-1] - alpha
    pen = jnp.real(jnp.sum(resid * resid.conj()))
    return 0.5*jnp.sum(f**2) + 0.5*rho*pen

In [8]:
mya0 = jnp.eye(n)[0]

In [9]:
a = propSchro( jnp.array(np.random.normal(size=numsteps)), mya0 )

In [10]:
myalpha = jnp.array(np.random.normal(size=n))

In [11]:
cost( jnp.array(np.random.normal(size=numsteps)), mya0, myalpha )

Array(37.4341355, dtype=float64)

In [12]:
gradcost = grad(cost, 0)

In [13]:
# adjoint method
rho = 1
def adjgrad(f, a0, alpha):
    a = [a0]
    allevals = []
    allevecs = []
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        allevals.append( evals )
        allevecs.append( evecs )
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    # forward trajectory
    a = jnp.stack(a, axis=0)
    print(a.shape)
    
    # initialize lambda
    resid = a[-1] - alpha
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = [rho*resid.conj()]
    for k in range(numsteps-1,-1,-1):
        kk = (numsteps-1) - k
        alllamb.append( alllamb[kk] @ allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T )
    
    # backward trajectory
    alllamb = jnp.flipud(jnp.stack(alllamb, axis=0))
    
    # set up mask
    mask = jnp.ones((n, n)) - jnp.eye(n)
    
    # first critical calculation
    allexpderivs = []
    for k in range(numsteps):
        amat = allevecs[k].conj().T @ m @ allevecs[k]
        dvec = -1j*dt*allevals[k]
        dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
        mask = jnp.ones((n,n)) - jnp.eye(n)
        numer = jnp.exp(dvec1) - jnp.exp(dvec2)
        denom = (dvec1 - dvec2)*mask + jnp.eye(n)
        derivmat = mask*numer/denom + jnp.diag(jnp.exp(dvec))
        qmat = allevecs[k] @ (amat * derivmat) @ allevecs[k].conj().T
        allexpderivs.append( qmat )
    
    # output gradient we want
    ourgrad = []
    for k in range(numsteps):
        # because we have stored "lambda conjugate", we just need transpose here to get the dagger
        ourgrad.append( alllamb[k+1].T @ allexpderivs[k] @ a[k] )
    
    return f + jnp.real((-1j*dt)*jnp.array(ourgrad))

# simple finite-difference gradient checker
def fdgrad(kk, a0, alpha, myeps=1e-6):
    pertvec = jnp.eye(numsteps)[kk]
    cplus = cost( ftest + myeps*pertvec, a0, alpha )
    cminus = cost( ftest - myeps*pertvec, a0, alpha )
    return (cplus - cminus)/(2*myeps)

In [14]:
# UNIT TESTING for DERIVATIVE OF MATRIX EXPONENTIAL
#
# dvec = jnp.array([2.0, 4.5])
# dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
# mask = jnp.ones((n,n)) - jnp.eye(n)
# numer = jnp.exp(dvec1) - jnp.exp(dvec2)
# denom = (dvec1 - dvec2)*mask + jnp.eye(n)
# test = mask*numer/denom + jnp.diag(jnp.exp(dvec))
# print(test)
# 
# npres = np.zeros((n,n))
# for i in range(n):
#     for l in range(n):
#         if dvec[i]==dvec[l]:
#             npres[i,l] = np.exp(dvec[i])
#         else:
#             npres[i,l] = (np.exp(dvec[i]) - np.exp(dvec[l]))/(dvec[i] - dvec[l])
# 
# print(npres)

In [15]:
ftest = jnp.array(np.random.normal(size=numsteps))
adjres = adjgrad( ftest, mya0, myalpha )
# print(adjres)
jaxres = gradcost( ftest, mya0, myalpha )
# print(jaxres)
print(jnp.mean(jnp.abs(adjres - jaxres)))

(101, 10)


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


5.715393436300787e-16


In [16]:
def myexp(x):
    thisham = h0 + x*m
    evals, evecs = jnp.linalg.eigh(thisham)
    out = evecs @ jnp.diag(jnp.exp(evals)) @ evecs.conj().T
    return out.reshape((-1))

In [17]:
dmyexp = jacobian(myexp)

In [18]:
ddmyexp = jacobian(dmyexp)

In [19]:
mask = jnp.ones((n,n)) - jnp.eye(n)

def myhess(x):
    z = h0 + x*m
    evals, evecs = jnp.linalg.eigh(z)
    a = evecs.conj().T @ m @ evecs
    expevals = jnp.exp(evals)
    evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
    mask = jnp.ones((n,n)) - jnp.eye(n)
    expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
    # first D_{ii}=D_{kk} term
    diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
    # second D_{ii}=D_{kk} term
    numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    denom1 = (evals1-evals2)**2 + jnp.eye(n)
    maska = mask * a
    diagterm2 = np.eye(n) * 2*((numer1/denom1 * maska) @ maska)
    # first D_{ii}!=D_{kk} term
    frac1 = numer1/denom1 * mask
    term1 = frac1*2*(jnp.diag(a)*a).T
    # second D_{ii}!=D_{kk} term
    numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
    denom2 = (evals1-evals2)**2 + jnp.eye(n)
    frac2 = numer2/denom2 * mask
    term2 = frac2*2*(a*jnp.diag(a))
    # third D_{ii}!=D_{kk} term
    matij = mask*(1.0/((evals1-evals2) + jnp.eye(n)))
    matind1a = (expevals1 * matij) * a
    matind2a = (expevals2 * matij) * a
    term3 = 2*mask*((matind2a) @ (matij*a))
    term3 -= 2*matij*( matind1a @ maska )
    term3 -= 2*matij*( maska @ matind2a ) 
    # put it all together
    udagru = term1 - term2 - term3 + diagterm1 + diagterm2
    return evecs @ udagru @ evecs.conj().T

In [20]:
myhess(0.25)

Array([[ 2.09734077e+01,  6.75226274e+00,  7.23642597e+00,
         1.57597700e+01, -1.72248924e-01,  3.91927912e+01,
         5.42834015e+01, -1.27183395e+01, -9.14628302e+00,
         1.99314862e+01],
       [ 6.75226274e+00,  4.54955693e+01,  1.75591618e+01,
         4.96467101e+00,  3.39840241e+01,  4.54518975e+00,
        -1.93443627e+01,  4.17957166e+01, -1.27499081e+01,
        -5.59076781e+00],
       [ 7.23642597e+00,  1.75591618e+01,  1.55032913e+01,
         8.56288537e+00,  6.90618605e+00,  1.11375141e+01,
         5.24708360e+00,  1.63112591e+01, -7.47277438e+00,
        -3.32168101e+00],
       [ 1.57597700e+01,  4.96467101e+00,  8.56288537e+00,
         1.75021831e+01, -1.35832951e+00,  2.84901749e+01,
         3.93863057e+01, -3.72106627e+00, -9.24057937e+00,
         1.10999632e+01],
       [-1.72248924e-01,  3.39840241e+01,  6.90618605e+00,
        -1.35832951e+00,  3.65899321e+01, -1.58075238e+01,
        -3.84321843e+01,  3.29242891e+01, -1.17389924e+01,
        -1.

In [21]:
jnp.mean(jnp.abs(myhess(0.25) - ddmyexp(0.25).reshape((n,n))))

Array(1.64379621e-14, dtype=float64)

In [22]:
# simple finite-difference gradient checker
def fdhess(x, myeps=1e-6):
    cplus = dmyexp( x + myeps )
    cminus = dmyexp( x - myeps )
    return (cplus - cminus)/(2*myeps)

In [23]:
jnp.mean(jnp.abs(myhess(0.25) - fdhess(0.25).reshape((n,n))))

Array(1.88839858e-08, dtype=float64)

In [24]:
jmyhess = jit(myhess)

In [25]:
jddmyexp = jit(ddmyexp)

In [26]:
import time

In [31]:
start = time.time()
myres = jmyhess(-0.123)
end = time.time()
print(end-start)

0.0007960796356201172


In [32]:
start = time.time()
myres = jddmyexp(-0.123)
end = time.time()
print(end-start)

0.0006320476531982422
