In [1]:
import os

from pyciemss.PetriNetODE.base import MiraPetriNetODESystem, ScaledBetaNoisePetriNetODESystem
from pyciemss.PetriNetODE.events import Event, StartEvent, LoggingEvent, ObservationEvent, StaticParameterInterventionEvent
import pyciemss

from pyciemss.PetriNetODE.interfaces import load_petri_model, setup_model, reset_model, intervene, sample, calibrate, optimize

In [7]:
MIRA_PATH = "test/models/evaluation_examples/scenario_1/"

filename = "scenario1_sir_mira.json"
filename = os.path.join(MIRA_PATH, filename)
model = ScaledBetaNoisePetriNetODESystem.from_mira(filename)
model


ScaledBetaNoisePetriNetODESystem(
	beta = Uniform(low: 0.09000000357627869, high: 0.10999999940395355),
	gamma = Uniform(low: 0.18000000715255737, high: 0.2199999988079071),
	pseudocount = 1
)

In [8]:


start_event = StartEvent(0.0, {"susceptible_population": 0.99, "infected_population": 0.01, "immune_population": 0.0})
model.load_event(start_event)

tspan = range(1, 10)
logging_events = [LoggingEvent(t) for t in tspan]
model.load_events(logging_events)

solution = model()

# See that the solution returns a dictionary where each value has length 10, one for each logging event.
assert len(solution["susceptible_population"]) == len(solution["infected_population"]) == len(solution["immune_population"]) == len(tspan)

solution

{'immune_population': tensor([0.0019, 0.0037, 0.0052, 0.0066, 0.0078, 0.0089, 0.0099, 0.0107, 0.0115]),
 'infected_population': tensor([0.0089, 0.0079, 0.0071, 0.0063, 0.0056, 0.0050, 0.0044, 0.0039, 0.0035]),
 'susceptible_population': tensor([0.9891, 0.9884, 0.9877, 0.9871, 0.9866, 0.9861, 0.9857, 0.9853, 0.9850])}

In [9]:
# Remove logging events
model.remove_logging_events()

# Add observations
observation_events = [ObservationEvent(1.1, {"susceptible_population": 0.9, "infected_population": 0.09, "immune_population": 0.01}), 
                      ObservationEvent(2.1, {"susceptible_population": 0.8, "infected_population": 0.18, "immune_population": 0.02}),
                      ObservationEvent(3.1, {"susceptible_population": 0.7, "infected_population": 0.27, "immune_population": 0.03}),
                      ObservationEvent(4.1, {"susceptible_population": 0.6, "infected_population": 0.36, "immune_population": 0.04})]

model.load_events(observation_events)
model

ScaledBetaNoisePetriNetODESystem(
	beta = Uniform(low: 0.09000000357627869, high: 0.10999999940395355),
	gamma = Uniform(low: 0.18000000715255737, high: 0.2199999988079071),
	pseudocount = 1
)

In [10]:
model._static_events

[StartEvent(time=0.0, initial_state={'susceptible_population': tensor(0.9900), 'infected_population': tensor(0.0100), 'immune_population': tensor(0.)}),
 ObservationEvent(time=1.100000023841858, observation={'susceptible_population': tensor(0.9000), 'infected_population': tensor(0.0900), 'immune_population': tensor(0.0100)}),
 ObservationEvent(time=2.0999999046325684, observation={'susceptible_population': tensor(0.8000), 'infected_population': tensor(0.1800), 'immune_population': tensor(0.0200)}),
 ObservationEvent(time=3.0999999046325684, observation={'susceptible_population': tensor(0.7000), 'infected_population': tensor(0.2700), 'immune_population': tensor(0.0300)}),
 ObservationEvent(time=4.099999904632568, observation={'susceptible_population': tensor(0.6000), 'infected_population': tensor(0.3600), 'immune_population': tensor(0.0400)})]

In [11]:
# Show that inference works.

from pyro.infer.autoguide import AutoNormal
from pyro.poutine import block
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro

guide = AutoNormal(model)
optim = Adam({'lr': 0.03})
loss_f = Trace_ELBO(num_particles=1)
verbose = True

svi = SVI(model, guide, optim, loss=loss_f)

pyro.clear_param_store()

for j in range(100):
    # calculate the loss and take a gradient step
    # Passing a data argument to svi.step() will condition the model on the data.
    loss = svi.step()
    if verbose:
        if j % 25 == 0:
            print("[iteration %04d] loss: %.4f" % (j + 1, loss))

[iteration 0001] loss: 36.5507
[iteration 0026] loss: 35.0466
[iteration 0051] loss: 33.5636
[iteration 0076] loss: 30.1845


In [12]:
# Remove the observation events
model.remove_observation_events()

In [13]:
# Load some static parameter intervention events
model.load_event(StaticParameterInterventionEvent(2.99, "beta", 0.0))
model.load_event(StaticParameterInterventionEvent(4.11, "beta", 10.))

# Load the logging events again
model.load_events(logging_events)

In [14]:
model._static_events

[StartEvent(time=0.0, initial_state={'susceptible_population': tensor(0.9900), 'infected_population': tensor(0.0100), 'immune_population': tensor(0.)}),
 LoggingEvent(time=1),
 LoggingEvent(time=2),
 StaticParameterInterventionEvent(time=2.990000009536743, parameter=beta, value=0.0),
 LoggingEvent(time=3),
 LoggingEvent(time=4),
 StaticParameterInterventionEvent(time=4.110000133514404, parameter=beta, value=10.0),
 LoggingEvent(time=5),
 LoggingEvent(time=6),
 LoggingEvent(time=7),
 LoggingEvent(time=8),
 LoggingEvent(time=9)]

In [15]:
model.G.parameters

{'beta': <mira.modeling.ModelParameter at 0x7fb9c540b010>,
 'gamma': <mira.modeling.ModelParameter at 0x7fb9c540b2e0>}

In [16]:
from pyciemss.PetriNetODE.base import get_name
[get_name(p) for p in model.G.parameters.values()]
# model.G.parameters

['beta', 'gamma']

In [17]:
model._observation_var_names
model()

{'immune_population': tensor([0.0020, 0.0038, 0.0054, 0.0068, 0.0800, 0.2577, 0.4017, 0.5177, 0.6112]),
 'infected_population': tensor([0.0089, 0.0078, 0.0069, 0.0056, 0.8868, 0.7422, 0.5983, 0.4823, 0.3888]),
 'susceptible_population': tensor([9.8911e-01, 9.8832e-01, 9.8763e-01, 9.8763e-01, 3.3194e-02, 8.7191e-06,
         1.1156e-08, 1.3891e-11, 8.2939e-11])}

In [18]:
# use pyro predictive
from pyro.infer import Predictive

# Get the return value of model.
predictions = Predictive(model, guide=guide, num_samples=10)()

In [19]:
predictions

{'gamma': tensor([0.2131, 0.2103, 0.2031, 0.2098, 0.2073, 0.2040, 0.2022, 0.2031, 0.1951,
         0.1888]),
 'immune_population_sol': tensor([[0.1062, 0.2776, 0.4163, 0.5283, 0.6189, 0.6920, 0.7511, 0.7989, 0.8375],
         [0.1048, 0.2745, 0.4120, 0.5235, 0.6139, 0.6871, 0.7464, 0.7945, 0.8335],
         [0.1015, 0.2665, 0.4013, 0.5113, 0.6011, 0.6744, 0.7343, 0.7831, 0.8230],
         [0.1046, 0.2740, 0.4114, 0.5228, 0.6131, 0.6863, 0.7457, 0.7938, 0.8328],
         [0.1035, 0.2712, 0.4077, 0.5186, 0.6087, 0.6820, 0.7415, 0.7899, 0.8293],
         [0.1019, 0.2675, 0.4027, 0.5129, 0.6028, 0.6760, 0.7358, 0.7846, 0.8243],
         [0.1012, 0.2656, 0.4001, 0.5099, 0.5997, 0.6730, 0.7328, 0.7818, 0.8217],
         [0.1016, 0.2666, 0.4014, 0.5114, 0.6012, 0.6745, 0.7344, 0.7832, 0.8230],
         [0.0978, 0.2576, 0.3892, 0.4974, 0.5865, 0.6598, 0.7201, 0.7697, 0.8105],
         [0.0949, 0.2505, 0.3795, 0.4862, 0.5746, 0.6478, 0.7084, 0.7586, 0.8001]]),
 'infected_population_sol': tensor

In [None]:
# SCAFFOLDING FOR DYNAMIC EVENT HANDLING BELOW

import pyro
from torch import Tensor
from torchdiffeq import odeint_event, odeint

Time = Tensor
State = tuple[Tensor, ...]

class BaseODEModel(pyro.nn.PyroModule):

    def __init__(self, static_events: list[Event]):
        super().__init__()
        # This is a list of events that are always sorted by time.
        # TODO: probably pre-sort this list just in case.
        self.static_events = static_events

    def deriv(self, t: Time, state: State) -> State:
        raise NotImplementedError

    @pyro.nn.pyro_method
    def param_prior(self):
        raise NotImplementedError

    @pyro.nn.pyro_method
    def initial_conditions_prior(self):
        raise NotImplementedError

    # @pyro.nn.pyro_method
    # def observation_model(self, state: State, tspan: Time, ?) -> ?: …

    def solve(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        
        current_state = initial_state
        current_time = initial_time

        solution = torch.tensor([])

        for i, static_event in enumerate(self.static_events):
            
            # TODO: change below
            # Note: Immediate goal is to get self.solve() to generate ode solutions that stop at the event times.
            # Note: Next we will add log likelihood at those events using an observation model.

            # Note: each static_event is an Event object, which has a forward() method that returns a Tensor.
            # TODO: Chad, could you please make this actually make sense.
            event_time, event_solution = odeint(self.deriv, current_state, current_time, event_fn=static_event)
    
            # TODO: Add log likelihood for event_solution.
            # Note: This will be done by adding an observation model to the ODE model and calling it with the event_solution as an argument.

            solution = torch.cat([solution, event_solution], dim=-1)
            # current = event_solution[-1]


    def forward(self, initial_state: State, initial_time: Time) -> tuple[State, State]:
        # Sample parameters from the prior. These parameters are generated as attributes of the model.
        self.param_prior()

        # TODO: Sample initial conditions from the prior instead of taking them as input.

        # Solve the ODE, taking into account any intervention and conditioning events.
        return self.solve(initial_state, initial_time)


