In [18]:
import os

from pyciemss.PetriNetODE.base import MiraPetriNetODESystem, BetaNoisePetriNetODESystem
from pyciemss.PetriNetODE.events import Event, StartEvent, LoggingEvent, ObservationEvent, StaticParameterInterventionEvent
import pyciemss

from pyciemss.PetriNetODE.interfaces import load_petri_model, setup_model, reset_model, intervene, sample, calibrate, optimize

In [32]:
STARTERKIT_PATH = "test/models/starter_kit_examples/"
MIRA_PATH = "test/models/evaluation_examples/"

filename = "CHIME-SIR/model_petri.json"
filename = os.path.join(STARTERKIT_PATH, filename)

model = BetaNoisePetriNetODESystem.from_mira(filename)
# model = MiraPetriNetODESystem.from_mira(filename)

start_event = StartEvent(0.0, {"S": 0.99, "I": 0.01, "R": 0.0})
model.load_event(start_event)

tspan = range(1, 10)
logging_events = [LoggingEvent(t) for t in tspan]
model.load_events(logging_events)

solution = model()

# See that the solution returns a dictionary where each value has length 10, one for each logging event.
assert len(solution["I"]) == len(solution["R"]) == len(solution["S"]) == len(tspan)

solution

{'I': tensor([0.0092, 0.0084, 0.0076, 0.0069, 0.0062, 0.0055, 0.0049, 0.0044, 0.0038]),
 'R': tensor([0.0093, 0.0179, 0.0257, 0.0327, 0.0391, 0.0447, 0.0498, 0.0543, 0.0583]),
 'S': tensor([0.9815, 0.9737, 0.9667, 0.9604, 0.9548, 0.9497, 0.9453, 0.9413, 0.9379])}

In [33]:
# Remove logging events
model.remove_logging_events()

# Add observations
observation_events = [ObservationEvent(1.1, {"S": 0.9, "I": 0.09, "R": 0.01}), 
                      ObservationEvent(2.1, {"S": 0.8, "I": 0.18, "R": 0.02}),
                      ObservationEvent(3.1, {"S": 0.7, "I": 0.27, "R": 0.03}),
                      ObservationEvent(4.1, {"S": 0.6, "I": 0.36, "R": 0.04})]

model.load_events(observation_events)

In [34]:
model._static_events

[StartEvent(time=0.0, initial_state={'S': tensor(0.9900), 'I': tensor(0.0100), 'R': tensor(0.)}),
 ObservationEvent(time=1.100000023841858, observation={'S': tensor(0.9000), 'I': tensor(0.0900), 'R': tensor(0.0100)}),
 ObservationEvent(time=2.0999999046325684, observation={'S': tensor(0.8000), 'I': tensor(0.1800), 'R': tensor(0.0200)}),
 ObservationEvent(time=3.0999999046325684, observation={'S': tensor(0.7000), 'I': tensor(0.2700), 'R': tensor(0.0300)}),
 ObservationEvent(time=4.099999904632568, observation={'S': tensor(0.6000), 'I': tensor(0.3600), 'R': tensor(0.0400)})]

In [35]:
# Show that inference works.

from pyro.infer.autoguide import AutoNormal
from pyro.poutine import block
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro

guide = AutoNormal(model)
optim = Adam({'lr': 0.03})
loss_f = Trace_ELBO(num_particles=1)
verbose = True

svi = SVI(model, guide, optim, loss=loss_f)

pyro.clear_param_store()

for j in range(100):
    # calculate the loss and take a gradient step
    # Passing a data argument to svi.step() will condition the model on the data.
    loss = svi.step()
    if verbose:
        if j % 25 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss))

[iteration 0001] loss: 25.2661
[iteration 0026] loss: 21.5414
[iteration 0051] loss: 21.2547
[iteration 0076] loss: 20.8130


In [36]:
# Remove the observation events
model.remove_observation_events()

In [37]:
# Load some static parameter intervention events
model.load_event(StaticParameterInterventionEvent(2.99, "beta", 0.0))
model.load_event(StaticParameterInterventionEvent(4.11, "beta", 10.))

# Load the logging events again
model.load_events(logging_events)

In [38]:
model._static_events

[StartEvent(time=0.0, initial_state={'S': tensor(0.9900), 'I': tensor(0.0100), 'R': tensor(0.)}),
 LoggingEvent(time=1),
 LoggingEvent(time=2),
 StaticParameterInterventionEvent(time=2.990000009536743, parameter=beta, value=0.0),
 LoggingEvent(time=3),
 LoggingEvent(time=4),
 StaticParameterInterventionEvent(time=4.110000133514404, parameter=beta, value=10.0),
 LoggingEvent(time=5),
 LoggingEvent(time=6),
 LoggingEvent(time=7),
 LoggingEvent(time=8),
 LoggingEvent(time=9)]

In [39]:
model._observation_var_names

[]

In [40]:
# use pyro predictive
from pyro.infer import Predictive

# Get the return value of model.
predictions = Predictive(model, guide=guide, num_samples=10)()

In [41]:
predictions

{'a_gamma': tensor([0.9626, 0.9235, 0.9970, 0.9383, 0.9102, 0.9466, 1.0690, 0.9930, 0.9729,
         1.0212]),
 'I_sol': tensor([[6.1250e-01, 2.4397e-01, 9.3330e-02, 3.5642e-02, 1.3620e-02, 5.2049e-03,
          1.9889e-03, 7.6000e-04, 2.9041e-04],
         [6.2359e-01, 2.5746e-01, 1.0238e-01, 4.0657e-02, 1.6153e-02, 6.4175e-03,
          2.5496e-03, 1.0129e-03, 4.0240e-04],
         [6.0297e-01, 2.3273e-01, 8.6052e-02, 3.1751e-02, 1.1725e-02, 4.3298e-03,
          1.5989e-03, 5.9041e-04, 2.1801e-04],
         [6.1936e-01, 2.5226e-01, 9.8855e-02, 3.8681e-02, 1.5144e-02, 5.9286e-03,
          2.3209e-03, 9.0859e-04, 3.5569e-04],
         [6.2744e-01, 2.6225e-01, 1.0567e-01, 4.2527e-02, 1.7121e-02, 6.8931e-03,
          2.7751e-03, 1.1172e-03, 4.4976e-04],
         [6.1701e-01, 2.4939e-01, 9.6929e-02, 3.7614e-02, 1.4604e-02, 5.6705e-03,
          2.2016e-03, 8.5477e-04, 3.3187e-04],
         [5.8374e-01, 2.1110e-01, 7.2706e-02, 2.4964e-02, 8.5837e-03, 2.9516e-03,
          1.0149e-03, 3.

In [None]:
# SCAFFOLDING FOR DYNAMIC EVENT HANDLING BELOW

import pyro
from torch import Tensor
from torchdiffeq import odeint_event, odeint

Time = Tensor
State = tuple[Tensor, ...]

class BaseODEModel(pyro.nn.PyroModule):

    def __init__(self, static_events: list[Event]):
        super().__init__()
        # This is a list of events that are always sorted by time.
        # TODO: probably pre-sort this list just in case.
        self.static_events = static_events

    def deriv(self, t: Time, state: State) -> State:
        raise NotImplementedError

    @pyro.nn.pyro_method
    def param_prior(self):
        raise NotImplementedError

    @pyro.nn.pyro_method
    def initial_conditions_prior(self):
        raise NotImplementedError

    # @pyro.nn.pyro_method
    # def observation_model(self, state: State, tspan: Time, ?) -> ?: …

    def solve(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        
        current_state = initial_state
        current_time = initial_time

        solution = torch.tensor([])

        for i, static_event in enumerate(self.static_events):
            
            # TODO: change below
            # Note: Immediate goal is to get self.solve() to generate ode solutions that stop at the event times.
            # Note: Next we will add log likelihood at those events using an observation model.

            # Note: each static_event is an Event object, which has a forward() method that returns a Tensor.
            # TODO: Chad, could you please make this actually make sense.
            event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)
    
            # TODO: Add log likelihood for event_solution.
            # Note: This will be done by adding an observation model to the ODE model and calling it with the event_solution as an argument.

            solution = torch.cat([solution, event_solution], dim=-1)
            # current = event_solution[-1]


    def forward(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        # Sample parameters from the prior. These parameters are generated as attributes of the model.
        self.param_prior()

        # TODO: Sample initial conditions from the prior instead of taking them as input.

        # Solve the ODE, taking into account any intervention and conditioning events.
        return self.solve(initial_state, initial_time)


