# A tour of PyCIEMSS interfaces and functionality

### Load dependencies and interfaces

In [1]:
import os
import pyciemss
import torch

### Select models and data

In [2]:
MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"
DATA_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/datasets/"

model1 = os.path.join(MODEL_PATH, "SEIRHD_NPI_Type1_petrinet.json")
model2 = os.path.join(MODEL_PATH, "SEIRHD_NPI_Type2_petrinet.json")
model3 = os.path.join(MODEL_PATH, "SIR_stockflow.json")

dataset1 = os.path.join(DATA_PATH, "traditional.csv")

### Set parameters for sampling

In [3]:
start_time = 0.0
end_time = 100.
logging_step_size = 10.0
num_samples = 3

## Sample interface
Take `num_samples` number of samples from the (prior) distribution invoked by the chosen model.

### Sample from model 1

In [4]:
result1 = pyciemss.sample(model1, end_time, logging_step_size, num_samples, start_time=start_time)
result1["unprocessed_result"]

{'persistent_beta_c': tensor([0.7177, 0.3379, 0.1606]),
 'persistent_kappa': tensor([0.5890, 0.3083, 0.4639]),
 'persistent_gamma': tensor([0.3956, 0.2760, 0.2185]),
 'persistent_hosp': tensor([0.0828, 0.1585, 0.0967]),
 'persistent_death_hosp': tensor([0.0229, 0.0875, 0.0369]),
 'persistent_I0': tensor([11.4742,  5.2734, 10.8078]),
 'D_state': tensor([[9.2631e-02, 3.6183e-01, 9.0445e-01, 1.9749e+00, 4.0833e+00, 8.2349e+00,
          1.6407e+01, 3.2483e+01, 6.4063e+01],
         [3.8682e-01, 1.2385e+00, 2.2912e+00, 3.5222e+00, 4.9526e+00, 6.6135e+00,
          8.5418e+00, 1.0780e+01, 1.3379e+01],
         [1.2958e-01, 6.5710e-01, 2.2127e+00, 6.7344e+00, 1.9866e+01, 5.7969e+01,
          1.6827e+02, 4.8541e+02, 1.3797e+03]]),
 'E_state': tensor([[6.3783e+01, 1.2562e+02, 2.4743e+02, 4.8730e+02, 9.5955e+02, 1.8889e+03,
          3.7158e+03, 7.3006e+03, 1.3206e+04],
         [2.7773e+01, 3.2162e+01, 3.7340e+01, 4.3352e+01, 5.0332e+01, 5.8435e+01,
          6.7842e+01, 7.8762e+01, 7.7455e+0

In [5]:
result1['data'].head()

Unnamed: 0,timepoint_id,sample_id,persistent_beta_c_param,persistent_kappa_param,persistent_gamma_param,persistent_hosp_param,persistent_death_hosp_param,persistent_I0_param,D_state_state,E_state_state,H_state_state,I_state_state,R_state_state,S_state_state,infected_observable_state,exposed_observable_state,hospitalized_observable_state,dead_observable_state
0,0,0,0.717676,0.588992,0.395557,0.082779,0.022923,11.474162,0.092631,63.783447,3.867057,34.406448,91.571045,19339846.0,34.406448,63.783447,3.867057,0.092631
1,1,0,0.717676,0.588992,0.395557,0.082779,0.022923,11.474162,0.361827,125.622398,8.241694,67.780319,281.638306,19339580.0,67.780319,125.622398,8.241694,0.361827
2,2,0,0.717676,0.588992,0.395557,0.082779,0.022923,11.474162,0.904454,247.428192,16.318035,133.50264,656.542053,19339008.0,133.50264,247.428192,16.318035,0.904454
3,3,0,0.717676,0.588992,0.395557,0.082779,0.022923,11.474162,1.974892,487.297546,32.150703,262.932709,1395.010498,19337860.0,262.932709,487.297546,32.150703,1.974892
4,4,0,0.717676,0.588992,0.395557,0.082779,0.022923,11.474162,4.083286,959.550537,63.316792,517.769897,2849.331543,19335644.0,517.769897,959.550537,63.316792,4.083286


### Sample from model 2

In [6]:
result2 = pyciemss.sample(model2, end_time, logging_step_size, num_samples, start_time=start_time)
result2['data'].head()

Unnamed: 0,timepoint_id,sample_id,persistent_beta_c_param,persistent_beta_nc_param,persistent_kappa_param,persistent_gamma_param,persistent_hosp_param,persistent_death_hosp_param,persistent_I0_param,D_state_state,E_state_state,H_state_state,I_state_state,R_state_state,S_state_state,infected_observable_state,exposed_observable_state,hospitalized_observable_state,dead_observable_state
0,0,0,0.754962,0.478901,0.545487,0.282183,0.119526,0.067823,5.930447,0.295847,71.128304,4.766628,46.065563,71.311386,19339848.0,46.065563,71.128304,4.766628,0.295847
1,1,0,0.754962,0.478901,0.545487,0.282183,0.119526,0.067823,5.930447,1.493463,200.243073,14.373006,129.784225,288.610687,19339408.0,129.784225,200.243073,14.373006,1.493463
2,2,0,0.754962,0.478901,0.545487,0.282183,0.119526,0.067823,5.930447,4.922004,563.887207,40.605663,365.484894,901.352783,19338150.0,365.484894,563.887207,40.605663,4.922004
3,3,0,0.754962,0.478901,0.545487,0.282183,0.119526,0.067823,5.930447,14.584277,1587.552002,114.354279,1029.048828,2626.837891,19334646.0,1029.048828,1587.552002,114.354279,14.584277
4,4,0,0.754962,0.478901,0.545487,0.282183,0.119526,0.067823,5.930447,41.786915,4466.624023,321.875366,2895.887695,7483.921387,19324826.0,2895.887695,4466.624023,321.875366,41.786915


## Ensemble Sample Interface
Sample from an ensemble of model 1 and model 2 

In [7]:
model_paths = [model1, model2]
solution_mappings = [lambda x : x, lambda x : x] # Conveniently, these two models operate on exactly the same state space, with the same names.

ensemble_result = pyciemss.ensemble_sample(model_paths, solution_mappings, end_time, logging_step_size, num_samples, start_time=start_time)
ensemble_result['data'].head()

Unnamed: 0,timepoint_id,sample_id,model_0/persistent_beta_c_param,model_0/persistent_kappa_param,model_0/persistent_gamma_param,model_0/persistent_hosp_param,model_0/persistent_death_hosp_param,model_0/persistent_I0_param,model_1/persistent_beta_c_param,model_1/persistent_beta_nc_param,...,model_0/H_state_state,model_0/I_state_state,model_0/R_state_state,model_0/S_state_state,model_1/D_state_state,model_1/E_state_state,model_1/H_state_state,model_1/I_state_state,model_1/R_state_state,model_1/S_state_state
0,0,0,0.512159,0.770848,0.191189,0.148833,0.085974,5.542153,0.205691,0.394095,...,8.5108,128.430603,87.901299,19339592.0,0.1975,50.069599,5.234837,30.438091,60.586609,19339896.0
1,1,0,0.512159,0.770848,0.191189,0.148833,0.085974,5.542153,0.205691,0.394095,...,78.131119,1151.855469,904.834595,19336008.0,0.856885,105.135063,12.102801,63.976738,205.066162,19339656.0
2,2,0,0.512159,0.770848,0.191189,0.148833,0.085974,5.542153,0.205691,0.394095,...,700.40686,10317.504883,8228.245117,19303866.0,2.279881,220.848709,25.574177,134.392319,509.507111,19339124.0
3,3,0,0.512159,0.770848,0.191189,0.148833,0.085974,5.542153,0.205691,0.394095,...,6233.259277,91446.203125,73491.335938,19019282.0,5.274168,463.884094,53.74041,282.291962,1149.133179,19338090.0
4,4,0,0.512159,0.770848,0.191189,0.148833,0.085974,5.542153,0.205691,0.394095,...,52261.085938,740655.875,627085.3125,16771191.0,11.564219,974.216614,112.8759,592.877075,2492.606689,19335858.0


## Calibrate interface
Calibrate a model to a dataset by mapping model state varibale or observables to columns in the dataset

In [8]:
data_mapping = {"Infected": "I"} # data_mapping = "column_name": "observable/state_variable"
num_iterations = 10
calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)
parameter_estimates = calibrated_results["inferred_parameters"]
calibrated_results

{'inferred_parameters': AutoGuideList(
   (0): AutoDelta()
   (1): AutoLowRankMultivariateNormal()
 ),
 'loss': 244.39011216163635}

In [9]:
parameter_estimates()

{'persistent_beta_c': tensor(0.4845, grad_fn=<ExpandBackward0>),
 'persistent_kappa': tensor(0.3275, grad_fn=<ExpandBackward0>),
 'persistent_gamma': tensor(0.3196, grad_fn=<ExpandBackward0>),
 'persistent_hosp': tensor(0.0939, grad_fn=<ExpandBackward0>),
 'persistent_death_hosp': tensor(0.0549, grad_fn=<ExpandBackward0>),
 'persistent_I0': tensor(5.9766, grad_fn=<ExpandBackward0>)}

## Pass the parameter estimates to `sample` to sample from the calibrated model

In [10]:
calibrated_sample_results = pyciemss.sample(model1, end_time, logging_step_size, num_samples, 
                start_time=start_time, inferred_parameters=parameter_estimates)
calibrated_sample_results

{'data':     timepoint_id  sample_id  persistent_beta_c_param  persistent_kappa_param  \
 0              0          0                 0.467816                0.335138   
 1              1          0                 0.467816                0.335138   
 2              2          0                 0.467816                0.335138   
 3              3          0                 0.467816                0.335138   
 4              4          0                 0.467816                0.335138   
 5              5          0                 0.467816                0.335138   
 6              6          0                 0.467816                0.335138   
 7              7          0                 0.467816                0.335138   
 8              8          0                 0.467816                0.335138   
 9              0          1                 0.485538                0.309315   
 10             1          1                 0.485538                0.309315   
 11             2   

In [11]:
# TODO:
# - Add intervention example
# - Add examples for calibrate_ensemble and optimize interfaces as they become available
# - Plot results

## Sample interface with intervention

In [12]:
start_time = 0.0
end_time = 40.
logging_step_size = 1.0
num_samples = 5
result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         static_parameter_interventions={torch.tensor(1.): {"p_cbeta": torch.tensor(0.35)}}, solver_method="euler")
result["data"]

Unnamed: 0,timepoint_id,sample_id,persistent_p_cbeta_param,persistent_p_tr_param,I_state_state,R_state_state,S_state_state
0,0,0,0.326866,19.755219,1.275920,0.050620,999.673462
1,1,0,0.326866,19.755219,1.657314,0.115206,999.227478
2,2,0,0.326866,19.755219,2.152454,0.199098,998.648438
3,3,0,0.326866,19.755219,2.795087,0.308055,997.896851
4,4,0,0.326866,19.755219,3.628849,0.449541,996.921631
...,...,...,...,...,...,...,...
190,34,4,0.345899,16.822075,530.974487,382.251221,87.774292
191,35,4,0.345899,16.822075,515.706116,413.815369,71.478516
192,36,4,0.345899,16.822075,497.938385,444.471863,58.589737
193,37,4,0.345899,16.822075,478.538818,474.072144,48.389011


### Optimize interface
Get infections below 300 individuals at 100 days for SIR model with minimum change to current value for intervention parameter

In [13]:
import numpy as np
from typing import Dict, List

def obs_nday_average_qoi(
    samples: Dict[str, torch.Tensor], contexts: List, ndays: int = 7
) -> np.ndarray:
    """
    Return estimate of last n-day average of each sample.
    samples is is the output from a Pyro Predictive object.
    samples[VARIABLE] is expected to have dimension (nreplicates, ntimepoints)
    Note: last ndays timepoints is assumed to represent last n-days of simulation.
    """
    dataQoI = samples[contexts[0]].detach().numpy()

    return np.mean(dataQoI[:, -ndays:], axis=1)

start_time = 0.0
end_time = 40.
logging_step_size = 1.0
observed_params = ["I_state"]
intervention_time = torch.tensor(1.)
intervened_params = "p_cbeta"
p_cbeta_current = 0.35
initial_guess_interventions = 0.15
bounds_interventions = [[0.1], [0.5]]

risk_bound = 300.
qoi = lambda x: obs_nday_average_qoi(x, observed_params, 1)
objfun = lambda x: np.abs(p_cbeta_current-x)
static_parameter_interventions = {intervention_time: intervened_params}

opt_result = pyciemss.optimize(model3, end_time, logging_step_size, qoi, risk_bound, static_parameter_interventions, objfun, 
                               initial_guess_interventions=initial_guess_interventions, bounds_interventions=bounds_interventions, 
                               start_time=0.0, n_samples_ouu=int(1e2), maxiter=1, maxfeval=20, solver_method="euler")
print(f'Optimal policy for intervening on {static_parameter_interventions[list(static_parameter_interventions.keys())[0]]} is ', opt_result["policy"])

 40%|████      | 16/40 [01:45<02:37,  6.57s/it]

Optimal policy for intervening on p_cbeta is  [0.2237]





#### Sample using optimal policy as intervention

In [14]:
num_samples = 100
result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         static_parameter_interventions={intervention_time: {intervened_params: opt_result["policy"]}}, 
                         solver_method="euler")
result["data"]

Unnamed: 0,timepoint_id,sample_id,persistent_p_cbeta_param,persistent_p_tr_param,I_state_state,R_state_state,S_state_state
0,0,0,0.349256,12.813910,1.270867,0.078040,999.651123
1,1,0,0.349256,12.813910,1.455598,0.177219,999.367188
2,2,0,0.349256,12.813910,1.667089,0.290814,999.042114
3,3,0,0.349256,12.813910,1.909187,0.420914,998.669922
4,4,0,0.349256,12.813910,2.186285,0.569907,998.243835
...,...,...,...,...,...,...,...
3895,34,99,0.313551,17.372927,169.823914,69.390106,761.786011
3896,35,99,0.313551,17.372927,188.959747,79.165306,732.874939
3897,36,99,0.313551,17.372927,209.030960,90.041985,701.927063
3898,37,99,0.313551,17.372927,229.788452,102.073982,669.137573
