In [2]:
import os
from pyciemss.Ensemble.interfaces import (
    load_and_sample_petri_ensemble, load_and_calibrate_and_sample_ensemble_model
)

In [3]:
DEMO_PATH = "../../notebook/integration_demo/"
ASKENET_PATH_1 = "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json"
ASKENET_PATH_2 = "../../test/models/AMR_examples/SIDARTHE.amr.json"

ASKENET_PATHS = [ASKENET_PATH_1, ASKENET_PATH_2]

In [4]:
def solution_mapping1(model1_solution: dict) -> dict:
    return model1_solution


def solution_mapping2(model2_solution: dict) -> dict:
    mapped_solution = {}
    mapped_solution["S"] = (
        model2_solution["Susceptible"]
        + model2_solution["Recognized"]
        + model2_solution["Threatened"]
    )

    mapped_solution["I"] = (
        model2_solution["Infected"]
        + model2_solution["Ailing"]
        + model2_solution["Diagnosed"]
    )

    # Model 1 doesn't include dead people, and implicitly assumes that everyone who is infected will recover.
    mapped_solution["R"] = (
        model2_solution["Healed"] + model2_solution["Extinct"]
    )

    return mapped_solution


solution_mappings = [solution_mapping1, solution_mapping2]

## load_and_sample_ensemble_model

In [5]:
weights = [0.5, 0.5]
num_samples = 100
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run sampling
ensemble_samples, q_ensemble = load_and_sample_petri_ensemble(
    ASKENET_PATHS, weights, solution_mappings, num_samples, timepoints
)

# Save results
ensemble_samples.to_csv(os.path.join(DEMO_PATH, "results_petri_ensemble/sample_results.csv"), index=False)



(100, 5) 100 (23, 5)
(100, 5) 100 (23, 5)
(100, 5) 100 (23, 5)


In [6]:
ensemble_samples

Unnamed: 0,timepoint_id,sample_id,model_0/beta_param,model_0/gamma_param,model_1/beta_param,model_1/gamma_param,model_1/delta_param,model_1/alpha_param,model_1/epsilon_param,model_1/zeta_param,...,model_1/mu_param,model_1/nu_param,model_1/xi_param,model_1/tau_param,model_1/sigma_param,model_0_weight,model_1_weight,I_sol,R_sol,S_sol
0,0,0,0.026803,0.119216,0.012491,0.387874,0.011683,0.669896,0.153438,0.109742,...,0.017868,0.022851,0.01528,0.008842,0.013721,0.479329,0.520671,0.479331,5.714417e-12,479.849792
1,1,0,0.026803,0.119216,0.012491,0.387874,0.011683,0.669896,0.153438,0.109742,...,0.017868,0.022851,0.01528,0.008842,0.013721,0.479329,0.520671,0.437009,5.458255e-02,479.837494
2,2,0,0.026803,0.119216,0.012491,0.387874,0.011683,0.669896,0.153438,0.109742,...,0.017868,0.022851,0.01528,0.008842,0.013721,0.479329,0.520671,0.398424,1.043458e-01,479.826355
3,3,0,0.026803,0.119216,0.012491,0.387874,0.011683,0.669896,0.153438,0.109742,...,0.017868,0.022851,0.01528,0.008842,0.013721,0.479329,0.520671,0.363246,1.497156e-01,479.816376
4,4,0,0.026803,0.119216,0.012491,0.387874,0.011683,0.669896,0.153438,0.109742,...,0.017868,0.022851,0.01528,0.008842,0.013721,0.479329,0.520671,0.331175,1.910788e-01,479.806671
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
495,0,99,0.027413,0.177764,0.012784,0.439427,0.012093,0.507124,0.185306,0.116331,...,0.014990,0.022754,0.01803,0.009766,0.014784,0.417890,0.582110,0.417893,7.428625e-12,418.472595
496,1,99,0.027413,0.177764,0.012784,0.439427,0.012093,0.507124,0.185306,0.116331,...,0.014990,0.022754,0.01803,0.009766,0.014784,0.417890,0.582110,0.359548,6.897052e-02,418.462006
497,2,99,0.027413,0.177764,0.012784,0.439427,0.012093,0.507124,0.185306,0.116331,...,0.014990,0.022754,0.01803,0.009766,0.014784,0.417890,0.582110,0.309349,1.283116e-01,418.452820
498,3,99,0.027413,0.177764,0.012784,0.439427,0.012093,0.507124,0.185306,0.116331,...,0.014990,0.022754,0.01803,0.009766,0.014784,0.417890,0.582110,0.266159,1.793678e-01,418.445129


In [7]:
q_ensemble

Unnamed: 0,timepoint_id,quantile,I_sol,R_sol,S_sol
0,0,0.01,0.000736,1.323472e-14,1.731733
1,1,0.01,0.000633,1.225338e-04,1.731713
2,2,0.01,0.000545,2.275207e-04,1.731696
3,3,0.01,0.000469,3.176221e-04,1.731683
4,4,0.01,0.000403,3.950255e-04,1.731671
...,...,...,...,...,...
110,0,0.99,0.999622,1.714123e-11,999.621978
111,1,0.99,0.904725,1.594760e-01,999.595990
112,2,0.99,0.822314,2.972693e-01,999.573170
113,3,0.99,0.747413,4.163281e-01,999.550798


## load_and_calibrate_and_sample_ensemble_model

In [5]:
data_path = os.path.join(DEMO_PATH, "data.csv")
num_samples = 100
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run the calibration and sampling
calibrated_samples = load_and_calibrate_and_sample_ensemble_model(
    ASKENET_PATHS,
    data_path,
    weights,
    solution_mappings,
    num_samples,
    timepoints,
    verbose=True,
    total_population=1000,
    num_iterations=350,
)

# Save results
calibrated_samples.to_csv(
    os.path.join(DEMO_PATH, "results_petri_ensemble/calibrated_sample_results.csv"), index=False
)



iteration 0: loss = 66.16035544872284
iteration 25: loss = 42.31934851408005
iteration 50: loss = 33.17540234327316
iteration 75: loss = 27.15283751487732
iteration 100: loss = 21.151540756225586
iteration 125: loss = 17.88287889957428
iteration 150: loss = 19.490506768226624
iteration 175: loss = 16.146777868270874
iteration 200: loss = 16.498568773269653
iteration 225: loss = 15.53571480512619
iteration 250: loss = 15.07484495639801
iteration 275: loss = 14.595208406448364
iteration 300: loss = 14.831865012645721
iteration 325: loss = 15.439243197441101
