In [1]:
import os
import torch

from pyro.distributions import Uniform

from pyciemss.ODE.base import PetriNetODESystem, BetaNoisePetriNetODESystem
from pyciemss.ODE.events import Event, ObservationEvent, LoggingEvent, StartEvent, DynamicStopEvent
import pyciemss

In [2]:
STARTERKIT_PATH = "test/models/starter_kit_examples/"
MIRA_PATH = "test/models/evaluation_examples/"

filename = "CHIME-SIR/model_petri.json"
filename = os.path.join(STARTERKIT_PATH, filename)

model = BetaNoisePetriNetODESystem.from_mira(filename)
model.load_start_event(0.0, {"S": 0.9, "I": 0.1, "R": 0.0})

tspan = torch.linspace(1, 10, 10)
model.load_logging_events(tspan)

solution = model()

# See that the solution returns a dictionary where each value has length 10, one for each logging event.
assert len(solution["I"]) == len(solution["R"]) == len(solution["S"]) == len(tspan)

solution

  logging_events = [LoggingEvent(torch.tensor(t)) for t in times]


{'I': tensor([0.0861, 0.0688, 0.0520, 0.0378, 0.0267, 0.0185, 0.0127, 0.0086, 0.0058,
         0.0039]),
 'R': tensor([0.0965, 0.1764, 0.2385, 0.2846, 0.3176, 0.3407, 0.3567, 0.3675, 0.3749,
         0.3799]),
 'S': tensor([0.8175, 0.7548, 0.7095, 0.6776, 0.6557, 0.6407, 0.6306, 0.6238, 0.6193,
         0.6162])}

In [3]:
# Remove the logging events.
model.delete_logging_events()

# Add observation events.
# TODO: implement a pyciemss.condition operation that returns a model with observation_events
observations = {1.3: {"R": 0.2, "I":0.15}, 2.3: {"I": 0.1}}
model.load_observation_events(observations)

model()


{'I': tensor([]), 'R': tensor([]), 'S': tensor([])}

In [4]:
# Show that inference works.

from pyro.infer.autoguide import AutoNormal
from pyro.poutine import block
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro

guide = AutoNormal(model)
optim = Adam({'lr': 0.03})
loss_f = Trace_ELBO(num_particles=1)
verbose = True

svi = SVI(model, guide, optim, loss=loss_f)

pyro.clear_param_store()

for j in range(100):
    # calculate the loss and take a gradient step
    # Passing a data argument to svi.step() will condition the model on the data.
    loss = svi.step()
    if verbose:
        if j % 25 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss))

[iteration 0001] loss: 7.1361
[iteration 0026] loss: 6.8324
[iteration 0051] loss: 5.1047
[iteration 0076] loss: 3.1802


In [5]:
# use pyro predictive
from pyro.infer import Predictive

# Remove the observation events and add logging events.

model.delete_observation_events()
model.load_logging_events(torch.linspace(1, 10, 10))


# Get the return value of model.
predictions = Predictive(model, guide=guide, num_samples=10)()

In [6]:
predictions

{'a_beta': tensor([1.0571, 0.9763, 0.9913, 0.9369, 1.0276, 0.9485, 0.9565, 1.0730, 1.0156,
         1.0680]),
 'a_gamma': tensor([0.9417, 0.9620, 0.9400, 1.0504, 1.0620, 1.0572, 1.0083, 0.9713, 1.0538,
         1.0467]),
 'I_sol': tensor([[0.0962, 0.0849, 0.0698, 0.0545, 0.0409, 0.0298, 0.0213, 0.0151, 0.0105,
          0.0073],
         [0.0884, 0.0729, 0.0570, 0.0429, 0.0314, 0.0225, 0.0159, 0.0112, 0.0078,
          0.0054],
         [0.0914, 0.0777, 0.0623, 0.0478, 0.0356, 0.0259, 0.0186, 0.0131, 0.0092,
          0.0065],
         [0.0784, 0.0581, 0.0413, 0.0286, 0.0195, 0.0131, 0.0087, 0.0058, 0.0038,
          0.0025],
         [0.0835, 0.0648, 0.0477, 0.0339, 0.0235, 0.0160, 0.0108, 0.0072, 0.0048,
          0.0032],
         [0.0786, 0.0583, 0.0415, 0.0287, 0.0195, 0.0131, 0.0087, 0.0058, 0.0038,
          0.0025],
         [0.0831, 0.0648, 0.0483, 0.0349, 0.0246, 0.0171, 0.0117, 0.0080, 0.0054,
          0.0037],
         [0.0946, 0.0820, 0.0663, 0.0509, 0.0376, 0.0270, 0.019

In [None]:
# SCAFFOLDING FOR DYNAMIC EVENT HANDLING BELOW

import pyro
from torch import Tensor
from torchdiffeq import odeint_event, odeint

Time = Tensor
State = tuple[Tensor, ...]

class BaseODEModel(pyro.nn.PyroModule):

    def __init__(self, static_events: list[Event]):
        super().__init__()
        # This is a list of events that are always sorted by time.
        # TODO: probably pre-sort this list just in case.
        self.static_events = static_events

    def deriv(self, t: Time, state: State) -> State:
        raise NotImplementedError

    @pyro.nn.pyro_method
    def param_prior(self):
        raise NotImplementedError

    @pyro.nn.pyro_method
    def initial_conditions_prior(self):
        raise NotImplementedError

    # @pyro.nn.pyro_method
    # def observation_model(self, state: State, tspan: Time, ?) -> ?: …

    def solve(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        
        current_state = initial_state
        current_time = initial_time

        solution = torch.tensor([])

        for i, static_event in enumerate(self.static_events):
            
            # TODO: change below
            # Note: Immediate goal is to get self.solve() to generate ode solutions that stop at the event times.
            # Note: Next we will add log likelihood at those events using an observation model.

            # Note: each static_event is an Event object, which has a forward() method that returns a Tensor.
            # TODO: Chad, could you please make this actually make sense.
            event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)
    
            # TODO: Add log likelihood for event_solution.
            # Note: This will be done by adding an observation model to the ODE model and calling it with the event_solution as an argument.

            solution = torch.cat([solution, event_solution], dim=-1)
            # current = event_solution[-1]


    def forward(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        # Sample parameters from the prior. These parameters are generated as attributes of the model.
        self.param_prior()

        # TODO: Sample initial conditions from the prior instead of taking them as input.

        # Solve the ODE, taking into account any intervention and conditioning events.
        return self.solve(initial_state, initial_time)




In [None]:
event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)

In [None]:

# TODO: Chad, initialize a stop event here.
# stop_event = ...

# TODO: Chad, create a model the subclasses BaseODEModel and has all of the above missing methods implemented.
# deriv, param_prior, initial_conditions_prior, observation_model
# class ActualODEModel(BaseODEModel):
#   ...

# TODO: Chad, instantiate the model here.
# model = ActualODEModel([stop_event])

# TODO: Chad, Call model.forward() here.
# solution, _ = model(initial_state, tspan)

In [7]:
observations = [ObservationEvent(torch.tensor([2.1]), {"I": torch.tensor([0.1])})]


In [8]:
conditioned_model = pyciemss.condition(model, observations)

AttributeError: module 'pyciemss' has no attribute 'condition'