In [1]:
import os

from pyro.distributions import Uniform

from pyciemss.ODE.base import PetriNetODESystem, BetaNoisePetriNetODESystem
from pyciemss.ODE.events import Event, ObservationEvent, LoggingEvent, StartEvent, DynamicStopEvent
import pyciemss

In [2]:
from typing import Type

StartEvent

pyciemss.ODE.events.StartEvent

In [3]:
STARTERKIT_PATH = "test/models/starter_kit_examples/"
MIRA_PATH = "test/models/evaluation_examples/"

filename = "CHIME-SIR/model_petri.json"
filename = os.path.join(STARTERKIT_PATH, filename)

model = BetaNoisePetriNetODESystem.from_mira(filename)

start_event = StartEvent(0.0, {"S": 0.99, "I": 0.01, "R": 0.0})
model.load_event(start_event)

tspan = range(1, 10)
logging_events = [LoggingEvent(t) for t in tspan]
model.load_events(logging_events)

# tspan = torch.linspace(1, 10, 10)
# model.load_logging_events(tspan)

solution = model()

# See that the solution returns a dictionary where each value has length 10, one for each logging event.
assert len(solution["I"]) == len(solution["R"]) == len(solution["S"]) == len(tspan)

solution

{'I': tensor([0.0092, 0.0085, 0.0077, 0.0070, 0.0063, 0.0056, 0.0050, 0.0044, 0.0039]),
 'R': tensor([0.0099, 0.0189, 0.0272, 0.0348, 0.0415, 0.0476, 0.0530, 0.0578, 0.0620]),
 'S': tensor([0.9809, 0.9726, 0.9651, 0.9583, 0.9522, 0.9468, 0.9420, 0.9378, 0.9341])}

In [4]:
# Remove logging events
model.remove_logging_events()

# Add observations
observation_events = [ObservationEvent(1.1, {"S": 0.9, "I": 0.09, "R": 0.01}), 
                      ObservationEvent(2.1, {"S": 0.8, "I": 0.18, "R": 0.02}),
                      ObservationEvent(3.1, {"S": 0.7, "I": 0.27, "R": 0.03}),
                      ObservationEvent(4.1, {"S": 0.6, "I": 0.36, "R": 0.04})]

model.load_events(observation_events)

In [5]:
model._static_events

[StartEvent(time=0.0, initial_state={'S': tensor(0.9900), 'I': tensor(0.0100), 'R': tensor(0.)}),
 ObservationEvent(time=1.100000023841858, observation={'S': tensor(0.9000), 'I': tensor(0.0900), 'R': tensor(0.0100)}),
 ObservationEvent(time=2.0999999046325684, observation={'S': tensor(0.8000), 'I': tensor(0.1800), 'R': tensor(0.0200)}),
 ObservationEvent(time=3.0999999046325684, observation={'S': tensor(0.7000), 'I': tensor(0.2700), 'R': tensor(0.0300)}),
 ObservationEvent(time=4.099999904632568, observation={'S': tensor(0.6000), 'I': tensor(0.3600), 'R': tensor(0.0400)})]

In [6]:
# Show that inference works.

from pyro.infer.autoguide import AutoNormal
from pyro.poutine import block
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro

guide = AutoNormal(model)
optim = Adam({'lr': 0.03})
loss_f = Trace_ELBO(num_particles=1)
verbose = True

svi = SVI(model, guide, optim, loss=loss_f)

pyro.clear_param_store()

for j in range(100):
    # calculate the loss and take a gradient step
    # Passing a data argument to svi.step() will condition the model on the data.
    loss = svi.step()
    if verbose:
        if j % 25 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss))

[iteration 0001] loss: 16.4383
[iteration 0026] loss: 13.4627
[iteration 0051] loss: 12.5196
[iteration 0076] loss: 12.4522


In [7]:
# Remove the observation events
model.remove_observation_events()

In [8]:
# Load some static parameter intervention events
model.load_event(StaticParameterInterventionEvent(2.99, "beta", 0.0))
model.load_event(StaticParameterInterventionEvent(4.11, "beta", 10.))

# Load the logging events again
model.load_events(logging_events)

In [9]:
model._static_events

[StartEvent(time=0.0, initial_state={'S': tensor(0.9900), 'I': tensor(0.0100), 'R': tensor(0.)}),
 LoggingEvent(time=1),
 LoggingEvent(time=2),
 StaticParameterInterventionEvent(time=2.990000009536743, parameter=beta, value=0.0),
 LoggingEvent(time=3),
 LoggingEvent(time=4),
 StaticParameterInterventionEvent(time=4.110000133514404, parameter=beta, value=10.0),
 LoggingEvent(time=5),
 LoggingEvent(time=6),
 LoggingEvent(time=7),
 LoggingEvent(time=8),
 LoggingEvent(time=9)]

In [10]:
model._observation_var_names

[]

In [11]:
# use pyro predictive
from pyro.infer import Predictive

# Get the return value of model.
predictions = Predictive(model, guide=guide, num_samples=10)()

In [12]:
predictions

{'a_beta': tensor([0.9531, 1.0604, 0.9963, 1.0584, 1.0436, 0.9849, 1.0022, 0.9444, 0.9956,
         0.9495]),
 'a_gamma': tensor([0.9833, 0.9789, 1.0076, 1.0023, 0.9844, 0.9864, 1.0258, 1.0240, 0.9999,
         1.0351]),
 'I_sol': tensor([[0.0096, 0.0091, 0.0085, 0.0032, 0.6377, 0.2993, 0.1124, 0.0421, 0.0158],
         [0.0107, 0.0113, 0.0116, 0.0044, 0.6382, 0.2893, 0.1092, 0.0411, 0.0154],
         [0.0097, 0.0094, 0.0089, 0.0033, 0.6310, 0.2906, 0.1066, 0.0390, 0.0142],
         [0.0104, 0.0107, 0.0108, 0.0040, 0.6325, 0.2853, 0.1052, 0.0387, 0.0142],
         [0.0104, 0.0108, 0.0109, 0.0041, 0.6374, 0.2900, 0.1088, 0.0407, 0.0152],
         [0.0098, 0.0096, 0.0092, 0.0034, 0.6373, 0.2955, 0.1107, 0.0413, 0.0154],
         [0.0096, 0.0092, 0.0086, 0.0031, 0.6254, 0.2869, 0.1034, 0.0371, 0.0133],
         [0.0091, 0.0082, 0.0073, 0.0026, 0.6233, 0.2931, 0.1058, 0.0381, 0.0137],
         [0.0098, 0.0095, 0.0091, 0.0033, 0.6333, 0.2920, 0.1079, 0.0398, 0.0146],
         [0.0091, 0.008

In [13]:
model()

{'I': tensor([0.0102, 0.0102, 0.0101, 0.0041, 0.6588, 0.3154, 0.1273, 0.0513, 0.0206]),
 'R': tensor([0.0092, 0.0185, 0.0278, 0.0338, 0.2332, 0.6838, 0.8726, 0.9487, 0.9793]),
 'S': tensor([9.8066e-01, 9.7130e-01, 9.6210e-01, 9.6210e-01, 1.0798e-01, 7.6387e-04,
         9.5966e-05, 4.1588e-05, 2.9698e-05])}

In [None]:
# SCAFFOLDING FOR DYNAMIC EVENT HANDLING BELOW

import pyro
from torch import Tensor
from torchdiffeq import odeint_event, odeint

Time = Tensor
State = tuple[Tensor, ...]

class BaseODEModel(pyro.nn.PyroModule):

    def __init__(self, static_events: list[Event]):
        super().__init__()
        # This is a list of events that are always sorted by time.
        # TODO: probably pre-sort this list just in case.
        self.static_events = static_events

    def deriv(self, t: Time, state: State) -> State:
        raise NotImplementedError

    @pyro.nn.pyro_method
    def param_prior(self):
        raise NotImplementedError

    @pyro.nn.pyro_method
    def initial_conditions_prior(self):
        raise NotImplementedError

    # @pyro.nn.pyro_method
    # def observation_model(self, state: State, tspan: Time, ?) -> ?: …

    def solve(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        
        current_state = initial_state
        current_time = initial_time

        solution = torch.tensor([])

        for i, static_event in enumerate(self.static_events):
            
            # TODO: change below
            # Note: Immediate goal is to get self.solve() to generate ode solutions that stop at the event times.
            # Note: Next we will add log likelihood at those events using an observation model.

            # Note: each static_event is an Event object, which has a forward() method that returns a Tensor.
            # TODO: Chad, could you please make this actually make sense.
            event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)
    
            # TODO: Add log likelihood for event_solution.
            # Note: This will be done by adding an observation model to the ODE model and calling it with the event_solution as an argument.

            solution = torch.cat([solution, event_solution], dim=-1)
            # current = event_solution[-1]


    def forward(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        # Sample parameters from the prior. These parameters are generated as attributes of the model.
        self.param_prior()

        # TODO: Sample initial conditions from the prior instead of taking them as input.

        # Solve the ODE, taking into account any intervention and conditioning events.
        return self.solve(initial_state, initial_time)




In [None]:
event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)

In [None]:

# TODO: Chad, initialize a stop event here.
# stop_event = ...

# TODO: Chad, create a model the subclasses BaseODEModel and has all of the above missing methods implemented.
# deriv, param_prior, initial_conditions_prior, observation_model
# class ActualODEModel(BaseODEModel):
#   ...

# TODO: Chad, instantiate the model here.
# model = ActualODEModel([stop_event])

# TODO: Chad, Call model.forward() here.
# solution, _ = model(initial_state, tspan)

In [7]:
observations = [ObservationEvent(torch.tensor([2.1]), {"I": torch.tensor([0.1])})]


In [8]:
conditioned_model = pyciemss.condition(model, observations)

AttributeError: module 'pyciemss' has no attribute 'condition'