# A tour of PyCIEMSS interfaces and functionality

### Load dependencies and interfaces

In [2]:
import os
import pyciemss
import torch
smoke_test = ('CI' in os.environ)

### Select models and data

In [3]:
MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"
DATA_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/datasets/"

model1 = os.path.join(MODEL_PATH, "SEIRHD_NPI_Type1_petrinet.json")
model2 = os.path.join(MODEL_PATH, "SEIRHD_NPI_Type2_petrinet.json")
model3 = os.path.join(MODEL_PATH, "SIR_stockflow.json")

dataset1 = os.path.join(DATA_PATH, "traditional.csv")

### Set parameters for sampling

In [4]:
start_time = 0.0
end_time = 100.
logging_step_size = 10.0
num_samples = 3 if smoke_test else 1000

## Sample interface
Take `num_samples` number of samples from the (prior) distribution invoked by the chosen model.

### Sample from model 1

In [5]:
result1 = pyciemss.sample(model1, end_time, logging_step_size, num_samples, start_time=start_time)
result1["unprocessed_result"]

{'persistent_beta_c': tensor([0.5248, 0.7828, 0.7588]),
 'persistent_kappa': tensor([0.5060, 0.5787, 0.1391]),
 'persistent_gamma': tensor([0.4790, 0.1410, 0.2166]),
 'persistent_hosp': tensor([0.1439, 0.0362, 0.0762]),
 'persistent_death_hosp': tensor([0.0729, 0.0901, 0.0796]),
 'persistent_I0': tensor([ 6.1897, 12.8359, 12.1238]),
 'D_state': tensor([[4.0696e-01, 1.2093e+00, 2.1402e+00, 3.1672e+00, 4.2936e+00, 5.5280e+00,
          6.8805e+00, 8.3626e+00, 9.9865e+00],
         [1.1735e-01, 1.0571e+00, 7.2866e+00, 4.8363e+01, 3.1678e+02, 1.9757e+03,
          9.7853e+03, 2.7899e+04, 4.5882e+04],
         [1.5962e-01, 3.9959e-01, 5.7416e-01, 6.8746e-01, 7.5935e-01, 8.0474e-01,
          8.3337e-01, 8.5143e-01, 8.6282e-01]]),
 'E_state': tensor([[3.3231e+01, 3.6408e+01, 3.9895e+01, 4.3716e+01, 4.7903e+01, 5.2490e+01,
          5.7515e+01, 6.3021e+01, 6.1560e+01],
         [1.7722e+02, 1.1699e+03, 7.7183e+03, 5.0640e+04, 3.2048e+05, 1.6333e+06,
          3.4077e+06, 1.6671e+06, 3.9021e+0

In [6]:
result1['data'].head()

Unnamed: 0,timepoint_id,sample_id,persistent_beta_c_param,persistent_kappa_param,persistent_gamma_param,persistent_hosp_param,persistent_death_hosp_param,persistent_I0_param,D_state,E_state,H_state,I_state,R_state,S_state,infected_observable_state,exposed_observable_state,hospitalized_observable_state,dead_observable_state
0,0,0,0.524803,0.506024,0.479009,0.143926,0.072884,6.189723,0.406956,33.23101,4.754951,17.009682,66.670609,19339920.0,17.009682,33.23101,4.754951,0.406956
1,1,0,0.524803,0.506024,0.479009,0.143926,0.072884,6.189723,1.209271,36.407856,6.030499,18.64551,149.939301,19339828.0,18.64551,36.407856,6.030499,1.209271
2,2,0,0.524803,0.506024,0.479009,0.143926,0.072884,6.189723,2.14016,39.895275,6.719253,20.431562,241.846176,19339728.0,20.431562,39.895275,6.719253,2.14016
3,3,0,0.524803,0.506024,0.479009,0.143926,0.072884,6.189723,3.167211,43.716206,7.377875,22.388531,342.645447,19339622.0,22.388531,43.716206,7.377875,3.167211
4,4,0,0.524803,0.506024,0.479009,0.143926,0.072884,6.189723,4.293585,47.902828,8.086552,24.532637,453.110413,19339502.0,24.532637,47.902828,8.086552,4.293585


### Sample from model 2

In [7]:
result2 = pyciemss.sample(model2, end_time, logging_step_size, num_samples, start_time=start_time)
result2['data'].head()

Unnamed: 0,timepoint_id,sample_id,persistent_beta_c_param,persistent_beta_nc_param,persistent_kappa_param,persistent_gamma_param,persistent_hosp_param,persistent_death_hosp_param,persistent_I0_param,D_state,E_state,H_state,I_state,R_state,S_state,infected_observable_state,exposed_observable_state,hospitalized_observable_state,dead_observable_state
0,0,0,0.297661,0.453676,0.754566,0.397328,0.080995,0.096925,9.568054,0.423141,101.485405,4.755975,49.33733,107.440117,19339780.0,49.33733,101.485405,4.755975,0.423141
1,1,0,0.297661,0.453676,0.754566,0.397328,0.080995,0.096925,9.568054,2.245101,326.518188,16.089222,158.754715,466.292603,19338996.0,158.754715,326.518188,16.089222,2.245101
2,2,0,0.297661,0.453676,0.754566,0.397328,0.080995,0.096925,9.568054,8.173079,1050.408691,51.870396,510.736267,1621.462158,19336756.0,510.736267,1050.408691,51.870396,8.173079
3,3,0,0.297661,0.453676,0.754566,0.397328,0.080995,0.096925,9.568054,27.251232,3377.420654,166.845978,1642.418823,5337.130859,19329478.0,1642.418823,3377.420654,166.845978,27.251232
4,4,0,0.297661,0.453676,0.754566,0.397328,0.080995,0.096925,9.568054,88.583427,10841.516602,536.099976,5274.518066,17278.066406,19306026.0,5274.518066,10841.516602,536.099976,88.583427


## Ensemble Sample Interface
Sample from an ensemble of model 1 and model 2 

In [8]:
model_paths = [model1, model2]
solution_mappings = [lambda x : x, lambda x : x] # Conveniently, these two models operate on exactly the same state space, with the same names.

ensemble_result = pyciemss.ensemble_sample(model_paths, solution_mappings, end_time, logging_step_size, num_samples, start_time=start_time)
ensemble_result['data'].head()

Unnamed: 0,timepoint_id,sample_id,model_0/weight_param,model_1/weight_param,model_0/persistent_beta_c_param,model_0/persistent_kappa_param,model_0/persistent_gamma_param,model_0/persistent_hosp_param,model_0/persistent_death_hosp_param,model_0/persistent_I0_param,...,D_state,E_state,H_state,I_state,R_state,S_state,infected_state,exposed_state,hospitalized_state,dead_state
0,0,0,0.971867,0.028133,0.454145,0.250756,0.281342,0.090347,0.069111,9.584448,...,0.187443,22.044773,2.285573,20.535141,52.936188,19339942.0,20.535141,22.044773,2.285573,0.187443
1,1,0,0.971867,0.028133,0.454145,0.250756,0.281342,0.090347,0.069111,9.584448,...,0.514453,18.843756,2.365145,17.67697,106.300575,19339922.0,17.67697,18.843756,2.365145,0.514453
2,2,0,0.971867,0.028133,0.454145,0.250756,0.281342,0.090347,0.069111,9.584448,...,0.821363,16.234955,2.083929,15.228139,152.485168,19339854.0,15.228139,16.234955,2.083929,0.821363
3,3,0,0.971867,0.028133,0.454145,0.250756,0.281342,0.090347,0.069111,9.584448,...,1.089267,13.998522,1.801614,13.130044,192.319611,19339816.0,13.130044,13.998522,1.801614,1.089267
4,4,0,0.971867,0.028133,0.454145,0.250756,0.281342,0.090347,0.069111,9.584448,...,1.32072,12.071818,1.554059,11.322845,226.67128,19339836.0,11.322845,12.071818,1.554059,1.32072


## Calibrate interface
Calibrate a model to a dataset by mapping model state varibale or observables to columns in the dataset

In [9]:
data_mapping = {"Infected": "I"} # data_mapping = "column_name": "observable/state_variable"
num_iterations = 10 if smoke_test else 1000
calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)
parameter_estimates = calibrated_results["inferred_parameters"]
calibrated_results

{'inferred_parameters': AutoGuideList(
   (0): AutoDelta()
   (1): AutoLowRankMultivariateNormal()
 ),
 'loss': 248.0983643233776}

In [9]:
parameter_estimates()

{'persistent_beta_c': tensor(0.4845, grad_fn=<ExpandBackward0>),
 'persistent_kappa': tensor(0.3275, grad_fn=<ExpandBackward0>),
 'persistent_gamma': tensor(0.3196, grad_fn=<ExpandBackward0>),
 'persistent_hosp': tensor(0.0939, grad_fn=<ExpandBackward0>),
 'persistent_death_hosp': tensor(0.0549, grad_fn=<ExpandBackward0>),
 'persistent_I0': tensor(5.9766, grad_fn=<ExpandBackward0>)}

## Pass the parameter estimates to `sample` to sample from the calibrated model

In [10]:
calibrated_sample_results = pyciemss.sample(model1, end_time, logging_step_size, num_samples, 
                start_time=start_time, inferred_parameters=parameter_estimates)
calibrated_sample_results

{'data':     timepoint_id  sample_id  persistent_beta_c_param  persistent_kappa_param  \
 0              0          0                 0.410941                0.220197   
 1              1          0                 0.410941                0.220197   
 2              2          0                 0.410941                0.220197   
 3              3          0                 0.410941                0.220197   
 4              4          0                 0.410941                0.220197   
 5              5          0                 0.410941                0.220197   
 6              6          0                 0.410941                0.220197   
 7              7          0                 0.410941                0.220197   
 8              8          0                 0.410941                0.220197   
 9              0          1                 0.461822                0.209891   
 10             1          1                 0.461822                0.209891   
 11             2   

In [11]:
# TODO:
# - Add intervention example
# - Add examples for calibrate_ensemble and optimize interfaces as they become available
# - Plot results

## Sample interface with intervention

In [12]:
start_time = 0.0
end_time = 40.
logging_step_size = 1.0
num_samples = 5 if smoke_test else 1000
result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         static_parameter_interventions={torch.tensor(1.): {"p_cbeta": torch.tensor(0.35)}}, solver_method="euler")
result["data"]

Unnamed: 0,timepoint_id,sample_id,persistent_p_cbeta_param,persistent_p_tr_param,I_state_state,R_state_state,S_state_state
0,0,0,0.326866,19.755219,1.275920,0.050620,999.673462
1,1,0,0.326866,19.755219,1.657314,0.115206,999.227478
2,2,0,0.326866,19.755219,2.152454,0.199098,998.648438
3,3,0,0.326866,19.755219,2.795087,0.308055,997.896851
4,4,0,0.326866,19.755219,3.628849,0.449541,996.921631
...,...,...,...,...,...,...,...
190,34,4,0.345899,16.822075,530.974487,382.251221,87.774292
191,35,4,0.345899,16.822075,515.706116,413.815369,71.478516
192,36,4,0.345899,16.822075,497.938385,444.471863,58.589737
193,37,4,0.345899,16.822075,478.538818,474.072144,48.389011


### Optimize interface
Get infections below 300 individuals at 100 days for SIR model with minimum change to current value for intervention parameter

In [11]:
import numpy as np
from typing import Dict, List

def obs_nday_average_qoi(
    samples: Dict[str, torch.Tensor], contexts: List, ndays: int = 7
) -> np.ndarray:
    """
    Return estimate of last n-day average of each sample.
    samples is is the output from a Pyro Predictive object.
    samples[VARIABLE] is expected to have dimension (nreplicates, ntimepoints)
    Note: last ndays timepoints is assumed to represent last n-days of simulation.
    """
    dataQoI = samples[contexts[0]].detach().numpy()

    return np.mean(dataQoI[:, -ndays:], axis=1)

start_time = 0.0
end_time = 40.
logging_step_size = 1.0
observed_params = ["I_state"]
intervention_time = torch.tensor(1.)
intervened_params = "p_cbeta"
p_cbeta_current = 0.35
initial_guess_interventions = 0.15
bounds_interventions = [[0.1], [0.5]]

risk_bound = 300.
qoi = lambda x: obs_nday_average_qoi(x, observed_params, 1)
objfun = lambda x: np.abs(p_cbeta_current-x)
static_parameter_interventions = {intervention_time: intervened_params}

opt_result = pyciemss.optimize(model3, end_time, logging_step_size, qoi, risk_bound, static_parameter_interventions, objfun, 
                               initial_guess_interventions=initial_guess_interventions, bounds_interventions=bounds_interventions, 
                               start_time=0.0, n_samples_ouu=int(1e2), maxiter=1, maxfeval=20, solver_method="euler")
print(f'Optimal policy for intervening on {static_parameter_interventions[list(static_parameter_interventions.keys())[0]]} is ', opt_result["policy"])

 40%|████      | 16/40 [01:00<01:31,  3.80s/it]

Optimal policy for intervening on p_cbeta is  tensor([0.2237], dtype=torch.float64)





#### Sample using optimal policy as intervention

In [12]:
num_samples = 10 if smoke_test else 100
result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         static_parameter_interventions={intervention_time: {intervened_params: opt_result["policy"]}}, 
                         solver_method="euler")
result["data"]

Unnamed: 0,timepoint_id,sample_id,persistent_p_cbeta_param,persistent_p_tr_param,I_state,R_state,S_state
0,0,0,0.349256,12.813910,1.270867,0.078040,999.651123
1,1,0,0.349256,12.813910,1.455598,0.177219,999.367188
2,2,0,0.349256,12.813910,1.667089,0.290814,999.042114
3,3,0,0.349256,12.813910,1.909187,0.420914,998.669922
4,4,0,0.349256,12.813910,2.186285,0.569907,998.243835
...,...,...,...,...,...,...,...
385,34,9,0.316448,13.111476,104.724236,61.065868,835.209900
386,35,9,0.316448,13.111476,116.283775,69.053085,815.663147
387,36,9,0.316448,13.111476,128.611313,77.921944,794.466736
388,37,9,0.316448,13.111476,141.636490,87.731010,771.632507
