In [1]:
# code up SIRD model with deaths
# fit using adjoint method
# see how well it does on simulated data
# see how well it does on real data
# might have to play with noise parameter

In [2]:
import numpy as onp

In [3]:
soldat = onp.loadtxt('soldat.csv',delimiter=',')
print(soldat.shape)

(1001, 4)


In [4]:
from jax.config import config
config.update("jax_enable_x64", True)

import jax.numpy as np
import jax.nn
from jax import grad, jit, jacobian, random, vmap, lax
from jax.ops import index, index_update
from jax.experimental import ode

from sklearn import linear_model

%matplotlib inline
import matplotlib.pyplot as plt

In [5]:
nn = 1000
npts = 1001
n = npts
lentheta = 3
d = 4
dt = 0.1
tint = np.arange(n)*dt



In [6]:
# theta = (beta, gamma, mu)
def sird(x, t, theta):
    stheta = np.exp(theta)
    sdot = -stheta[0]*x[0]*x[1]/nn
    idot = stheta[0]*x[0]*x[1]/nn - stheta[1]*x[1] - stheta[2]*x[1]
    rdot = stheta[1]*x[1]
    ddot = stheta[2]*x[1]
    return np.array([sdot, idot, rdot, ddot])

In [7]:
sird(np.array([3.,4.,5.,6.]),0.5,np.array([1.0,0.5,0.1]))

DeviceArray([ -0.03261938, -10.98294937,   6.59488508,   4.42068367], dtype=float64)

In [8]:
# just-in-time (JIT) compiled version
fsird = jit(sird)

# use automatic differentiation and JIT together
mygradsird = jacobian(sird, 0)
fmygradsird = jit(mygradsird)

# use automatic differentiation and JIT together
mygradsirdtheta = jacobian(sird, 2)
fmygradsirdtheta = jit(mygradsirdtheta)


In [9]:
# for scipy.optimize
# z should have (xinit, curtheta)
# note that the t variable is **not** passed in as z[0]
def newlagwithgrad(xinit, curtheta):
    # solves the forward ODE using our current estimates of xinit and curtheta
    fsirdI = lambda y, t: fsird(y, t, curtheta)
    x = lax.stop_gradient(ode.odeint(fsirdI, t=tint, rtol=1e-9, atol=1e-9, y0=xinit)).T
    
    # set up adjoint ODE
    fadj = lambda y, t: -np.matmul(y, fmygradsird(y, t, curtheta))
    icmat = np.eye(d)
    adjtint = np.array([0, dt])
    
    # function that solves the adjoint ODE once for one initial condition
    @jit
    def solonce(y0):
        adjsol = lax.stop_gradient(ode.odeint(fadj, t=adjtint, rtol=1e-9, atol=1e-9, y0=y0))
        return adjsol[1,:]
    
    # this is to solve the adjoint ODE for all initial conditions in the icmat **at once**
    propagator = vmap(solonce, in_axes=(0))(icmat) # + (1e-6)*np.eye(d)
    backprop = lax.stop_gradient(np.linalg.inv(propagator))

    yminusx = lax.stop_gradient(y - x)
    
    @jit
    def growlamb(i, lamb):
        lambplus = np.matmul(lamb[i,:], backprop)
        outlamb = index_update(lamb, index[i+1, :], lambplus + yminusx[:,(npts-2-i)])
        return outlamb
        
    initlamb = np.vstack([np.expand_dims(yminusx[:,(npts-1)],0), np.zeros((npts-1, d))])
    lambminus = lax.fori_loop(0, npts-1, growlamb, initlamb)
            
    # compute current value of lagrangian
    allxdot = np.hstack([(x[:,[1]]-x[:,[0]]), (x[:,2:] - x[:,:-2])/2, (x[:,[npts-2]]-x[:,[npts-3]])])/dt
    
    @jit
    def goodfun(i, lag):
        f = fsird(x[:, i], tint[i], curtheta)
        lag1 = lag + np.dot(lambminus[npts-1-i], allxdot[:,i]-f)*dt
        return lag1
        
    lag = lax.fori_loop(0, npts-1, goodfun, 0.0)
    lag += np.sum(np.square(x - y))/2.0
    
    # compute gradients using lamb (solution of adjoint ODE)
    # gradient of L with respect to parameters theta
    initgradtheta = np.zeros(lentheta)
    
    @jit
    def gt1i(i, gt):
        g = fmygradsirdtheta(x[:, i], tint[i], curtheta).reshape((d, lentheta))  # nabla_theta f
        gradtheta = gt - np.matmul(lambminus[npts-1-i],g)*dt
        return gradtheta
    
    gradtheta = lax.fori_loop(0, npts-1, gt1i, initgradtheta)
    gradx0 = -lambminus[npts-2]
    
    return lag, gradx0, gradtheta, x

In [10]:
lagwithgrad = jit(newlagwithgrad)

In [19]:
# adjoint solver with GD (gradient descent)
y = soldat.T

# take as initial guess x = y
theta0 = -4*onp.abs(onp.random.normal(size=lentheta))
print(theta0)
x0 = y[:,0] # + 1e-8*onp.random.rand(d)

maxiters = 30000
step = 1e-9

x = x0.copy()
theta = theta0.copy()

lag, gradx0, gradtheta, xest = lagwithgrad(x, theta)
print(onp.linalg.norm(gradtheta))

# mys = 1e-2
for i in range(maxiters):
    lag, gradx0, gradtheta, xest = lagwithgrad(x, theta)
    if i % 1000 == 0:
        print(lag, onp.exp(theta))

    theta -= step*gradtheta
    # x0 -= step*gradx0



[-2.91982837 -0.02253266 -1.10660533]
431715.2375958681
636098698.9880548 [0.05394294 0.9777193  0.33067961]
635979044.044479 [0.06433861 1.24946965 0.25869111]
635902108.1676689 [0.07783262 1.50739896 0.21437367]
635841572.6798483 [0.09608699 1.74800841 0.18482067]
635780492.4281824 [0.12245382 1.97289995 0.16371181]
635700557.1266948 [0.16449799 2.18489957 0.14778625]
635560459.4906447 [0.24372799 2.38741639 0.13520529]
635156635.1909875 [0.45941204 2.58642636 0.12474232]
130005930.29482687 [2.34798569 0.92865073 0.23872606]
2088.965546889067 [0.40107127 0.03499123 0.0050688 ]
0.00011085715008056455 [0.40000004 0.035      0.005     ]
-1.245929431122475e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.2832043571753997e-07 [0.4   0.035 0.005]
-1.28