# Introduction
------------

## Extremely Brief Probabilistic Programming Background
Probabilistic programming languages like `Pyro`, `Turing.jl`, `Gen.jl`, `Stan`, etc. (partially) automate the difficult task of probabilistic modeling and inference.

TLDR: Bayes Theorem is very flexible and powerful, but hard to work with!

These technologies build on techniques in programming languages theory, Bayesian statistics, and probabilistic machine learning to provide efficient model-agnostic approximate solutions to hard probabilsitic inference problems.

## General ASKEM Goal 1: 
Represent dynamical systems models with uncertainty as probabilistic programs in a probabilistic programming language.

## General ASKEM Goal 2:
Represent ASKEM questions in terms of primitive query operations that PPLs (partially) automate for any probabilistic program. 

In `Pyro` and `CausalPyro`, the probabilistic programming languages we're using and developing: this includes operations like `sample`, `condition`, `intervene`, etc.

## Ensemble Challenge Goal:
Demonstrate how probabilsitic programming lets us think about probabilsitic models the way we think about software.

First, we'll demonstrate how we can build probabilistic ensemble models compositionally from probabilsitic ODE models, with very few lines of code, and without losing any probabilsitic modeling functionality.

Second, we'll show how to use our ASKEM-specific abstractions for building ensemble models in a no-code workflow. Using these abstractions, we'll make example forecasts based on historical data.

-----------

# A Brief Tour of Ensemble Modeling in PyCIEMSS

```python
class DynamicalSystem(pyro.nn.module):
    
    ...

    def forward(self, *args, **kwargs) -> Solution:
        '''
        Joint distribution over model parameters, trajectories, and noisy observations.
        '''
        # Setup the anything the dynamical system needs before solving.
        self.setup_before_solve()

        # Sample parameters from the prior
        self.param_prior()

        # "Solve" the Dynamical System
        solution = self.get_solution(*args, **kwargs)        

        # Add the observation likelihoods for probabilistic inference (if applicable)
        self.add_observation_likelihoods(solution)

        return self.log_solution(solution) 
```

```python
class ODESystem(DynamicalSystem):
    
    ...

    def param_prior(self):
        # Sample all of the parameters from the prior and store them as attributes of the ODESystem.
        ...
    
    def get_solution(*args, **kwargs):
        # Run an off-the-shelf ODE solver, using the parameters generated from the call to self.param_prior()
        # Make sure to evaluate at all points the user wants logged, and at all observations.
        ...
        local_solution = torchdiffeq.odeint(self.deriv, initial_state, local_tspan)
        ...
    
```

```python
class EnsembleSystem(DynamicalSystem):

        def __init__(self, 
                     models: List[DynamicalSystem], 
                     dirichlet_alpha: torch.tensor,
                     solution_mappings: Callable) -> None:

        self.models = models
        self.dirichlet_alpha = dirichlet_alpha
        self.solution_mappings = solution_mappings

    ...

    def param_prior(self):
        # The prior distribution over parameters in an ensemble is just the prior distribution over each constituent model's parameters.
        for i, model in enumerate(self.models):
            with scope(prefix=f'model_{i}'):
                model.param_prior()
    
    def get_solution(self, *args, **kwargs):
        
        # Sample model weights from a Dirichlet distribution
        model_weights = pyro.sample('model_weights', pyro.distributions.Dirichlet(self.dirichlet_alpha))

        # Solve the Dynamical System for each model in self.models, mapping outputs to the shared state representation.
        solutions = [mapping(model.get_solution(*args, **kwargs)) for model, mapping in zip(self.models, self.solution_mappings)]

        # Combine the ensemble solutions, scaled by `model_weights`.
        solution = {k: sum([model_weights[i] * v[k] for i, v in enumerate(solutions)]) for k in solutions[0].keys()}

        return solution
```


---

# A Brief Tour of ASKEM-specific interfaces

Here we show how PyCIEMSS can be used in a no-code (or low-code) modeling workflow

In [1]:
# First, let's load the dependencies and set up the notebook environment.
import os

from pyciemss.PetriNetODE.interfaces import load_petri_model
from pyciemss.Ensemble.interfaces import setup_model, reset_model, intervene, sample, calibrate, optimize

## Load Data

In [2]:
MIRA_PATH = "test/models/april_ensemble_demo/"

filename1 = "BIOMD0000000955_template_model.json"
filename1 = os.path.join(MIRA_PATH, filename1)
model1 = load_petri_model(filename1, add_uncertainty=True)


# TODO: put this into the interfaces
start_state1 = {k[0]: v.data['initial_value'] for k, v in model1.G.variables.items()}

model1

ScaledBetaNoisePetriNetODESystem(
	beta = Uniform(low: 0.00989999994635582, high: 0.01209999993443489),
	gamma = Uniform(low: 0.41040000319480896, high: 0.5016000270843506),
	delta = Uniform(low: 0.00989999994635582, high: 0.01209999993443489),
	alpha = Uniform(low: 0.5130000114440918, high: 0.6269999742507935),
	epsilon = Uniform(low: 0.15389999747276306, high: 0.18809999525547028),
	zeta = Uniform(low: 0.11249999701976776, high: 0.13750000298023224),
	XXlambdaXX = Uniform(low: 0.03060000017285347, high: 0.03739999979734421),
	eta = Uniform(low: 0.11249999701976776, high: 0.13750000298023224),
	rho = Uniform(low: 0.03060000017285347, high: 0.03739999979734421),
	theta = Uniform(low: 0.33390000462532043, high: 0.4081000089645386),
	kappa = Uniform(low: 0.015300000086426735, high: 0.018699999898672104),
	mu = Uniform(low: 0.015300000086426735, high: 0.018699999898672104),
	nu = Uniform(low: 0.024299999698996544, high: 0.02969999983906746),
	xi = Uniform(low: 0.015300000086426735, high: 

In [3]:
filename2 = "BIOMD0000000960_template_model.json"
filename2 = os.path.join(MIRA_PATH, filename2)
model2 = load_petri_model(filename2, add_uncertainty=True)


# TODO: put this into the interfaces
start_state2 = {k[0]: v.data['initial_value'] for k, v in model2.G.variables.items()}

model2



ScaledBetaNoisePetriNetODESystem(
	mira_param_0 = Uniform(low: 0.0, high: 0.10000000149011612),
	mira_param_1 = Uniform(low: 0.0, high: 0.10000000149011612),
	mira_param_2 = Uniform(low: 0.0, high: 0.10000000149011612),
	mira_param_3 = Uniform(low: 0.02098800055682659, high: 0.025652000680565834),
	mira_param_4 = Uniform(low: 0.37501201033592224, high: 0.45834800601005554),
	mira_param_5 = Uniform(low: 0.4526999890804291, high: 0.5533000230789185),
	mira_param_6 = Uniform(low: 0.23669999837875366, high: 0.28929999470710754),
	mira_param_7 = Uniform(low: 0.0027000000700354576, high: 0.0032999999821186066),
	mira_param_8 = Uniform(low: 1.4759999513626099, high: 1.8040000200271606),
	mira_param_9 = Uniform(low: 0.0, high: 0.10000000149011612),
	mira_param_10 = Uniform(low: 0.007199999876320362, high: 0.008799999952316284),
	mira_param_11 = Uniform(low: 0.12690000236034393, high: 0.1551000028848648),
	pseudocount = 1.0
)

In [4]:
filename3 = "BIOMD0000000983_template_model.json"
filename3 = os.path.join(MIRA_PATH, filename3)
model3 = load_petri_model(filename3, add_uncertainty=True)

# TODO: put this into the interfaces
start_state3 = {k[0]: v.data['initial_value'] for k, v in model3.G.variables.items()}
start_state3['Deceased'] = 0.0

model3

ScaledBetaNoisePetriNetODESystem(
	mira_param_0 = Uniform(low: 7.614000097078133e-09, high: 9.306000414710525e-09),
	mira_param_1 = Uniform(low: 1.5228000194156266e-08, high: 1.861200082942105e-08),
	mira_param_2 = Uniform(low: 1.9035000242695332e-09, high: 2.326500103677631e-09),
	mira_param_3 = Uniform(low: 3.8070000485390665e-09, high: 4.653000207355262e-09),
	mira_param_4 = Uniform(low: 0.019285714253783226, high: 0.023571427911520004),
	mira_param_5 = Uniform(low: 0.035999998450279236, high: 0.04399999976158142),
	mira_param_6 = Uniform(low: 0.14399999380111694, high: 0.17599999904632568),
	mira_param_7 = Uniform(low: 0.09000000357627869, high: 0.10999999940395355),
	mira_param_8 = Uniform(low: 0.22499999403953552, high: 0.2750000059604645),
	mira_param_9 = Uniform(low: 0.044999998062849045, high: 0.054999999701976776),
	mira_param_10 = Uniform(low: 0.015300000086426735, high: 0.018699999898672104),
	pseudocount = 1.0
)

In [5]:
from math import isclose

solution_ratio21 = start_state2['Infectious'] / start_state1['Infected']
solution_ratio31 = (start_state3['Infected_reported'] + start_state3['Infected_unreported']) / start_state1['Infected']

solution_mapping1 = lambda x : {"Infected": x["Infected"]}
solution_mapping2 = lambda x : {"Infected": x["Infectious"] / solution_ratio21}
solution_mapping3 = lambda x : {"Infected": (x["Infected_reported"] + x["Infected_unreported"]) / solution_ratio31}

# Assert that all of the variables in the solution mappings are the same.
assert(set(solution_mapping1(start_state1).keys()) 
       == set(solution_mapping2(start_state2).keys())
       == set(solution_mapping3(start_state3).keys()))

# Assert that the solution mappings are correct.
assert(isclose(solution_mapping1(start_state1)["Infected"], solution_mapping2(start_state2)["Infected"]))
assert(isclose(solution_mapping1(start_state1)["Infected"], solution_mapping3(start_state3)["Infected"]))

In [6]:
# Setup the Ensemble

models = [model1, model2, model3]
weights = [1/3, 1/3, 1/3]
start_time = 0.0

start_states = [start_state1, start_state2, start_state3]
solution_mappings = [solution_mapping1, solution_mapping2, solution_mapping3]

total_population = 1.0
dirichlet_concentration = 10.0

ensemble = setup_model(models, 
                       weights, 
                       solution_mappings, 
                       start_time, 
                       start_states, 
                       total_population, 
                       dirichlet_concentration=dirichlet_concentration)
ensemble

Ensemble of 3 models. 

 	Dirichlet Alpha: tensor([3.3333, 3.3333, 3.3333]). 

 	Models: [ScaledBetaNoisePetriNetODESystem(
	beta = Uniform(low: 0.00989999994635582, high: 0.01209999993443489),
	gamma = Uniform(low: 0.41040000319480896, high: 0.5016000270843506),
	delta = Uniform(low: 0.00989999994635582, high: 0.01209999993443489),
	alpha = Uniform(low: 0.5130000114440918, high: 0.6269999742507935),
	epsilon = Uniform(low: 0.15389999747276306, high: 0.18809999525547028),
	zeta = Uniform(low: 0.11249999701976776, high: 0.13750000298023224),
	XXlambdaXX = Uniform(low: 0.03060000017285347, high: 0.03739999979734421),
	eta = Uniform(low: 0.11249999701976776, high: 0.13750000298023224),
	rho = Uniform(low: 0.03060000017285347, high: 0.03739999979734421),
	theta = Uniform(low: 0.33390000462532043, high: 0.4081000089645386),
	kappa = Uniform(low: 0.015300000086426735, high: 0.018699999898672104),
	mu = Uniform(low: 0.015300000086426735, high: 0.018699999898672104),
	nu = Uniform(low: 0.02429

In [7]:
# Sample from the Ensemble

timepoints = [1.0, 5.0, 10.0]
num_samples = 8
ensemble_solution = sample(ensemble, timepoints, num_samples)
ensemble_solution


{'model_0/beta': tensor([0.0107, 0.0104, 0.0101, 0.0102, 0.0111, 0.0113, 0.0117, 0.0118]),
 'model_0/gamma': tensor([0.4954, 0.4395, 0.4965, 0.4339, 0.4438, 0.4888, 0.4860, 0.4574]),
 'model_0/delta': tensor([0.0102, 0.0116, 0.0109, 0.0108, 0.0113, 0.0120, 0.0120, 0.0116]),
 'model_0/alpha': tensor([0.5835, 0.5886, 0.5627, 0.6017, 0.5670, 0.6064, 0.5804, 0.5496]),
 'model_0/epsilon': tensor([0.1765, 0.1819, 0.1636, 0.1759, 0.1606, 0.1814, 0.1845, 0.1847]),
 'model_0/zeta': tensor([0.1246, 0.1287, 0.1300, 0.1148, 0.1285, 0.1145, 0.1261, 0.1281]),
 'model_0/XXlambdaXX': tensor([0.0315, 0.0317, 0.0346, 0.0355, 0.0336, 0.0351, 0.0323, 0.0352]),
 'model_0/eta': tensor([0.1280, 0.1305, 0.1225, 0.1262, 0.1324, 0.1174, 0.1344, 0.1350]),
 'model_0/rho': tensor([0.0330, 0.0322, 0.0307, 0.0307, 0.0364, 0.0345, 0.0311, 0.0358]),
 'model_0/theta': tensor([0.3977, 0.3801, 0.3349, 0.3627, 0.4000, 0.3397, 0.3913, 0.3896]),
 'model_0/kappa': tensor([0.0161, 0.0183, 0.0184, 0.0169, 0.0165, 0.0160, 0.018

In [8]:
data = [(1.1, {"Infected": 0.003}), (1.2, {"Infected": 0.005})]

inferred_parameters = calibrate(ensemble, data, num_iterations=100, verbose=True)

iteration 0: loss = 121.81931269168854
iteration 25: loss = 83.86965811252594
iteration 50: loss = 56.698450922966
iteration 75: loss = 45.0703843832016


In [11]:
forecasts = sample(ensemble, timepoints, num_samples, inferred_parameters)
forecasts

{'model_0/beta': tensor([0.0116, 0.0104, 0.0109, 0.0109, 0.0106, 0.0111, 0.0118, 0.0114]),
 'model_0/gamma': tensor([0.4782, 0.4203, 0.4472, 0.4657, 0.4172, 0.4556, 0.4303, 0.4709]),
 'model_0/delta': tensor([0.0115, 0.0113, 0.0116, 0.0113, 0.0105, 0.0118, 0.0114, 0.0108]),
 'model_0/alpha': tensor([0.5262, 0.5824, 0.5639, 0.5621, 0.5608, 0.5868, 0.5840, 0.5672]),
 'model_0/epsilon': tensor([0.1820, 0.1556, 0.1728, 0.1761, 0.1654, 0.1749, 0.1602, 0.1680]),
 'model_0/zeta': tensor([0.1243, 0.1143, 0.1306, 0.1294, 0.1241, 0.1320, 0.1191, 0.1159]),
 'model_0/XXlambdaXX': tensor([0.0350, 0.0328, 0.0348, 0.0366, 0.0329, 0.0368, 0.0357, 0.0331]),
 'model_0/eta': tensor([0.1330, 0.1187, 0.1320, 0.1229, 0.1165, 0.1274, 0.1274, 0.1272]),
 'model_0/rho': tensor([0.0358, 0.0318, 0.0346, 0.0334, 0.0325, 0.0322, 0.0332, 0.0347]),
 'model_0/theta': tensor([0.3914, 0.3517, 0.3714, 0.3847, 0.3689, 0.3951, 0.3489, 0.3693]),
 'model_0/kappa': tensor([0.0159, 0.0171, 0.0172, 0.0168, 0.0172, 0.0187, 0.016