In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax, vmap, pmap

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import scipy.integrate as si
import scipy.optimize as so

import time

In [2]:
# given time series observations, our goal here is to find f(x)
# such that the observations can be predicted using the ODE
# dx/dt = \dot{x} = f(x)

In [3]:
# typical supervised learning task
# you have a lot of (x, y) pairs
# you want a function y = f(x)
# our problem is different, we never directly observe the thing we're trying to learn (which is f(x))1q
dilfac = 1

In [4]:
alltraj = np.load('heh+_training_data.npz')
alltraj.shape

(90, 20000, 2, 2)

In [5]:
#flatten density matrices into a vector and reshape to be in form 
trajs = []
for i in range(alltraj.shape[0]):
    traj = alltraj[i]
    steps  = []
    for j in range(400):
        vec = np.zeros((2,),dtype=np.complex128)
        denmat = traj[j]
        v1 = denmat[0][0]
        v2 = denmat[1][1]
        vec[0] = v1
        vec[1] = v2
        steps.append(vec)
    trajs.append(steps)
trajs = np.array(trajs)
trajs.shape

(90, 400, 2)

In [6]:
# now pretend we don't know the vector field but we want to learn it
# set up a neural network model for f
nlayers = 4

# units per hidden layer
uphl = 64

# dimension of vector field
vfd = 2

# set up neural network parameters
layerwidths = [2*vfd,uphl,uphl,uphl,vfd]
numparams = 0
numweights = 0
numbiases = 0
for j in range(nlayers):
    numparams += layerwidths[j]*layerwidths[j+1] + layerwidths[j+1]
    numweights += layerwidths[j]*layerwidths[j+1]
    numbiases += layerwidths[j+1]

# print out total number of parameters
print(numparams)

# definition of actual neural network function
def neuralf(x, theta):
    filt = []
    si = 0
    ei = layerwidths[0]*layerwidths[1]
    filt.append( theta[si:ei].reshape((layerwidths[0],layerwidths[1])) )
    si += layerwidths[0]*layerwidths[1]
    ei += layerwidths[1]*layerwidths[2]
    filt.append( theta[si:ei].reshape((layerwidths[1],layerwidths[2])) )
    si += layerwidths[1]*layerwidths[2]
    ei += layerwidths[2]*layerwidths[3]
    filt.append( theta[si:ei].reshape((layerwidths[2],layerwidths[3])) )
    si += layerwidths[2]*layerwidths[3]
    ei += layerwidths[3]*layerwidths[4]
    filt.append( theta[si:ei].reshape((layerwidths[3],layerwidths[4])) )
    bias = []
    si += layerwidths[3]*layerwidths[4]
    ei += layerwidths[1]
    bias.append( theta[si:ei] )
    si += layerwidths[1]
    ei += layerwidths[2]
    bias.append( theta[si:ei] )
    si += layerwidths[2]
    ei += layerwidths[3]
    bias.append( theta[si:ei] )
    si += layerwidths[3]
    ei += layerwidths[4]
    bias.append( theta[si:ei] )
    inplyr = jnp.concatenate([x.real,x.imag])
    h1 = jax.nn.selu( inplyr @ filt[0] + bias[0] )
    h2 = jax.nn.selu( h1 @ filt[1] + bias[1] )
    h3 = jax.nn.selu( h2 @ filt[2] + bias[2] )
    h4 = h3 @ filt[3] + bias[3]
    return h4

8770


In [7]:
from jax.experimental import ode

In [8]:
# move training data from NumPy to JAX
jalltraj = jnp.array(trajs)
print(jalltraj.shape)

(90, 400, 2)


In [9]:
numsteps = 400

In [10]:
# define JAX function to compute one predicted trajectory
def predtraj(y0, theta):
    def rhsfunc(y, t):
        return neuralf(y, theta)
    
    intdt = 0.08268/dilfac
    intnpts = (numsteps-1)*dilfac + 1
    inttint = np.arange(intnpts)*intdt
    return ode.odeint(rhsfunc, y0, inttint)


In [11]:
# now parallelize the above function to handle many initial conditions but one theta
vpredtraj = vmap(predtraj, in_axes=(0, None))

In [12]:
# now define loss (or objective function)
def obj(y, theta):
    # compute all predicted trajectories for fixed theta
    predfine = predtraj(y[0,:], theta)
    pred = predfine[::dilfac, :]
    # now compute mean squared errors between predictions and training data
    return jnp.mean(jnp.square(pred - y))


In [35]:
# use JAX to compile objective function
jobj = jit(obj)

# use JAX to compute gradient and compile it
jgradobj = jit(grad(obj,1,holomorphic=True))

In [36]:
dfdx = jacobian(neuralf, 0, holomorphic=True)
dfdtheta = jacobian(neuralf, 1, holomorphic=True)

In [37]:
vdfdtheta = vmap(dfdtheta, in_axes=(0,None))

In [82]:
# adjoint method to compute gradient of objective function
# here y stands for one trajectory of observations, of shape (numsteps, 2)
def newlagwithgrad(y, theta):
    intdt = 0.08268/dilfac
    intnpts = (numsteps-1)*dilfac + 1
    inttint = np.arange(intnpts)*intdt
    
    # solve forward problem and compute residual
    def rhsfunc(y, t):
        return neuralf(y, theta)
    
    xfine = ode.odeint(rhsfunc, y[0,:], inttint)
    
    x = xfine[::dilfac,:]
    resid = 2*(x - y)
    
    # compute and save loss for later
    obj = jnp.mean(jnp.square(resid))
    
    # trapezoid rule quadrature weights
    w = jnp.concatenate([jnp.array([0.5]),jnp.ones(intnpts-2),jnp.array([0.5])])
    
    # backward-in-time loop body
    # Heun's method backward in time
    inth = -intdt
    
    def bodylamb(j, lambmfine):
       
        #import pdb; pdb.set_trace()
        feval1 = -lambmfine[intnpts-j, :] @ dfdx(xfine[intnpts-j, :], theta.astype(np.complex128)) #@ dfdy(xfine[intnpts-j, :].real,xfine[intnpts-j, :].imag, theta)
        lambtilde = lambmfine[intnpts-j, :] + inth*feval1
        feval2 = -lambtilde @ dfdx(xfine[intnpts-j-1, :], theta.astype(np.complex128)) #@ dfdy(xfine[intnpts-j-1,:].real ,xfine[intnpts-j-1, :].imag , theta)
        prevlamb = lambmfine[intnpts-j, :] + (inth/2.0)*(feval1 + feval2)
        prevlamb += ((intnpts-j-1) % dilfac == 0) * resid[(intnpts-j-1)//dilfac, :]
        return lambmfine.at[intnpts-j-1].set( prevlamb )
    
    # loop initialization
    lambmat = jnp.concatenate([ jnp.zeros((intnpts-1, vfd)), resid[[numsteps-1],:] ])
    
    # actually loop
    lambout = lax.fori_loop(1, intnpts, bodylamb, lambmat)
    
    # straight out of the notes
    allg = vdfdtheta(xfine, theta.astype(np.complex128))
    gradtheta = intdt*jnp.einsum('ij,ijk,i->k',lambout,allg,w)
    gradtheta *= (1.0/(2*numsteps))
    return obj, gradtheta

jnlwg = jit(newlagwithgrad)

In [83]:
jnewlagwithgrad = jit(newlagwithgrad)

In [101]:
# random initializer
def glorotinit():
    theta = []
    sd = np.sqrt(2.0 / (layerwidths[0]+layerwidths[1]))
    theta.append( np.random.normal(size=layerwidths[0]*layerwidths[1])*sd )
    sd = np.sqrt(2.0 / (layerwidths[1]+layerwidths[2]))
    theta.append( np.random.normal(size=layerwidths[1]*layerwidths[2])*sd )
    sd = np.sqrt(2.0 / (layerwidths[2]+layerwidths[3]))
    theta.append( np.random.normal(size=layerwidths[2]*layerwidths[3])*sd )
    sd = np.sqrt(2.0 / (layerwidths[3]+layerwidths[4]))
    theta.append( np.random.normal(size=layerwidths[3]*layerwidths[4])*sd )
    theta.append( np.zeros(numbiases) )
    theta = np.concatenate(theta)
    return theta

theta0 = glorotinit()

In [102]:
start = time.time()
myobj, mygrad = jnewlagwithgrad(jalltraj[0], theta0)
end = time.time()
print(end-start)

start = time.time()
jaxobj = jobj(jalltraj[0], theta0)
jaxgrad = jgradobj(jalltraj[0], theta0.astype(np.complex128))
end = time.time()
print(end-start)

print(jnp.mean(jnp.abs(jaxgrad-mygrad)))

0.0673370361328125
3.8646726608276367
19.78648017321262


In [103]:
jvnewlagwithgrad = jit(vmap(newlagwithgrad, in_axes=(0,None)))

In [104]:
# for debugging purposes
#tobj, tgrad = jvnewlagwithgrad(jalltraj, theta0)

In [117]:
# wrap functions so that we can use them in standard NumPy/SciPy
def objgradSP(theta):
    jaxobj, jaxgrads = jvnewlagwithgrad(jalltraj, theta)
    return jnp.mean(jaxobj).item(), np.array(jnp.mean(jaxgrads,0))

In [118]:
# trust region optimizer with SR1 update
res = so.minimize(fun=objgradSP, jac=True, x0=theta0, method='trust-constr', 
                  hess=so.SR1(), options={'gtol': 1e-3, 'xtol': 1e-3, 'verbose': 2})

# L-BFGS-B optimizer
# res = so.minimize(fun=objSP, jac=gradSP, x0=theta0, method='L-BFGS-B',
#                   options={'gtol': 1e-8, 'ftol':1e-8, 'iprint':1})

TypeError: '<' not supported between instances of 'complex' and 'float'

In [None]:
allpreds = vpredtraj(jalltraj[:,0,:], jnp.array(res.x))

In [None]:
allpreds.shape

In [None]:
# for a grid of initial conditions, generate trajectories
# this is a phase portrait of the LEARNED DYNAMICAL SYSTEM
plt.figure(figsize=(10,6))
for i in range(allpreds.shape[0]):
    plt.plot(allpreds[i,:,0],allpreds[i,:,1])

plt.xlim([-2*xL,2*xL])
plt.ylim([-1.5*xL,1.5*xL])
plt.show()


In [None]:
# this is a phase portrait of the GROUND TRUTH DYNAMICAL SYSTEM
plt.figure(figsize=(10,6))
for i in range(allpreds.shape[0]):
    plt.plot(alltraj[i,0,:],alltraj[i,1,:])

plt.xlim([-2*xL,2*xL])
plt.ylim([-1.5*xL,1.5*xL])
plt.show()
