In [7]:
import os
from pyciemss.Ensemble.interfaces import (
    load_and_sample_petri_ensemble, load_and_calibrate_and_sample_ensemble_model
)

In [8]:
DEMO_PATH = "../../notebook/integration_demo/"
ASKENET_PATH_1 = "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json"
ASKENET_PATH_2 = "../../test/models/AMR_examples/SIDARTHE.amr.json"

ASKENET_PATHS = [ASKENET_PATH_1, ASKENET_PATH_2]

In [9]:
def solution_mapping1(model1_solution: dict) -> dict:
    return model1_solution


def solution_mapping2(model2_solution: dict) -> dict:
    mapped_solution = {}
    mapped_solution["S"] = (
        model2_solution["Susceptible"]
        + model2_solution["Recognized"]
        + model2_solution["Threatened"]
    )

    mapped_solution["I"] = (
        model2_solution["Infected"]
        + model2_solution["Ailing"]
        + model2_solution["Diagnosed"]
    )

    # Model 1 doesn't include dead people, and implicitly assumes that everyone who is infected will recover.
    mapped_solution["R"] = (
        model2_solution["Healed"] + model2_solution["Extinct"]
    )

    return mapped_solution


solution_mappings = [solution_mapping1, solution_mapping2]

## load_and_sample_ensemble_model

In [10]:
weights = [0.5, 0.5]
num_samples = 100
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run sampling
ensemble_samples, q_ensemble = load_and_sample_petri_ensemble(
    ASKENET_PATHS, weights, solution_mappings, num_samples, timepoints
)

# Save results
ensemble_samples.to_csv(os.path.join(DEMO_PATH, "results_petri_ensemble/sample_results.csv"), index=False)
q_ensemble.to_csv(os.path.join(DEMO_PATH, "results_petri_ensemble/quantile_results.csv"), index=False)



a
(100, 5) 100 (23, 5)
(100, 5) 100 (23, 5)
(100, 5) 100 (23, 5)


In [5]:
ensemble_samples

Unnamed: 0,timepoint_id,sample_id,model_0/beta_param,model_0/gamma_param,model_1/beta_param,model_1/gamma_param,model_1/delta_param,model_1/alpha_param,model_1/epsilon_param,model_1/zeta_param,...,model_1/mu_param,model_1/nu_param,model_1/xi_param,model_1/tau_param,model_1/sigma_param,model_0_weight,model_1_weight,I_sol,R_sol,S_sol
0,0,0,0.026223,0.167017,0.011631,0.409733,0.012467,0.559308,0.188073,0.142098,...,0.015597,0.026466,0.019557,0.010510,0.014439,0.178197,0.821803,0.178200,2.976398e-12,179.019211
1,1,0,0.026223,0.167017,0.011631,0.409733,0.012467,0.559308,0.188073,0.142098,...,0.015597,0.026466,0.019557,0.010510,0.014439,0.178197,0.821803,0.154793,2.776307e-02,179.014877
2,2,0,0.026223,0.167017,0.011631,0.409733,0.012467,0.559308,0.188073,0.142098,...,0.015597,0.026466,0.019557,0.010510,0.014439,0.178197,0.821803,0.134460,5.187901e-02,179.010941
3,3,0,0.026223,0.167017,0.011631,0.409733,0.012467,0.559308,0.188073,0.142098,...,0.015597,0.026466,0.019557,0.010510,0.014439,0.178197,0.821803,0.116798,7.282706e-02,179.007782
4,4,0,0.026223,0.167017,0.011631,0.409733,0.012467,0.559308,0.188073,0.142098,...,0.015597,0.026466,0.019557,0.010510,0.014439,0.178197,0.821803,0.101456,9.102336e-02,179.004929
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,0,99,0.026316,0.102675,0.011822,0.541434,0.009999,0.673947,0.193376,0.104431,...,0.015125,0.029262,0.015572,0.010282,0.014376,0.175958,0.824042,0.175961,1.806686e-12,176.781631
496,1,99,0.026316,0.102675,0.011822,0.541434,0.009999,0.673947,0.193376,0.104431,...,0.015125,0.029262,0.015572,0.010282,0.014376,0.175958,0.824042,0.163022,1.739416e-02,176.777100
497,2,99,0.026316,0.102675,0.011822,0.541434,0.009999,0.673947,0.193376,0.104431,...,0.015125,0.029262,0.015572,0.010282,0.014376,0.175958,0.824042,0.151035,3.350934e-02,176.773041
498,3,99,0.026316,0.102675,0.011822,0.541434,0.009999,0.673947,0.193376,0.104431,...,0.015125,0.029262,0.015572,0.010282,0.014376,0.175958,0.824042,0.139931,4.843965e-02,176.769226


In [6]:
q_ensemble

Unnamed: 0,timepoint_id,target,type,quantile,value
0,0,I_sol,quantile,0.01,0.000507
1,1,I_sol,quantile,0.01,0.000456
2,2,I_sol,quantile,0.01,0.000410
3,3,I_sol,quantile,0.01,0.000369
4,4,I_sol,quantile,0.01,0.000332
...,...,...,...,...,...
340,0,S_sol,quantile,0.99,996.604308
341,1,S_sol,quantile,0.99,996.579152
342,2,S_sol,quantile,0.99,996.557291
343,3,S_sol,quantile,0.99,996.538525


## load_and_calibrate_and_sample_ensemble_model

In [5]:
data_path = os.path.join(DEMO_PATH, "data.csv")
num_samples = 100
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run the calibration and sampling
calibrated_samples, calibrated_q_ensemble = load_and_calibrate_and_sample_ensemble_model(
    ASKENET_PATHS,
    data_path,
    weights,
    solution_mappings,
    num_samples,
    timepoints,
    verbose=True,
    total_population=1000,
    num_iterations=350,
)

# Save results
calibrated_samples.to_csv(
    os.path.join(DEMO_PATH, "results_petri_ensemble/calibrated_sample_results.csv"), index=False
)
calibrated_q_ensemble.to_csv(
    os.path.join(DEMO_PATH, "results_petri_ensemble/calibrated_quantile_results.csv"), index=False
)



iteration 0: loss = 63.51079273223877
iteration 25: loss = 40.8916078209877
iteration 50: loss = 37.46487545967102
iteration 75: loss = 25.852425813674927
iteration 100: loss = 21.470157742500305
iteration 125: loss = 17.58678960800171
iteration 150: loss = 15.267673671245575
iteration 175: loss = 16.10206639766693
iteration 200: loss = 15.993813276290894
iteration 225: loss = 15.860064268112183
iteration 250: loss = 13.152705788612366
iteration 275: loss = 16.014341235160828
iteration 300: loss = 15.707162857055664
iteration 325: loss = 12.456510186195374
