# This is a notebook for synthesizing data to test calibration

In order to check that `calibrate` is returning a result that makes sense, we are going to:  

1. `sample` a model
2. use that output to generate synthetic data
3. then calibrate the model to that synthetic dataset
4. sanity check that the parameters/results are reasonable compared to the parameters used to create the synthetic data

See [this issue](https://github.com/ciemss/pyciemss/issues/448).

### Load dependencies

In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pyciemss
from pyciemss.interfaces import calibrate

Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),
(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)
but was not found to be installed on your system.
If this would cause problems for you,
please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466
        
  import pandas as pd


### Collect model and data paths

In [2]:
MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"
DATA_PATH = "../../docs/source/"

# Models
petri1 = os.path.join(MODEL_PATH, "SEIRHD_with_reinfection01_petrinet.json")
regnet1 = os.path.join(MODEL_PATH, "LV_rabbits_wolves_model02_regnet.json")
stock1 = os.path.join(MODEL_PATH, "SIR_stockflow.json")
stock2 = os.path.join(MODEL_PATH, "SEIRHDS_stockflow.json")

### Set parameters for sampling

In [3]:
# What is logging_step_size?

In [44]:
start_time = 0.0
end_time = 151.0
logging_step_size = 10.0

### Define functions for generating synthetic data

In [45]:
# Function to add Gaussian noise to `sample` results
def add_gaussian_noise(data: pd.DataFrame, std_dev: float, col_state_map: dict, start_time: float = 0.0, end_time: float = 100.0) -> pd.DataFrame:
    noise = np.random.normal(0, std_dev, size=data.shape)
    noisy_data = data + noise
    noisy_data.insert(0, 'Timestamp', np.arange(start_time + logging_step_size, end_time, logging_step_size).tolist())
    col_state_map = {'Timestamp': 'Timestamp', **col_state_map}
    noisy_data = noisy_data.rename(columns=col_state_map)
    return noisy_data

# Function to sample from a model and generate synthetic data
def get_synthetic_data(model, col_state_map, end_time, logging_step_size, noise_level):
    num_samples = 1
    result = pyciemss.sample(model_path_or_json=model, end_time=end_time, logging_step_size=logging_step_size, num_samples=num_samples, start_time=start_time)
    data_df = result["data"][list(col_state_map.keys())]
    print(data_df)
    noisy_data = add_gaussian_noise(data_df, noise_level, col_state_map, start_time=start_time, end_time=end_time)
    noisy_data.to_csv('noisy_data.csv', index=False)
    return noisy_data

# TODO: make_plot=True

## (1) Create synthetic data from a given model

In [47]:
col_state_map = {'I_state_state': 'Cases', 'H_state_state': 'Hosp', 'D_state_state': 'Deaths'}
noise_level = 0.0
get_synthetic_data(petri1, col_state_map, end_time, logging_step_size, noise_level)

    I_state_state  H_state_state  D_state_state
0    7.936502e+01       2.597956       0.137985
1    2.945332e+02      10.168753       0.830313
2    1.092169e+03      37.781757       3.426448
3    4.048442e+03     140.078873      13.056312
4    1.498669e+04     518.819519      48.740582
5    5.520471e+04    1914.785889     180.672516
6    1.997089e+05    6975.674316     664.434204
7    6.779072e+05   24269.400391    2386.510498
8    1.886455e+06   72901.289062    7947.368652
9    3.332892e+06  153380.953125   21836.242188
10   3.205401e+06  189024.750000   43869.144531
11   1.956736e+06  144376.140625   64985.828125
12   9.639288e+05   82089.820312   78830.046875
13   4.494286e+05   40646.988281   86150.320312
14   2.151614e+05   19434.386719   89695.851562


Unnamed: 0,Timestamp,Cases,Hosp,Deaths
0,10.0,79.36502,2.597956,0.137985
1,20.0,294.5332,10.168753,0.830313
2,30.0,1092.169,37.781757,3.426448
3,40.0,4048.442,140.078873,13.056312
4,50.0,14986.69,518.819519,48.740582
5,60.0,55204.71,1914.785889,180.672516
6,70.0,199708.9,6975.674316,664.434204
7,80.0,677907.2,24269.400391,2386.510498
8,90.0,1886455.0,72901.289062,7947.368652
9,100.0,3332892.0,153380.953125,21836.242188


## (2) Calibrate the model to the synthetic data

In [48]:
data_mapping = {"Cases": "I", "Hosp": "H", "Deaths": "D"} # data_mapping = "column_name": "observable/state_variable"
num_iterations = 100
dataset = DATA_PATH + "noisy_data.csv"

calibrated_results = calibrate(petri1, dataset, data_mapping=data_mapping, num_iterations=num_iterations)
parameter_estimates = calibrated_results["inferred_parameters"]
calibrated_results

{'inferred_parameters': AutoGuideList(
   (0): AutoDelta()
   (1): AutoLowRankMultivariateNormal()
 ),
 'loss': 5445.07262519002}

In [49]:
parameter_estimates()

{'persistent_beta': tensor(0.5850, grad_fn=<ExpandBackward0>),
 'persistent_gamma': tensor(0.3135, grad_fn=<ExpandBackward0>),
 'persistent_hosp': tensor(0.1309, grad_fn=<ExpandBackward0>),
 'persistent_death_hosp': tensor(0.0634, grad_fn=<ExpandBackward0>),
 'persistent_I0': tensor(11.1051, grad_fn=<ExpandBackward0>)}

In [52]:
num_samples = 10
calibrated_sample_results = pyciemss.sample(petri1, end_time, logging_step_size, num_samples, 
                start_time=start_time, inferred_parameters=parameter_estimates)
calibrated_sample_results

{'data':      timepoint_id  sample_id  persistent_beta_param  persistent_gamma_param  \
 0               0          0               0.583356                 0.30949   
 1               1          0               0.583356                 0.30949   
 2               2          0               0.583356                 0.30949   
 3               3          0               0.583356                 0.30949   
 4               4          0               0.583356                 0.30949   
 ..            ...        ...                    ...                     ...   
 145            10          9               0.606524                 0.30237   
 146            11          9               0.606524                 0.30237   
 147            12          9               0.606524                 0.30237   
 148            13          9               0.606524                 0.30237   
 149            14          9               0.606524                 0.30237   
 
      persistent_hosp_param  p

In [53]:
# TODO: make_plot for calibrated samples

In [None]:
# Sanity check: compare calibrated parameters to original