In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import datetime

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis
from autumn.tools.utils.utils import flatten_list
from autumn.dashboards.calibration_results.plots import get_uncertainty_df
from numpy import mean
from math import floor, sqrt

from typing import List

In [None]:
# Specify model details
model = Models.COVID_19
region = Region.SRI_LANKA
dirname = "2022-04-12"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)
mcmc_runs = db.load.load_mcmc_run_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["calibration", "MLE", "median", "uncertainty", "csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
def param_traces(mcmc_params: List[pd.DataFrame], mcmc_tables: List[pd.DataFrame],burn_in: int,):
    
    optional_param_request=None
    
        
    # Except not the dispersion parameters - only the epidemiological ones
    parameters = [
        param
        for param in mcmc_params[0].columns.tolist()
        if "dispersion_param" not in param
    ]
    params_to_plot = optional_param_request if optional_param_request else parameters

#     # split tables by chain
    param_options = mcmc_params[0].columns.tolist()
    chain_idx = mcmc_tables[0].chain.unique()
    
    param_name = params_to_plot[1]
    fig = plt.figure(figsize=(12, 8))
    axis = fig.add_subplot(211)
    for chain_id in chain_idx:
        mask =  mcmc_tables[0].chain == chain_id
        param_vals = mcmc_params[0][mask][param_name].to_list()
        weights = mcmc_tables[0][mask].weight.to_list()
        posterior_chains = flatten_list([[param_vals[i]] * w for i, w in enumerate(weights)])
        plt.plot(posterior_chains,linewidth=0.3)
        
        axis.set_xlabel("iterations")
       
 

In [None]:
param_traces(mcmc_params, mcmc_tables, burn_in=0)