In [1]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax
import jax.numpy as jnp
from jax import jit, grad, jacobian, lax, vmap, pmap

import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import scipy.integrate as si
import scipy.optimize as so

import time

In [7]:
alltraj = np.load('zdot_tdcasscf_dt0.008268.npz')

In [8]:
# now pretend we don't know the vector field but we want to learn it
# set up a neural network model for f
nlayers = 4

# units per hidden layer
uphl = 8

# dimension of vector field
vfd = 4

# set up neural network parameters
layerwidths = [vfd,uphl,uphl,uphl,vfd]
numparams = 0
numweights = 0
numbiases = 0
for j in range(nlayers):
    numparams += layerwidths[j]*layerwidths[j+1] + layerwidths[j+1]
    numweights += layerwidths[j]*layerwidths[j+1]
    numbiases += layerwidths[j+1]

# print out total number of parameters
print(numparams)

# definition of actual neural network function
def neuralf(x, theta):
    filt = []
    si = 0
    ei = layerwidths[0]*layerwidths[1]
    filt.append( theta[si:ei].reshape((layerwidths[0],layerwidths[1])) )
    si += layerwidths[0]*layerwidths[1]
    ei += layerwidths[1]*layerwidths[2]
    filt.append( theta[si:ei].reshape((layerwidths[1],layerwidths[2])) )
    si += layerwidths[1]*layerwidths[2]
    ei += layerwidths[2]*layerwidths[3]
    filt.append( theta[si:ei].reshape((layerwidths[2],layerwidths[3])) )
    si += layerwidths[2]*layerwidths[3]
    ei += layerwidths[3]*layerwidths[4]
    filt.append( theta[si:ei].reshape((layerwidths[3],layerwidths[4])) )
    bias = []
    si += layerwidths[3]*layerwidths[4]
    ei += layerwidths[1]
    bias.append( theta[si:ei] )
    si += layerwidths[1]
    ei += layerwidths[2]
    bias.append( theta[si:ei] )
    si += layerwidths[2]
    ei += layerwidths[3]
    bias.append( theta[si:ei] )
    si += layerwidths[3]
    ei += layerwidths[4]
    bias.append( theta[si:ei] )
    inplyr = x
    h1 = jax.nn.selu( inplyr @ filt[0] + bias[0] )
    h2 = jax.nn.selu( h1 @ filt[1] + bias[1] )
    h3 = jax.nn.selu( h2 @ filt[2] + bias[2] )
    h4 = h3 @ filt[3] + bias[3]
    return h4

220


In [9]:
from jax.experimental import ode

In [10]:
# move training data from NumPy to JAX
jalltraj = jnp.array(alltraj)
print(jalltraj.shape)

(301, 400, 4)


In [11]:
dt = 0.008268
dilfac = 1
numsteps = 400

In [12]:
# define JAX function to compute one predicted trajectory
def predtraj(y0, theta):
    def rhsfunc(y, t):
        return neuralf(y, theta)
    
    intdt = (dt)/dilfac
    intnpts = (numsteps-1)*dilfac + 1
    inttint = np.arange(intnpts)*intdt
    return ode.odeint(rhsfunc, y0, inttint)


In [13]:
# now parallelize the above function to handle many initial conditions but one theta
vpredtraj = vmap(predtraj, in_axes=(0, None))

In [14]:
# now define loss (or objective function)
def obj(y, theta):
    # compute all predicted trajectories for fixed theta
    predfine = predtraj(y[0,:], theta)
    pred = predfine[::dilfac, :]
    # now compute mean squared errors between predictions and training data
    return jnp.mean(jnp.square(pred - y))


In [15]:
# use JAX to compile objective function
jobj = jit(obj)

# use JAX to compute gradient and compile it
jgradobj = jit(grad(obj,1))

In [16]:
dfdx = jacobian(neuralf, 0)
dfdtheta = jacobian(neuralf, 1)

In [17]:
vdfdx = vmap(dfdx, in_axes=(0,None))
vdfdtheta = vmap(dfdtheta, in_axes=(0,None))

In [18]:
# adjoint method to compute gradient of objective function
# here y stands for one trajectory of observations, of shape (numsteps, 2)
def newlagwithgrad(y, theta):
    intdt = (dt)/dilfac
    intnpts = (numsteps-1)*dilfac + 1
    inttint = np.arange(intnpts)*intdt
    
    # solve forward problem and compute residual
    def rhsfunc(y, t):
        return neuralf(y, theta)
    
    xfine = ode.odeint(rhsfunc, y[0,:], inttint)
    x = xfine[::dilfac,:]
    resid = 2*(x - y)
    
    # compute and save loss for later
    obj = jnp.mean(jnp.square(resid))
    
    # trapezoid rule quadrature weights
    w = jnp.concatenate([jnp.array([0.5]),jnp.ones(intnpts-2),jnp.array([0.5])])
    
    # backward-in-time loop body
    # Heun's method backward in time
    inth = -intdt
    def bodylamb(j, lambmfine):
        feval1 = -lambmfine[intnpts-j, :] @ dfdx(xfine[intnpts-j, :], theta)
        lambtilde = lambmfine[intnpts-j, :] + inth*feval1
        feval2 = -lambtilde @ dfdx(xfine[intnpts-j-1, :], theta)
        prevlamb = lambmfine[intnpts-j, :] + (inth/2.0)*(feval1 + feval2)
        prevlamb += ((intnpts-j-1) % dilfac == 0) * resid[(intnpts-j-1)//dilfac, :]
        return lambmfine.at[intnpts-j-1].set( prevlamb )
    
    # loop initialization
    lambmat = jnp.concatenate([ jnp.zeros((intnpts-1, vfd)), resid[[numsteps-1],:] ])
    
    # actually loop
    lambout = lax.fori_loop(1, intnpts, bodylamb, lambmat)
    
    # straight out of the notes
    allg = vdfdtheta(xfine, theta)
    gradtheta = intdt*jnp.einsum('ij,ijk,i->k',lambout,allg,w)
    gradtheta *= (1.0/(2*numsteps))
    return obj, gradtheta

In [19]:
jnewlagwithgrad = jit(newlagwithgrad)

In [20]:
# random initializer
def glorotinit():
    theta = []
    sd = np.sqrt(2.0 / (layerwidths[0]+layerwidths[1]))
    theta.append( np.random.normal(size=layerwidths[0]*layerwidths[1])*sd )
    sd = np.sqrt(2.0 / (layerwidths[1]+layerwidths[2]))
    theta.append( np.random.normal(size=layerwidths[1]*layerwidths[2])*sd )
    sd = np.sqrt(2.0 / (layerwidths[2]+layerwidths[3]))
    theta.append( np.random.normal(size=layerwidths[2]*layerwidths[3])*sd )
    sd = np.sqrt(2.0 / (layerwidths[3]+layerwidths[4]))
    theta.append( np.random.normal(size=layerwidths[3]*layerwidths[4])*sd )
    theta.append( np.zeros(numbiases) )
    theta = np.concatenate(theta)
    return theta

theta0 = glorotinit()

In [21]:
jalltraj[0].shape

(400, 4)

In [22]:
# #uncomment if you want a finite difference differentiation of the loss wrt theta
perturbedtheta = np.copy(theta0)
fdtot = []
for ind in range(perturbedtheta.shape[0]):
    perturbedtheta[ind] += 1e-6
    losstheta = jobj(jalltraj[0],theta0)
    lossperturbedtheta = jobj(jalltraj[0],perturbedtheta)
    fd = (lossperturbedtheta - losstheta)/1e-6
    fdtot.append(fd)
    perturbedtheta[ind] -= 1e-6
fdtot = np.array(fdtot)

In [23]:
start = time.time()
myobj, mygrad = jnewlagwithgrad(jalltraj[0], theta0)
end = time.time()
print(end-start)

start = time.time()
jaxobj = jobj(jalltraj[0], theta0)
jaxgrad = jgradobj(jalltraj[0], theta0)
end = time.time()
print(end-start)

print(jnp.mean(jnp.abs(jaxgrad-mygrad)))

1.699483871459961
2.1925599575042725
0.44577407120510754


In [34]:
print(jnp.mean(jnp.abs(jaxgrad-mygrad)))
print(jnp.mean(jnp.abs(jaxgrad-fdtot)))
print(jnp.mean(jnp.abs(fdtot-mygrad)))

0.44577407120510754
2.2103532370146133
2.175547445552911


In [25]:
jvnewlagwithgrad = jit(vmap(newlagwithgrad, in_axes=(0,None)))

In [26]:
# for debugging purposes
# tobj, tgrad = jvnewlagwithgrad(jalltraj, theta0)

In [27]:
# wrap functions so that we can use them in standard NumPy/SciPy
def objgradSP(theta):
    jaxobj, jaxgrads = jvnewlagwithgrad(jalltraj, theta)
    return jnp.mean(jaxobj).item(), np.array(jnp.mean(jaxgrads,0))

In [28]:
# trust region optimizer with SR1 update
res = so.minimize(fun=objgradSP, jac=True, x0=theta0, method='trust-constr', 
                  hess=so.SR1(), options={'gtol': 1e-20, 'xtol': 1e-20, 'verbose': 2})

# L-BFGS-B optimizer
# res = so.minimize(fun=siobj, jac=sigrad, x0=theta0, method='L-BFGS-B',
#                   options={'gtol': 1e-8, 'ftol':1e-8, 'iprint':1})

| niter |f evals|CG iter|  obj func   |tr radius |   opt    |  c viol  |
|-------|-------|-------|-------------|----------|----------|----------|
|   1   |   1   |   0   | +1.3883e+00 | 1.00e+00 | 7.09e+00 | 0.00e+00 |
|   2   |   2   |   1   | +1.3883e+00 | 2.80e-01 | 7.09e+00 | 0.00e+00 |
|   3   |   3   |   2   | +1.3883e+00 | 7.58e-02 | 7.09e+00 | 0.00e+00 |
|   4   |   4   |   3   | +6.3535e-01 | 1.52e-01 | 4.80e+00 | 0.00e+00 |
|   5   |   5   |   5   | +6.3535e-01 | 3.54e-02 | 4.80e+00 | 0.00e+00 |
|   6   |   6   |   7   | +1.6909e-01 | 2.48e-01 | 6.56e-01 | 0.00e+00 |
|   7   |   7   |  10   | +1.6909e-01 | 2.48e-02 | 6.56e-01 | 0.00e+00 |
|   8   |   8   |  13   | +1.5873e-01 | 3.99e-02 | 4.27e-01 | 0.00e+00 |
|   9   |   9   |  16   | +1.5746e-01 | 3.99e-02 | 2.83e-01 | 0.00e+00 |
|  10   |  10   |  19   | +1.5675e-01 | 3.99e-02 | 2.13e-01 | 0.00e+00 |
|  11   |  11   |  21   | +1.4466e-01 | 7.98e-02 | 3.65e-01 | 0.00e+00 |
|  12   |  12   |  26   | +1.3824e-01 | 8.22e-02 | 

In [29]:
allpreds = vpredtraj(jalltraj[:,0,:], jnp.array(res.x))

In [30]:
allpreds.shape

(301, 400, 4)