In [None]:
# Import packages
import os
from math import log

from matplotlib import pyplot as plt
import pandas as pd
import numpy as np

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id, get_posterior
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis, split_mcmc_outputs_by_chain

from autumn.calibration.utils import get_uncertainty_df
import yaml


In [None]:
# Specify model details
model = Models.COVID_19
region = Region.VICTORIA_2020
dirname_lhs = "2021-09-05"

dirname_main = "2021-09-04"
main_burn_in = 8000

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)

# load data from LHS analysis
calib_path = os.path.join(project_calib_dir, dirname_lhs)
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

# Load data from main analysis
main_calib_path = os.path.join(project_calib_dir, dirname_main)
main_mcmc_tables = db.load.load_mcmc_tables(main_calib_path)
main_mcmc_params = db.load.load_mcmc_params_tables(main_calib_path)

# param_names = list(mcmc_params[0]["name"].unique())
param_names = ['victorian_clusters.metro.mobility.microdistancing.face_coverings_adjuster.parameters.effect', 'sojourn.compartment_periods_calculated.active.total_period', 'contact_rate', 'victorian_clusters.intercluster_mixing', 'infectious_seed', 'infection_fatality.top_bracket_overwrite', 'clinical_stratification.props.hospital.multiplier', 'testing_to_detection.assumed_cdr_parameter', 'sojourn.compartment_periods.icu_early', 'victorian_clusters.metro.mobility.microdistancing.behaviour_adjuster.parameters.effect']

# Param traces

In [None]:
# Get median estimates
medians, lower, upper = {}, {}, {}

for param_name in param_names:
    param_values = get_posterior(main_mcmc_params, main_mcmc_tables, param_name, burn_in=main_burn_in)
    medians[param_name] = np.quantile(param_values, 0.5)
    lower[param_name] = np.quantile(param_values, 0.025)
    upper[param_name] = np.quantile(param_values, 0.975)

    

In [None]:
n_rows, n_cols = 5, 2
fig, axes = plt.subplots(n_rows, n_cols, sharex=True, sharey=False, figsize=(15, 18))
plt.style.use("ggplot")


chain_ids = list(mcmc_params[0]["chain"].unique())

mcmc_params_list, mcmc_tables_list = split_mcmc_outputs_by_chain(mcmc_params, mcmc_tables)

i_row, i_col = 0, 0
for param_name in param_names:
    axis = axes[i_row, i_col]
    
    for i_chain in range(len(mcmc_params_list)):            
        param_values = get_posterior([mcmc_params_list[i_chain]], [mcmc_tables_list[i_chain]], param_name, burn_in=0)
        axis.plot(param_values, alpha=0.8, linewidth=0.5)  #, color=COLOR_THEME[i_chain])   
        
    h_color = "black"
    axis.hlines(y=medians[param_name], xmin = 0, xmax=len(param_values), zorder=100, color=h_color,  linestyle="solid")
    axis.hlines(y=lower[param_name], xmin = 0, xmax=len(param_values), zorder=100, color=h_color, linestyle="dotted")
    axis.hlines(y=upper[param_name], xmin = 0, xmax=len(param_values), zorder=100, color=h_color, linestyle="dotted")
    
    i_col += 1
    if i_col == n_cols:
        i_row += 1 
        i_col = 0
    
    axis.set_ylabel(get_plot_text_dict(param_name), fontsize=15) 
    
    # axis.set_ylim((min_ll - 2, 3))

plt.tight_layout()
plt.savefig("lhs_start_traces_median.png", dpi=150)
plt.savefig("lhs_start_traces_median.pdf")

# Posterior vs params

In [None]:
def plot_param_vs_loglike(mcmc_tables, mcmc_params, param_name, burn_in, axis, posterior=False):
    var_key = "ap_loglikelihood" if posterior else "loglikelihood"
    for mcmc_df, param_df in zip(mcmc_tables, mcmc_params):
        df = param_df.merge(mcmc_df, on=["run", "chain"])
        mask = (df["accept"] == 1) & (df["name"] == param_name) & (df["run"] > burn_in)
        df = df[mask]

        max_loglike = max(df[var_key]) + 1
        min_loglike = min(df[var_key])
        
        chain_ids = list(df["chain"].unique())
#         chain_ids.reverse()

        for chain_id in chain_ids:
            chain_df = df[df["chain"] == chain_id]

            param_values = chain_df["value"]

            # apply transformation to improve readability
            # trans_loglikelihood_values = [-log(-v + max_loglike) for v in chain_df[var_key]]

            trans_loglikelihood_values = [-log(-v + max_loglike) for v in chain_df[var_key]]
            zorders = list(np.random.randint(0, 10, size=len(param_values)))
            
            axis.plot(list(param_values)[0], list(trans_loglikelihood_values)[0], "*", color="violet",  markersize=15)
            axis.plot(param_values, trans_loglikelihood_values, ".", color=COLOR_THEME[chain_id], markersize=5)

            
        return -log(-min_loglike + max_loglike)

In [None]:
n_rows, n_cols = 4, 3
fig, axes = plt.subplots(n_rows, n_cols, sharex=False, sharey=True, figsize=(15, 18))


# fig = plt.figure(figsize=(12, 8))
plt.style.use("ggplot")

i_row, i_col = 0, 0
for param_name in param_names:
    axis = axes[i_row, i_col]
    min_ll = plot_param_vs_loglike(mcmc_tables, mcmc_params, param_name, 0, axis, posterior=False)

    i_col += 1
    if i_col == n_cols:
        i_row += 1 
        i_col = 0
    
    axis.set_title(get_plot_text_dict(param_name)) 
    if i_col == 1:
        axis.set_ylabel("likelihood (transformed)")
    axis.set_ylim((min_ll - 2, 0.5))
    
axis_to_shut = [ [3, 1], [3, 2] ]
for ax_ids in axis_to_shut:
    axis = axes[ax_ids[0], ax_ids[1]]
    axis.set_axis_off()
    
plt.tight_layout()
plt.savefig("likelihood_against_params.png", dpi=150)
plt.savefig("likelihood_against_params.pdf")