In [1]:
import os
from pyciemss.PetriNetODE.interfaces import (
    load_and_sample_petri_model,
    load_and_calibrate_and_sample_petri_model
)

In [4]:
DEMO_PATH = "../../notebook/integration_demo/"
ASKENET_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/Model-Representations/main/petrinet/examples/sir_typed.json"

## load_and_sample_petri_model

In [5]:
num_samples = 3
timepoints = [0.5, 1.0, 2.0, 3.0, 4.0]

# Run sampling w/o an intervention
samples = load_and_sample_petri_model(ASKENET_PATH, num_samples, timepoints=timepoints)

# Save results
samples.to_csv(os.path.join(DEMO_PATH, "results_petri/sample_results.csv"), index=False)

In [6]:
interventions = [(1.1, "beta", 1.0), (2.1, "gamma", 0.1), (1.3, "beta", 2.0), (1.4, "gamma", 0.3)]

intervened_samples = load_and_sample_petri_model(
    ASKENET_PATH, num_samples, timepoints=timepoints, interventions=interventions
)


# Save results
intervened_samples.to_csv(
    os.path.join(DEMO_PATH, "results_petri/sample_results_w_interventions.csv"),
    index=False,
)

## load_and_calibrate_and_sample_petri_model

In [7]:
data_path = os.path.join(DEMO_PATH, "data.csv")
num_samples = 100
timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]

# Run the calibration and sampling
calibrated_samples = load_and_calibrate_and_sample_petri_model(
    ASKENET_PATH,
    data_path,
    num_samples,
    timepoints=timepoints,
    verbose=True,
)

# Save results
calibrated_samples.to_csv(
    os.path.join(DEMO_PATH, "results_petri/calibrated_sample_results.csv"), index=False
)

iteration 0: loss = 62.12709617614746
iteration 25: loss = 60.099618911743164
iteration 50: loss = 58.367534935474396
iteration 75: loss = 57.764896750450134
iteration 100: loss = 56.60384750366211
iteration 125: loss = 56.626110792160034
iteration 150: loss = 56.81937026977539
iteration 175: loss = 56.734710931777954
iteration 200: loss = 56.507239818573
iteration 225: loss = 56.81278920173645
iteration 250: loss = 56.38322448730469
iteration 275: loss = 56.98680853843689
iteration 300: loss = 56.82800793647766
iteration 325: loss = 56.91287016868591
iteration 350: loss = 56.571919679641724
iteration 375: loss = 56.62337040901184
iteration 400: loss = 56.986464977264404
iteration 425: loss = 56.08261704444885
iteration 450: loss = 56.555689573287964
iteration 475: loss = 56.814436197280884
iteration 500: loss = 56.36864924430847
iteration 525: loss = 56.35692620277405
iteration 550: loss = 56.53825569152832
iteration 575: loss = 56.52056694030762
iteration 600: loss = 56.5714080333709