In [132]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [133]:
# Load Dependencies
import matplotlib.pyplot as plt

import torch

import pyro
import pyro.distributions as dist
# from pyro.optim import Adam

import pyciemss.utils as utils

import warnings
warnings.filterwarnings('ignore')

In [134]:
from pyciemss.ODE.frontend import compile_pp
from pyciemss.ODE.models import SVIIvR

In [144]:
# Setup Parameters

observed_tspan = utils.get_tspan(1, 7, 7)
new_tspan = utils.get_tspan(7, 89, 83)
full_tspan = utils.get_tspan(1, 89, 89)

num_samples = 500

# Total population, N.
N = 1.0
# Initial number of infected and recovered individuals, I0 and R0.
V0, I0, Iv0, R0 = 0., 81.0/100000, 0., 0.
# Everyone else, S0, is susceptible to infection initially.
S0 = N - I0 - Iv0 - V0 - R0

# 18 - 24 year olds
I_obs = torch.tensor([81.47, 84.3, 86.44, 89.66, 93.32, 94.1, 96.31])/100000

initial_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I0, Iv0, R0))
final_observed_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I_obs[-1], Iv0, R0))

In [145]:
model = SVIIvR(N)

In [146]:
solution, observations = model(initial_state, full_tspan)

In [147]:
import json
from pyciemss.utils import petri_utils

prior_path = "test/scenario2/SVIIvR_prior.json"

with open(prior_path) as f:
    prior_json = json.load(f)

# petri_path = "test/scenario2/SVIIvR.json"
petri_path = "test/models/starter_kit_examples/CHIME-SIR/model_petri.json"
petri_G = utils.load(petri_path)
petri_G = utils.add_state_indicies(petri_G)

model_compiled = compile_pp(petri_G, prior_json)

In [148]:
model_compiled.param_prior()
test_state = tuple(torch.as_tensor(s) for s in  (S0, I0, R0))
model_compiled.deriv(0., test_state)

(tensor(-9.8193e-05), tensor(-0.0002), tensor(0.0003))

In [149]:
from pyro.poutine import trace

tr = trace(model).get_trace(initial_state, full_tspan)
tr_compiled = trace(model_compiled).get_trace(test_state, utils.get_tspan(1, 3, 3))


In [150]:
# tr.nodes

In [151]:
# tr_compiled.nodes

In [154]:
solution, observations = model_compiled(test_state, utils.get_tspan(1, 100, 100))

In [157]:
solution[2]

tensor([0.0000e+00, 1.6369e-04, 3.4204e-04, 5.3633e-04, 7.4797e-04, 9.7851e-04,
        1.2296e-03, 1.5031e-03, 1.8009e-03, 2.1251e-03, 2.4782e-03, 2.8625e-03,
        3.2808e-03, 3.7360e-03, 4.2313e-03, 4.7702e-03, 5.3563e-03, 5.9938e-03,
        6.6868e-03, 7.4401e-03, 8.2587e-03, 9.1480e-03, 1.0114e-02, 1.1162e-02,
        1.2300e-02, 1.3535e-02, 1.4873e-02, 1.6324e-02, 1.7896e-02, 1.9598e-02,
        2.1439e-02, 2.3430e-02, 2.5581e-02, 2.7905e-02, 3.0411e-02, 3.3113e-02,
        3.6023e-02, 3.9153e-02, 4.2517e-02, 4.6129e-02, 5.0001e-02, 5.4146e-02,
        5.8578e-02, 6.3309e-02, 6.8352e-02, 7.3717e-02, 7.9415e-02, 8.5455e-02,
        9.1844e-02, 9.8588e-02, 1.0569e-01, 1.1315e-01, 1.2098e-01, 1.2915e-01,
        1.3768e-01, 1.4654e-01, 1.5573e-01, 1.6523e-01, 1.7502e-01, 1.8508e-01,
        1.9538e-01, 2.0589e-01, 2.1659e-01, 2.2744e-01, 2.3841e-01, 2.4947e-01,
        2.6056e-01, 2.7167e-01, 2.8275e-01, 2.9377e-01, 3.0470e-01, 3.1549e-01,
        3.2613e-01, 3.3658e-01, 3.4681e-