In [89]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [90]:
# Load Dependencies
import matplotlib.pyplot as plt

import torch

import pyro
import pyro.distributions as dist

import pyciemss.utils as utils

import warnings
warnings.filterwarnings('ignore')

In [91]:
from pyciemss.ODE.frontend import compile_pp
from pyciemss.ODE.models import SVIIvR

In [92]:
# Setup Parameters

observed_tspan = utils.get_tspan(1, 7, 7)
new_tspan = utils.get_tspan(7, 89, 83)
full_tspan = utils.get_tspan(1, 89, 89)

num_samples = 500

# Total population, N.
N = 1.0
# Initial number of infected and recovered individuals, I0 and R0.
V0, I0, Iv0, R0 = 0., 81.0/100000, 0., 0.
# Everyone else, S0, is susceptible to infection initially.
S0 = N - I0 - Iv0 - V0 - R0

# 18 - 24 year olds
I_obs = torch.tensor([81.47, 84.3, 86.44, 89.66, 93.32, 94.1, 96.31])/100000

initial_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I0, Iv0, R0))
final_observed_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I_obs[-1], Iv0, R0))

In [93]:
model = SVIIvR(N)

In [94]:
solution, observations = model(initial_state, full_tspan)

In [95]:
solution

(tensor([0.9992, 0.9974, 0.9956, 0.9938, 0.9919, 0.9901, 0.9882, 0.9863, 0.9843,
         0.9824, 0.9804, 0.9783, 0.9762, 0.9741, 0.9719, 0.9696, 0.9673, 0.9649,
         0.9624, 0.9599, 0.9572, 0.9544, 0.9516, 0.9485, 0.9454, 0.9420, 0.9385,
         0.9349, 0.9310, 0.9268, 0.9225, 0.9179, 0.9130, 0.9077, 0.9022, 0.8963,
         0.8900, 0.8833, 0.8761, 0.8685, 0.8604, 0.8518, 0.8427, 0.8330, 0.8227,
         0.8117, 0.8002, 0.7880, 0.7752, 0.7618, 0.7478, 0.7331, 0.7178, 0.7020,
         0.6856, 0.6688, 0.6515, 0.6339, 0.6160, 0.5978, 0.5795, 0.5611, 0.5427,
         0.5244, 0.5062, 0.4882, 0.4706, 0.4533, 0.4364, 0.4200, 0.4041, 0.3887,
         0.3739, 0.3597, 0.3461, 0.3331, 0.3207, 0.3089, 0.2978, 0.2871, 0.2771,
         0.2676, 0.2586, 0.2502, 0.2422, 0.2347, 0.2276, 0.2209, 0.2147]),
 tensor([0.0000, 0.0016, 0.0033, 0.0049, 0.0065, 0.0082, 0.0098, 0.0114, 0.0130,
         0.0146, 0.0162, 0.0178, 0.0194, 0.0210, 0.0226, 0.0242, 0.0258, 0.0274,
         0.0290, 0.0306, 0.0321, 0

In [96]:
import json

prior_path = "./test/models/SVIIvR_simple/prior.json"
petri_path = "./test/models/SVIIvR_simple/petri.json"

with open(prior_path) as f:
    prior_json = json.load(f)

petri_G = utils.load(petri_path)
petri_G = utils.add_state_indicies(petri_G)

model_compiled = compile_pp(petri_G, prior_json)

In [97]:
model_compiled.param_prior()
model_compiled.deriv(0., initial_state)

(tensor(-0.0086), tensor(0.), tensor(-0.0002), tensor(0.0086), tensor(0.0002))

In [99]:
from pyro.poutine import trace

tr = trace(model).get_trace(initial_state, full_tspan)
tr_compiled = trace(model_compiled).get_trace(initial_state, utils.get_tspan(1, 3, 3))


In [100]:
tr.nodes

OrderedDict([('_INPUT',
              {'name': '_INPUT',
               'type': 'args',
               'args': ((tensor(0.9992),
                 tensor(0.),
                 tensor(0.0008),
                 tensor(0.),
                 tensor(0.)),
                tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
                        15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28.,
                        29., 30., 31., 32., 33., 34., 35., 36., 37., 38., 39., 40., 41., 42.,
                        43., 44., 45., 46., 47., 48., 49., 50., 51., 52., 53., 54., 55., 56.,
                        57., 58., 59., 60., 61., 62., 63., 64., 65., 66., 67., 68., 69., 70.,
                        71., 72., 73., 74., 75., 76., 77., 78., 79., 80., 81., 82., 83., 84.,
                        85., 86., 87., 88., 89.])),
               'kwargs': {}}),
             ('noise',
              {'type': 'sample',
               'name': 'noise',
               

In [101]:
# tr_compiled.nodes

In [102]:
solution, observations = model_compiled(initial_state, utils.get_tspan(1, 100, 100))

In [103]:
solution[2]

tensor([8.1000e-04, 5.9845e-04, 4.4208e-04, 3.2653e-04, 2.4114e-04, 1.7806e-04,
        1.3146e-04, 9.7037e-05, 7.1620e-05, 5.2853e-05, 3.8998e-05, 2.8770e-05,
        2.1222e-05, 1.5652e-05, 1.1542e-05, 8.5105e-06, 6.2741e-06, 4.6248e-06,
        3.4085e-06, 2.5117e-06, 1.8507e-06, 1.3634e-06, 1.0043e-06, 7.3963e-07,
        5.4465e-07, 4.0102e-07, 2.9523e-07, 2.1731e-07, 1.5994e-07, 1.1769e-07,
        8.6596e-08, 6.3707e-08, 4.6861e-08, 3.4465e-08, 2.5345e-08, 1.8636e-08,
        1.3700e-08, 1.0071e-08, 7.4018e-09, 5.4394e-09, 3.9968e-09, 2.9364e-09,
        2.1570e-09, 1.5843e-09, 1.1635e-09, 8.5433e-10, 6.2725e-10, 4.6046e-10,
        3.3798e-10, 2.4804e-10, 1.8202e-10, 1.3355e-10, 9.7972e-11, 7.1865e-11,
        5.2707e-11, 3.8652e-11, 2.8341e-11, 2.0778e-11, 1.5231e-11, 1.1164e-11,
        8.1815e-12, 5.9952e-12, 4.3925e-12, 3.2179e-12, 2.3571e-12, 1.7263e-12,
        1.2642e-12, 9.2564e-13, 6.7768e-13, 4.9608e-13, 3.6310e-13, 2.6574e-13,
        1.9445e-13, 1.4228e-13, 1.0409e-