In [None]:
import arviz as az
import numpy as np
import pandas as pd
from emutools.calibration import plot_posterior_comparison
from emutools.utils import load_param_info
from estival.sampling.tools import idata_to_sampleiterator
from aust_covid.inputs import get_ifrs
from aust_covid.model import build_model
from inputs.constants import PRIMARY_ANALYSIS, RUNS_PATH, RUN_IDS, BURN_IN, OUTPUTS_PATH
from emutools.tex import DummyTexDoc
from estival.sampling import tools as esamp
from emutools.calibration import plot_output_ranges
from aust_covid.plotting import plot_infection_processes
from aust_covid.calibration import get_priors, get_targets
from arviz.labels import MapLabeller
import matplotlib as mpl
from matplotlib import pyplot as plt
from emutools.calibration import get_like_components, plot_like_components_by_analysis

In [None]:
param_info = load_param_info()
abbreviations = param_info['abbreviations'].to_dict()
parameters = param_info['value'].to_dict() | get_ifrs(DummyTexDoc())
priors = get_priors(PRIMARY_ANALYSIS in ['vacc', 'both'], param_info['abbreviations'], DummyTexDoc())
prior_names = [p.name for p in priors]
targets = get_targets(DummyTexDoc())

In [None]:
requests = {
    'loglikelihood': (-28.0, -12.0, 'total likelihood'),
    'll_adult_seropos_prop': (-3.0, 2.0, 'seroprevalence contribution'),
    'll_deaths_ma': (-10.0, -3.0, 'deaths contribution'),
    'll_notifications_ma': (-17.0, -10.0, 'cases contribution'),
}
like_outputs = get_like_components(requests.keys())
like_comparison_fig = plot_like_components_by_analysis(like_outputs, 'kdeplot', plot_requests=requests, alpha=0.1, linewidth=1.5);
like_comparison_fig.savefig(OUTPUTS_PATH / 'aust_covid_figure2.pdf')
plt.close()

In [None]:
epi_model = build_model(DummyTexDoc(), abbreviations, mobility_ext=True, cross_ref=False)
analysis_folder = RUN_IDS[PRIMARY_ANALYSIS]
i_max = pd.read_hdf(RUNS_PATH / analysis_folder / 'output/results.hdf', 'likelihood')['logposterior'].idxmax()
idata = az.from_netcdf(RUNS_PATH / analysis_folder / 'output/calib_full_out.nc')
best_params = idata_to_sampleiterator(idata).convert('pandas').loc[i_max].to_dict()
parameters.update(best_params)
epi_model.run(parameters=parameters)

In [None]:
quantiles = [0.025, 0.25, 0.5, 0.75, 0.975]
spaghettis = {k: pd.read_hdf(RUNS_PATH / v / 'output/results.hdf', 'spaghetti') for k, v in RUN_IDS.items()}
quantile_outputs = {k: esamp.quantiles_for_results(spaghettis[k], quantiles) for k in RUN_IDS.keys()}
outputs = ['notifications_ma', 'deaths_ma', 'adult_seropos_prop', 'reproduction_number']
base_analysis_ranges = plot_output_ranges(quantile_outputs, targets, outputs, PRIMARY_ANALYSIS, quantiles)
base_analysis_ranges.write_image(OUTPUTS_PATH / 'aust_covid_figure3.pdf')

In [None]:
infection_fig = plot_infection_processes(epi_model.get_derived_outputs_df(), targets, 'notifications_ma')
infection_fig.update_layout(showlegend=False).write_image(OUTPUTS_PATH / 'aust_covid_figure4.pdf')

In [None]:
idata = az.from_netcdf(RUNS_PATH / analysis_folder / 'output/calib_full_out.nc')
idata = idata.sel(draw=np.s_[BURN_IN:])

In [None]:
comp_fig = plot_posterior_comparison(idata, priors, prior_names, abbreviations, 0.995, grid=[5, 4])
comp_fig.savefig(OUTPUTS_PATH / 'aust_covid_figure5.pdf')

In [None]:
az.rcParams['plot.max_subplots'] = 200
mpl.rcParams['axes.facecolor'] = (0.2, 0.2, 0.4)
key_params = [
    'contact_rate', 
    'latent_period',
    'infectious_period', 
    'natural_immunity_period', 
    'start_cdr', 
    'imm_infect_protect',
    'ba2_escape',
    'ba5_escape',
    'imm_prop',
]
fig = az.plot_pair(idata, var_names=key_params, kind='kde', textsize=30, labeller=MapLabeller(var_name_map=abbreviations))
fig[2][0].set_ylim((0.0, 300.0))
fig[3][3].set_xlim((0.0, 300.0))
plt.savefig(OUTPUTS_PATH / 'aust_covid_figure6.pdf')
plt.close()