In [None]:
import arviz as az
import numpy as np
import pandas as pd
from plotly import graph_objects as go
from emutools.utils import load_param_info
from arviz.labels import MapLabeller
import matplotlib as mpl
from inputs.constants import SUPPLEMENT_PATH, RUN_IDS, RUNS_PATH, PRIMARY_ANALYSIS, BURN_IN

az.rcParams['plot.max_subplots'] = 200
mpl.rcParams['axes.facecolor'] = (0.2, 0.2, 0.4)

In [None]:
idata = az.from_netcdf(RUNS_PATH / RUN_IDS[PRIMARY_ANALYSIS] / 'output/calib_full_out.nc')
idata = idata.sel(draw=np.s_[BURN_IN:])
abbreviations = load_param_info()['abbreviations']

In [None]:
epi_params = [param for param in idata.posterior.keys() if '_dispersion' not in param]

key_params = [
    'contact_rate', 
    'latent_period',
    'infectious_period', 
    'natural_immunity_period', 
    'start_cdr', 
    'imm_infect_protect',
    'ba2_escape',
    'ba5_escape',
    'imm_prop',
]

imm_params = [
    'natural_immunity_period',
    'imm_infect_protect',
    'ba2_escape',
    'ba5_escape',
    'imm_prop',
]

correlated_params = [
    'contact_rate', 
    'infectious_period', 
    'imm_prop',
    'start_cdr',
    'natural_immunity_period',
    'ifr_adjuster',
    'ba2_escape',
    'ba5_escape', 
    'ba5_seed_time',
]

In [None]:
fig = az.plot_pair(idata, var_names=epi_params, kind='kde', textsize=35, labeller=MapLabeller(var_name_map=abbreviations));

In [None]:
fig = az.plot_pair(idata, var_names=key_params, kind='kde', textsize=30, labeller=MapLabeller(var_name_map=abbreviations));

In [None]:
fig = az.plot_pair(idata, var_names=imm_params, kind='kde', textsize=20, labeller=MapLabeller(var_name_map=abbreviations));

In [None]:
fig = az.plot_pair(idata, var_names=correlated_params, kind='kde', textsize=30, labeller=MapLabeller(var_name_map=abbreviations));