In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import datetime

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis
from autumn.tools.utils.utils import flatten_list
from autumn.dashboards.calibration_results.plots import get_uncertainty_df
from numpy import mean
from math import floor, sqrt

from typing import List

In [None]:
# Specify model details
model = Models.COVID_19
region = Region.SRI_LANKA
dirname = "2022-04-05"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)
mcmc_runs = db.load.load_mcmc_run_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["calibration", "MLE", "median", "uncertainty", "csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
def get_posterior(mcmc_params, mcmc_tables, param_name, burn_in=0):
    weighted_vals = []
    for param_df, run_df in zip(mcmc_params, mcmc_tables):
        table_df = param_df.merge(run_df, left_index=True,  right_index=True)
        unweighted_vals = table_df[param_name]
        weights = table_df.weight
        for v, w in zip(unweighted_vals, weights):
            weighted_vals += [v] * w

    return pd.DataFrame(weighted_vals, columns=[param_name])

In [None]:
def split_mcmc_outputs_by_chain(mcmc_params, mcmc_runs, mcmc_tables):
    chain_ids = mcmc_runs[0]["chain"].unique().tolist()
    mcmc_params_list, mcmc_tables_list = [], []
    for i_chain in chain_ids:
        mcmc_params_list.append(
            mcmc_runs[0][mcmc_runs[0]["chain"] == i_chain]
        )
        mcmc_tables_list.append(
            mcmc_tables[0][mcmc_tables[0]["chain"] == i_chain]
        )

    return mcmc_params_list, mcmc_tables_list

In [None]:
def calculate_r_hat(posterior_chains):
    """
    Calculate the R_hat statistic for a single parameter. The code below is intended to be compatible with chains of
    different lengths. This is why the calculations may look slightly different compared to what is found in classic
    textbooks.
    :param posterior_chains: a dictionary, The keys are the chains ids and the values contain each chain's posterior
    sample.
    :return: the R_hat statistic (float)
    """
    m = len(posterior_chains)

    # Compute within-chain means
    means_per_chain = {}
    for j_chain, x_j in posterior_chains.items():
        means_per_chain[j_chain] = mean(x_j)

    # Compute overall mean
    flat_listed_values = sum(list(posterior_chains.values()), [])
    overall_mean = mean(flat_listed_values)

    # Compute between-chain variation (B / n)
    b_over_n = 1 / (m - 1) * sum([(means_per_chain[j_chain] - overall_mean)**2 for j_chain in range(m)])

    # Compute within-chain variation for each chain
    variation, chain_length = {}, {}
    for j_chain, x_j in posterior_chains.items():
        n_j = len(x_j)
        variation[j_chain] = sum([(x_j[i] - means_per_chain[j_chain])**2 for i in range(n_j)])
        chain_length[j_chain] = n_j

    # Compute the average of within-chain variances (W)
    w = 1 / m * sum([1 / (chain_length[j_chain] - 1) * variation[j_chain] for j_chain in range(m)])

    # Calculate the marginal posterior variance
    var_hat = 1 / m * sum([1 / chain_length[j_chain] * variation[j_chain] for j_chain in range(m)]) + b_over_n

    # Calculate R_hat
    r_hat = sqrt(var_hat / w)

    return r_hat

In [None]:
def calculate_r_hats(mcmc_params: List[pd.DataFrame], mcmc_tables: List[pd.DataFrame], burn_in: int):
    """
    Calculates the R_hat statistic for all parameters
    :return: a dictionary
    """

    # split tables by chain
    param_options = mcmc_params[0].columns.tolist()
    chain_idx = mcmc_tables[0].chain.unique()

    r_hats = {}
    for param_name in param_options:
        posterior_chains = {}
        for chain_id in chain_idx:
            mask =  mcmc_tables[0].chain == chain_id
            param_vals = mcmc_params[0][mask][param_name].to_list()
            weights = mcmc_tables[0][mask].weight.to_list()
            posterior_chains[chain_id] = flatten_list([[param_vals[i]] * w for i, w in enumerate(weights)])
        r_hats[param_name] = calculate_r_hat(posterior_chains)

    return r_hats

In [None]:
r_hats = calculate_r_hats(mcmc_params, mcmc_tables, burn_in=0)

In [None]:
print(r_hats)

In [None]:
df = pd.DataFrame(data=r_hats, index=[0])

df = (df.T)

print (df)

df.to_excel(r'C:\Users\pjay0011\Desktop\r_hat_statistics.xlsx')