In [1]:
# use a random projection to try to reduce parameters
# also try layer-wise training

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"]="false"

import jax
jax.config.update("jax_enable_x64", True)

In [2]:
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax

import numpy as np
import matplotlib.pyplot as plt

In [3]:
h0 = jnp.array([[1.0,0.0,0.0],[0.0,2.0,0.0],[0.0,0.0,jnp.pi]])
m = jnp.array([[0.5,-2.0,1.0],[-2.0,0.7,0.3],[1.0,0.3,0.4]])
print(jnp.mean(jnp.abs(h0 - h0.T)))
print(jnp.mean(jnp.abs(m - m.T)))

0.0
0.0


In [4]:
dt = 1e-3

In [5]:
numsteps = 100
n = 3

In [6]:
# given initial condition and forcing f, return trajectory a
def propSchro(f, a0):
    a = [a0]
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    a = jnp.stack(a, axis=0)
    return a

In [7]:
# given forcing f, IC a0, FC alpha, return cost
rho = 1
def cost(f, a0, alpha):
    a = propSchro(f, a0)
    resid = a[-1] - alpha
    pen = jnp.real(jnp.sum(resid * resid.conj()))
    return 0.5*jnp.sum(f**2) + 0.5*rho*pen

In [8]:
mya0 = jnp.array([1.0,0.0,0.0]) 

In [9]:
a = propSchro( jnp.array(np.random.normal(size=numsteps)), mya0 )

In [10]:
myalpha = jnp.array([0.25,-0.5,0.75])

In [11]:
cost( jnp.array(np.random.normal(size=numsteps)), mya0, myalpha )

Array(42.37704045, dtype=float64)

In [12]:
gradcost = grad(cost, 0)

In [13]:
# adjoint method
rho = 1
def adjgrad(f, a0, alpha):
    a = [a0]
    allevals = []
    allevecs = []
    for k in range(numsteps):
        thisham = h0 + f[k]*m
        evals, evecs = jnp.linalg.eigh(thisham)
        allevals.append( evals )
        allevecs.append( evecs )
        a.append( evecs @ jnp.diag(jnp.exp(-1j*dt*evals)) @ evecs.conj().T @ a[k] )
    
    # forward trajectory
    a = jnp.stack(a, axis=0)
    print(a.shape)
    
    # initialize lambda
    resid = a[-1] - alpha
    # we are storing "lambda conjugate" throughout this calculation
    alllamb = [rho*resid.conj()]
    for k in range(numsteps-1,-1,-1):
        kk = (numsteps-1) - k
        alllamb.append( alllamb[kk] @ allevecs[k] @ jnp.diag(jnp.exp(-1j*dt*allevals[k])) @ allevecs[k].conj().T )
    
    # backward trajectory
    alllamb = jnp.flipud(jnp.stack(alllamb, axis=0))
    
    # set up mask
    mask = jnp.ones((n, n)) - jnp.eye(n)
    
    # first critical calculation
    allexpderivs = []
    for k in range(numsteps):
        amat = allevecs[k].conj().T @ m @ allevecs[k]
        dvec = -1j*dt*allevals[k]
        dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
        mask = jnp.ones((n,n)) - jnp.eye(n)
        numer = jnp.exp(dvec1) - jnp.exp(dvec2)
        denom = (dvec1 - dvec2)*mask + jnp.eye(n)
        derivmat = mask*numer/denom + jnp.diag(jnp.exp(dvec))
        qmat = allevecs[k] @ (amat * derivmat) @ allevecs[k].conj().T
        allexpderivs.append( qmat )
    
    # output gradient we want
    ourgrad = []
    for k in range(numsteps):
        # because we have stored "lambda conjugate", we just need transpose here to get the dagger
        ourgrad.append( alllamb[k+1].T @ allexpderivs[k] @ a[k] )
    
    return f + jnp.real((-1j*dt)*jnp.array(ourgrad))

# simple finite-difference gradient checker
def fdgrad(kk, a0, alpha, myeps=1e-6):
    pertvec = jnp.eye(numsteps)[kk]
    cplus = cost( ftest + myeps*pertvec, a0, alpha )
    cminus = cost( ftest - myeps*pertvec, a0, alpha )
    return (cplus - cminus)/(2*myeps)

In [14]:
# UNIT TESTING for DERIVATIVE OF MATRIX EXPONENTIAL
#
# dvec = jnp.array([2.0, 4.5])
# dvec1, dvec2 = jnp.meshgrid(dvec, dvec)
# mask = jnp.ones((n,n)) - jnp.eye(n)
# numer = jnp.exp(dvec1) - jnp.exp(dvec2)
# denom = (dvec1 - dvec2)*mask + jnp.eye(n)
# test = mask*numer/denom + jnp.diag(jnp.exp(dvec))
# print(test)
# 
# npres = np.zeros((n,n))
# for i in range(n):
#     for l in range(n):
#         if dvec[i]==dvec[l]:
#             npres[i,l] = np.exp(dvec[i])
#         else:
#             npres[i,l] = (np.exp(dvec[i]) - np.exp(dvec[l]))/(dvec[i] - dvec[l])
# 
# print(npres)

In [15]:
ftest = jnp.array(np.random.normal(size=numsteps))
adjres = adjgrad( ftest, mya0, myalpha )
# print(adjres)
jaxres = gradcost( ftest, mya0, myalpha )
# print(jaxres)
print(jnp.mean(jnp.abs(adjres - jaxres)))

(101, 3)


  x_bar = _convert_element_type(x_bar, x.aval.dtype, x.aval.weak_type)


1.7699492045608611e-16


In [16]:
def myexp(x):
    thisham = h0 + x*m
    evals, evecs = jnp.linalg.eigh(thisham)
    out = evecs @ jnp.diag(jnp.exp(evals)) @ evecs.conj().T
    return out.reshape((-1))

In [22]:
myexp(-0.5)

Array([ 4.60151186,  4.58493099, -5.01474672,  4.58493099,  7.66890076,
       -3.622522  , -5.01474672, -3.622522  , 20.71251129], dtype=float64)

In [18]:
dmyexp = jacobian(myexp)

In [23]:
dmyexp(-0.5)

Array([ -8.91549742, -10.86554809,  11.77529408, -10.86554809,
        -6.00133093,  10.57918381,  11.77529408,  10.57918381,
         0.32808872], dtype=float64, weak_type=True)

In [20]:
ddmyexp = jacobian(dmyexp)

In [24]:
ddmyexp(-0.5)

Array([ 22.7560428 ,  13.95536512, -12.14195109,  13.95536512,
        22.64835165, -15.41213593, -12.14195109, -15.41213593,
        19.4284515 ], dtype=float64, weak_type=True)

In [27]:
mask = jnp.ones((n,n)) - jnp.eye(n)

mask3 = np.zeros((3,3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i!=k and j!=k:
                mask3[i,j,k] = 1

mask3 = jnp.array(mask3)

newmask3 = np.zeros((3,3,3))
for i in range(3):
    for j in range(3):
        for k in range(3):
            if i!=j and i==k:
                newmask3[i,j,k] = 1
                
newmask3 = jnp.array(newmask3)

def myhess(x):
    z = h0 + x*m
    evals, evecs = jnp.linalg.eigh(z)
    a = evecs.conj().T @ m @ evecs
    expevals = jnp.exp(evals)
    evals1, evals2 = jnp.meshgrid(evals, evals, indexing='ij')
    mask = jnp.ones((n,n)) - jnp.eye(n)
    expevals1, expevals2 = jnp.meshgrid(expevals, expevals, indexing='ij')
    # first D_{ii}=D_{kk} term
    diagterm1 = expevals1*jnp.diag(jnp.diag(a*a))
    # first D_{ii}!=D_{kk} term
    numer1 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    denom1 = (evals1-evals2)**2 * mask + jnp.eye(n)
    frac1 = numer1/denom1 * mask
    term1 = frac1*2*(jnp.diag(a)*a).T
    # second D_{ii}!=D_{kk} term
    numer2 = -expevals1 + evals1*expevals2 - evals2*expevals2 + expevals2
    denom2 = (evals1-evals2)**2 * mask + jnp.eye(n)
    frac2 = numer2/denom2 * mask
    term2 = frac2*2*(a*jnp.diag(a))
    # third D_{ii}!=D_{kk} term
    paren = 2*jnp.einsum('ij,jk->ijk',a,a)
    evals1, evals2, evals3 = jnp.meshgrid(evals, evals, evals, indexing='ij')
    expevals1, expevals2, expevals3 = jnp.meshgrid(expevals, expevals, expevals, indexing='ij')
    numer3 = evals1*(expevals2-expevals3)-evals2*(expevals1-expevals3)+evals3*(expevals1-expevals2)
    denom3 = (evals1-evals2)*(evals1-evals3)*(evals2-evals3) + jnp.ones((3,3,3)) - mask3
    term3 = jnp.sum( (numer3/denom3 * mask3) * paren, axis=1 )
    # second D_{ii}=D_{kk} term
    diagnumer2 = -expevals1 + evals1*expevals1 - evals2*expevals1 + expevals2
    diagdenom2 = (evals1-evals2)**2 + jnp.ones((3,3,3)) - newmask3
    diagfrac2 = (diagnumer2/diagdenom2)*newmask3
    diagterm2 = jnp.sum( diagfrac2*paren, axis=1 )
    # put it all together
    udagru = term1 - term2 - term3 + diagterm1 + diagterm2
    return evecs @ udagru @ evecs.conj().T

In [44]:
jnp.linalg.norm(myhess(0.5) - ddmyexp(0.5).reshape((n,n)))

Array(1.56480806e-13, dtype=float64)

In [31]:
jmyhess = jit(myhess)

In [32]:
jddmyexp = jit(ddmyexp)

In [33]:
import time

In [47]:
start = time.time()
myres = jmyhess(-1.0)
end = time.time()
print(end-start)

0.0005421638488769531


In [48]:
start = time.time()
myres = jddmyexp(-1.0)
end = time.time()
print(end-start)

0.00052642822265625
