In [158]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [159]:
# Load Dependencies
import matplotlib.pyplot as plt

import torch

import pyro
import pyro.distributions as dist
# from pyro.optim import Adam

import pyciemss.utils as utils

import warnings
warnings.filterwarnings('ignore')

In [160]:
from pyciemss.ODE.frontend import compile_pp
from pyciemss.ODE.models import SVIIvR

In [161]:
# Setup Parameters

observed_tspan = utils.get_tspan(1, 7, 7)
new_tspan = utils.get_tspan(7, 89, 83)
full_tspan = utils.get_tspan(1, 89, 89)

num_samples = 500

# Total population, N.
N = 1.0
# Initial number of infected and recovered individuals, I0 and R0.
V0, I0, Iv0, R0 = 0., 81.0/100000, 0., 0.
# Everyone else, S0, is susceptible to infection initially.
S0 = N - I0 - Iv0 - V0 - R0

# 18 - 24 year olds
I_obs = torch.tensor([81.47, 84.3, 86.44, 89.66, 93.32, 94.1, 96.31])/100000

initial_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I0, Iv0, R0))
final_observed_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I_obs[-1], Iv0, R0))

In [162]:
model = SVIIvR(N)

In [163]:
solution, observations = model(initial_state, full_tspan)

In [164]:
import json
from pyciemss.utils import petri_utils

prior_path = "test/scenario2/SVIIvR_prior.json"

with open(prior_path) as f:
    prior_json = json.load(f)

# petri_path = "test/scenario2/SVIIvR.json"
petri_path = "test/models/starter_kit_examples/CHIME-SIR/model_petri.json"
petri_G = utils.load(petri_path)
petri_G = utils.add_state_indicies(petri_G)

model_compiled = compile_pp(petri_G, prior_json)

In [165]:
model_compiled.param_prior()
test_state = tuple(torch.as_tensor(s) for s in  (S0, I0, R0))
model_compiled.deriv(0., test_state)

(tensor(-9.0484e-05), tensor(3.3713e-05), tensor(5.6770e-05))

In [166]:
from pyro.poutine import trace

tr = trace(model).get_trace(initial_state, full_tspan)
tr_compiled = trace(model_compiled).get_trace(test_state, utils.get_tspan(1, 3, 3))


In [167]:
# tr.nodes

In [168]:
# tr_compiled.nodes

In [169]:
solution, observations = model_compiled(test_state, utils.get_tspan(1, 100, 100))

In [170]:
solution[2]

tensor([0.0000e+00, 1.0255e-04, 2.2230e-04, 3.6211e-04, 5.2534e-04, 7.1590e-04,
        9.3834e-04, 1.1980e-03, 1.5009e-03, 1.8545e-03, 2.2669e-03, 2.7480e-03,
        3.3090e-03, 3.9629e-03, 4.7251e-03, 5.6129e-03, 6.6468e-03, 7.8500e-03,
        9.2496e-03, 1.0876e-02, 1.2766e-02, 1.4958e-02, 1.7498e-02, 2.0440e-02,
        2.3839e-02, 2.7762e-02, 3.2281e-02, 3.7473e-02, 4.3425e-02, 5.0229e-02,
        5.7980e-02, 6.6778e-02, 7.6723e-02, 8.7912e-02, 1.0044e-01, 1.1437e-01,
        1.2979e-01, 1.4672e-01, 1.6517e-01, 1.8514e-01, 2.0656e-01, 2.2934e-01,
        2.5336e-01, 2.7844e-01, 3.0440e-01, 3.3102e-01, 3.5807e-01, 3.8533e-01,
        4.1256e-01, 4.3955e-01, 4.6609e-01, 4.9202e-01, 5.1718e-01, 5.4144e-01,
        5.6472e-01, 5.8693e-01, 6.0802e-01, 6.2798e-01, 6.4679e-01, 6.6445e-01,
        6.8099e-01, 6.9644e-01, 7.1082e-01, 7.2419e-01, 7.3659e-01, 7.4806e-01,
        7.5867e-01, 7.6847e-01, 7.7749e-01, 7.8580e-01, 7.9344e-01, 8.0046e-01,
        8.0690e-01, 8.1281e-01, 8.1822e-