In [None]:
import arviz as az
import numpy as np
import pandas as pd
from plotly import graph_objects as go
from emutools.utils import load_param_info
from inputs.constants import SUPPLEMENT_PATH, RUN_IDS, RUNS_PATH, PRIMARY_ANALYSIS, BURN_IN

In [None]:
import seaborn as sns

In [None]:
idata = az.from_netcdf(RUNS_PATH / RUN_IDS[PRIMARY_ANALYSIS] / 'output/calib_full_out.nc')
idata = idata.sel(draw=np.s_[BURN_IN:])

In [None]:
post_df = idata.posterior.to_dataframe()
post_df = post_df[[i for i in post_df.columns if '_dispersion' not in i]]

In [None]:
key_params = [
    'contact_rate', 
    'latent_period', 
    'infectious_period', 
    'natural_immunity_period', 
    'start_cdr', 
    'imm_infect_protect',
    'ifr_adjuster',
    'ba2_escape',
    'ba5_escape',
    'ba2_rel_ifr',
    'imm_prop',
]

In [None]:
imm_params = [
    'natural_immunity_period', 
    'imm_infect_protect',
    'ba2_escape',
    'ba5_escape',
    'imm_prop',
]

In [None]:
def get_bin_centres(bins):
    return (bins + (bins[1] - bins[0]) / 2)[:-1]    

def get_hist_df_from_params(data, param_1, param_2, bins):
    hist_data = np.histogram2d(data[param_1], data[param_2], bins=bins)
    x_bins_centres = get_bin_centres(hist_data[1])
    y_bins_centres = get_bin_centres(hist_data[2])
    return pd.DataFrame(hist_data[0], index=x_bins_centres, columns=y_bins_centres)

def plot_3d_param_hist(param_1, param_2, abbreviations):
    hist_df = get_hist_df_from_params(post_df, param_1, param_2, 50)
    fig = go.Figure(data=[go.Surface(z=hist_df)])
    xaxis_spec = {'title': abbreviations[param_1]}
    yaxis_spec = {'title': abbreviations[param_2]}
    zaxis_spec = {'title': 'density', 'range': (0.0, hist_df.max().max() * 1.5)}
    all_specs = {'xaxis': xaxis_spec, 'yaxis': yaxis_spec, 'zaxis': zaxis_spec}
    return fig.update_layout(height=800, scene=all_specs, margin={i: 25 for i in ['t', 'b', 'l', 'r']})

In [None]:
param_abbreviations = load_param_info()['abbreviations']
param_1 = 'natural_immunity_period'
param_2 = 'ba2_escape'
plot_3d_param_hist(param_1, param_2, param_abbreviations)

In [None]:
az.plot_pair(idata)

In [None]:
az.plot_pair(idata, kind='hexbin')

In [None]:
az.plot_pair(idata, kind='kde')

In [None]:
sns.pairplot(post_df)

In [None]:
sns.pairplot(post_df[key_params])

In [None]:
sns.pairplot(post_df[imm_params])