In [1]:
import os
from pyciemss.Ensemble.interfaces import (
    load_and_sample_petri_ensemble, load_and_calibrate_and_sample_ensemble_model
)
from pyciemss.visuals import plots

In [2]:
DEMO_PATH = "notebook/integration_demo/"
# ASKENET_PATH_1 = "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json"
ASKENET_PATH_1 = "test/models/AMR_examples/SIDARTHE.amr.json"
ASKENET_PATH_2 = "test/models/AMR_examples/SIDARTHE.amr.json"

ASKENET_PATHS = [ASKENET_PATH_1, ASKENET_PATH_2]

## load_and_sample_ensemble_model

In [4]:
weights = [0.5, 0.5]
num_samples = 2
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0, 40, 50]
solution_mappings = [{"S":"Cases", "I":"Hospitalizations", "D":"Deaths"}, {"S":"Cases", "I":"Hospitalizations", "D":"Deaths"}]

# Run sampling
result = load_and_sample_petri_ensemble(
    ASKENET_PATHS, weights, solution_mappings, num_samples, timepoints, 
    time_unit="days"
)

ensemble_samples = result["data"]
q_ensemble = result["quantiles"]
# Save results
ensemble_samples.to_csv(os.path.join(DEMO_PATH, "results_petri_ensemble/sample_results.csv"), index=False)
q_ensemble.to_csv(os.path.join(DEMO_PATH, "results_petri_ensemble/quantile_results.csv"), index=False)

In [5]:
ensemble_samples

Unnamed: 0,timepoint_id,sample_id,model_0/beta_param,model_0/gamma_param,model_0/delta_param,model_0/alpha_param,model_0/epsilon_param,model_0/zeta_param,model_0/lambda_param,model_0/eta_param,...,model_1/nu_param,model_1/xi_param,model_1/tau_param,model_1/sigma_param,model_0_weight,model_1_weight,S_sol,I_sol,D_sol,timepoint_days
0,0,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,3.666667e-07,3.333333e-08,6.324589e-32,0.0
1,1,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,9.598541e-07,1.540768e-07,2.136914e-11,1.0
2,2,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,1.882615e-06,4.510297e-07,1.611468e-10,2.0
3,3,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,3.265895e-06,9.679806e-07,5.753538e-10,3.0
4,4,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,5.308928e-06,1.780367e-06,1.494729e-09,4.0
5,5,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,0.4704257,0.3177823,0.002365492,40.0
6,6,0,0.012129,0.427158,0.009722,0.57175,0.177298,0.111706,0.037496,0.106768,...,0.025326,0.015622,0.011164,0.019525,0.007747,0.992253,0.5548024,0.4775309,0.01122028,50.0
7,0,1,0.011331,0.522915,0.008906,0.673092,0.151185,0.107392,0.03777,0.119198,...,0.028566,0.017226,0.009229,0.016204,0.969952,0.030048,3.666667e-07,3.333333e-08,6.78608e-32,0.0
8,1,1,0.011331,0.522915,0.008906,0.673092,0.151185,0.107392,0.03777,0.119198,...,0.028566,0.017226,0.009229,0.016204,0.969952,0.030048,1.032201e-06,1.670246e-07,2.261145e-11,1.0
9,2,1,0.011331,0.522915,0.008906,0.673092,0.151185,0.107392,0.03777,0.119198,...,0.028566,0.017226,0.009229,0.016204,0.969952,0.030048,2.12324e-06,5.057558e-07,1.742668e-10,2.0


In [6]:
result = load_and_sample_petri_ensemble(
    ASKENET_PATHS, weights, solution_mappings, num_samples, timepoints,
    time_unit="days", visual_options={"subset":".*_sol"}
)

plots.ipy_display(result["visual"])




In [7]:
result = load_and_sample_petri_ensemble(
    ASKENET_PATHS, weights, solution_mappings, num_samples, timepoints,
    time_unit="days", visual_options={"subset":".*_sol"}
)

plots.ipy_display(result["visual"])




## load_and_calibrate_and_sample_ensemble_model

In [8]:
data_path = os.path.join(DEMO_PATH, "data.csv")
num_samples = 2
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run the calibration and sampling
result = load_and_calibrate_and_sample_ensemble_model(
    ASKENET_PATHS,
    data_path,
    weights,
    solution_mappings,
    num_samples,
    timepoints,
    verbose=True,
    total_population=1000,
    num_iterations=350,
    time_unit="days",
    visual_options={"title": "Calibrated Ensemble", "subset":".*_sol"}
)

# Save results
result["data"].to_csv(
    os.path.join(DEMO_PATH, "results_petri_ensemble/calibrated_sample_results.csv"), index=False
)
result["quantiles"].to_csv(
    os.path.join(DEMO_PATH, "results_petri_ensemble/calibrated_quantile_results.csv"), index=False
)
plots.ipy_display(result["visual"])

iteration 0: loss = 88.95849859714508
iteration 25: loss = 60.6032851934433
iteration 50: loss = 39.23359388113022
iteration 75: loss = 25.380016148090363
iteration 100: loss = 17.3934445977211
iteration 125: loss = 19.12727040052414
iteration 150: loss = 11.075169920921326
iteration 175: loss = 5.466326415538788
iteration 200: loss = 15.445661842823029
iteration 225: loss = 3.8957112431526184
iteration 250: loss = 7.482335865497589
iteration 275: loss = 3.592634856700897
iteration 300: loss = 6.30310183763504


KeyboardInterrupt: 