In [13]:
%reload_ext autoreload
%autoreload 2

import json
import os

import torch
import pyro
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyro.distributions as dist
from typing import Dict, Optional

import mira
import mira.metamodel.io
import mira.modeling
import mira.sources.petri

from pyciemss.ODE.askem_primitives import sample, infer_parameters, intervene, optimization
from pyciemss.ODE.base import ODE, PetriNetODESystem, Time, State, Solution, Observation, get_name
from pyciemss.utils import state_flux_constraint

In [16]:
MODEL_PATH = "../../../program-milestones/6-month-milestone/evaluation/scenario_3/ta_2/"

MODEL_FILES = {
    "biomd958": ("scenario3_biomd958.json", "scenario3_biomd958_mira.json"),
    "biomd960": ("scenario3_biomd960.json", "scenario3_biomd960_mira.json"),
}

def read_obs(filename):
    df = pd.read_csv(filename)
    obs = {}
    for col in df.columns:
        if col == 'date':
            obs[col] = pd.to_datetime(df[col].values)
        else:
            obs[col] = torch.tensor(df[col].values)
    return obs

def load_mira_model(model_name: str) -> PetriNetODESystem:
    model_file, metamodel_file = MODEL_FILES[model_name]
    with open(os.path.join(MODEL_PATH, model_file), "r") as f:
        model_json = json.load(f)
    mira_model = mira.modeling.Model(
        mira.sources.petri.template_model_from_petri_json(model_json)
    )
    mira_metamodel = mira.metamodel.io.model_from_json_file(os.path.join(MODEL_PATH, metamodel_file))
    return mira_model, mira_metamodel

In [17]:
pyro.set_rng_seed(0)
pyro.clear_param_store()

DATA_FILE = "usa-IRDVHN_age.csv"
MODEL_NAME = "biomd958"
NUM_ITERATIONS = 101
NUM_TIMESTEPS = 100
T1, T2 = 0., 100.

# load from file
mira_model, mira_metamodel = load_mira_model(MODEL_NAME)
model = PetriNetODESystem(mira_model)
data = read_obs(DATA_FILE)

# load initial state
initial_state = model.default_initial_state
tspan = torch.linspace(T1, T2, NUM_TIMESTEPS)


In [None]:
posterior = infer_parameters(model, NUM_ITERATIONS, [], data, initial_state, tspan)