In [82]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [83]:
# Load Dependencies
import matplotlib.pyplot as plt

import torch

import pyro
import pyro.distributions as dist
# from pyro.optim import Adam

import pyciemss.utils as utils

import warnings
warnings.filterwarnings('ignore')

In [84]:
from pyciemss.ODE.frontend import compile_pp
from pyciemss.ODE.models import SVIIvR

In [85]:
# Setup Parameters

observed_tspan = utils.get_tspan(1, 7, 7)
new_tspan = utils.get_tspan(7, 89, 83)
full_tspan = utils.get_tspan(1, 89, 89)

num_samples = 500

# Total population, N.
N = 100000.0
# Initial number of infected and recovered individuals, I0 and R0.
V0, I0, Iv0, R0 = 0., 81.0, 0., 0.
# Everyone else, S0, is susceptible to infection initially.
S0 = N - I0 - Iv0 - V0 - R0

# 18 - 24 year olds
I_obs = torch.tensor([81.47, 84.3, 86.44, 89.66, 93.32, 94.1, 96.31])

initial_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I0, Iv0, R0))
final_observed_state = tuple(torch.as_tensor(s) for s in  (S0, V0, I_obs[-1], Iv0, R0))

In [86]:
model = SVIIvR(N)

In [87]:
solution, observations = model(initial_state, full_tspan)

In [88]:
import json
from pyciemss.utils import petri_utils

prior_path = "test/scenario2/SVIIvR_prior.json"

with open(prior_path) as f:
    prior_json = json.load(f)

# petri_path = "test/scenario2/SVIIvR.json"
petri_path = "test/models/starter_kit_examples/CHIME-SIR/model_petri.json"
petri_G = utils.load(petri_path)
petri_G = utils.add_state_indicies(petri_G)

model_compiled = compile_pp(petri_G, prior_json)

In [89]:
model_compiled.param_prior()
test_state = tuple(torch.as_tensor(s) for s in  (S0, I0, R0))
model_compiled.deriv(0., test_state)

(tensor(-1765958.5000), tensor(1765953.8750), tensor(4.5973))

In [98]:
from pyro.poutine import trace

tr = trace(model).get_trace(initial_state, full_tspan)
tr_compiled = trace(model_compiled).get_trace(test_state, utils.get_tspan(1, 3, 3))


In [101]:
# tr.nodes

In [100]:
# tr_compiled.nodes

In [102]:
model_compiled(test_state, utils.get_tspan(1, 3, 3))

((tensor([ 9.9919e+04, -1.3580e+06,  3.5664e+11]),
  tensor([ 8.1000e+01,  1.4580e+06, -3.5664e+11]),
  tensor([0.0000e+00, 2.1750e+01, 3.9152e+05])),
 (tensor([ 9.9922e+04, -1.3580e+06,  3.5664e+11]),
  tensor([ 6.9609e+01,  1.4580e+06, -3.5664e+11]),
  tensor([-3.3614e+00,  1.9346e+01,  3.9152e+05])))

In [69]:
observations_compiled

(tensor([ 9.9917e+04, -1.4802e+06,  4.5664e+11]),
 tensor([ 7.9246e+01,  1.5802e+06, -4.5664e+11]),
 tensor([-2.6109e+00,  1.8242e+01,  3.8308e+05]))