In [None]:
from matplotlib import pyplot as plt
import os
from math import ceil
from summer.utils import ref_times_to_dti

from autumn.core.project import get_project, load_timeseries
from autumn.core.plots.utils import REF_DATE
from autumn.core.runs.managed import ManagedRun

import numpy as np
from random import sample
from scipy.stats import truncnorm
from itertools import combinations

## Specify the run id and the outputs to plot

In [None]:
run_id = "hierarchical_sir/multi/1653962883/0e2003c"
BURN_IN = 5000

# Outputs requested for full run plots
outputs_for_full_run_plots = (
    "incidence_AUS",
    "incidence_ITA",
)

# Outputs requested with uncertainty
outputs_to_plot_with_uncertainty = (
    "incidence_AUS",
    "incidence_ITA",
)
scenarios_to_plot_with_uncertainty = [0]

In [None]:
param_lookup = {
}

title_lookup = {
}

## Load run outputs and pre-process data

In [None]:
model, region = run_id.split("/")[0:2]
mr = ManagedRun(run_id)
full_run = mr.full_run.get_derived_outputs()
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()
mcmc_params = mr.calibration.get_mcmc_params()
mcmc_runs = mr.calibration.get_mcmc_runs()

In [None]:
project = get_project(model, region, reload=True)
params_list = list(mcmc_params.columns)
chains = mcmc_runs.chain.unique()
mcmc_table = mcmc_params.merge(mcmc_runs, on=["urun"])

post_burnin_uruns = mcmc_runs[mcmc_runs["run"] >= BURN_IN].index
post_burnin_mcmc_table = mcmc_table.filter(items=post_burnin_uruns, axis=0)

## Plot posteriors

In [None]:
plt.style.use("ggplot")
n_params = len(params_list)
n_col = 3
n_row = ceil(n_params / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 4, n_row * 3.5))

for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_params:
        axis.set_visible(False)
    else:
        param = mcmc_params.columns[i_ax]
        axis.hist(post_burnin_mcmc_table[param], weights=post_burnin_mcmc_table["weight"])

        par_name = param if param not in param_lookup else param_lookup[param]
        axis.set_title(par_name)
        axis.set_xlabel(par_name)

fig.suptitle("parameter posterior histograms", fontsize=15, y=1)
fig.tight_layout()

## Plot traces

In [None]:
n_col = 2
n_row = ceil(n_params / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(8*n_col, 4*n_row))

for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_params:
        axis.set_visible(False)
    else:
        param = params_list[i_ax]    
        for chain in chains:
            if chain == 6:
                continue
            chain_filter = mcmc_table["chain"] == chain
            axis.plot(mcmc_table[chain_filter]['run'], mcmc_table[chain_filter][param], lw=.5)
        par_name = param if param not in param_lookup else param_lookup[param]
        axis.set_title(par_name)

# Model outputs

### Get all targets, including those not used in calibration, to use as a validation

In [None]:
all_targets = load_timeseries(os.path.join(project.get_path(), "timeseries.json"))
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(REF_DATE, all_targets[target].index)

### Calibration fits with individual model runs

In [None]:
n_outputs = len(outputs_for_full_run_plots)
n_col = 2
n_row = ceil(n_outputs / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 8, n_row * 6), sharex="all")
for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_outputs:
            axis.set_visible(False)
    else:
        output = outputs_for_full_run_plots[i_ax]
        scenario_chain = (full_run["scenario"] == 0) & (full_run["chain"] == 0)
        for i_run in full_run[scenario_chain]["run"].unique():
            selection = full_run[(full_run["run"] == i_run) & scenario_chain]
            axis.plot(ref_times_to_dti(REF_DATE, selection["times"]), selection[output])
        if output in all_targets and len(all_targets[output]) > 0:
            all_targets[output].plot.line(ax=axis, linewidth=0., markersize=10., marker="o")
            axis.scatter(all_targets[output].index, all_targets[output], color="k", s=5, alpha=0.5, zorder=10)
        if output in targets:
            axis.scatter(targets.index, targets[output], facecolors="r", edgecolors="k", s=15, zorder=10)
        axis.tick_params(axis="x", labelrotation=45)
        
        title = output if output not in title_lookup else title_lookup[output]
        axis.set_title(title)

### Calibration fits with uncertainty

In [None]:
colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))

def plot_outputs_with_uncertainty(outputs, scenarios = [0]):
    n_outputs = len(outputs)
    n_col = 2
    n_row = ceil(n_outputs / n_col)
    fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 8, n_row * 6), sharex="all")
    for i_ax, axis in enumerate(axes.reshape(-1)):
        if i_ax >= n_outputs:
            axis.set_visible(False)
        else:
            output = outputs[i_ax]
            for scenario in scenarios:
                colour = colours[scenario]
                results_df = results[(output, scenario)]
                indices = results_df.index
                interval_label = "baseline" if scenario == 0 else project.param_set.scenarios[scenario - 1]["description"]
                scenario_zorder = 10 if scenario == 0 else scenario
                axis.fill_between(
                    indices, 
                    results_df[0.025], results_df[0.975], 
                    color=colour, 
                    alpha=0.5,
                    label="_nolegend_",
                    zorder=scenario_zorder,
                )
                axis.fill_between(
                    indices, 
                    results_df[0.25], results_df[0.75], 
                    color=colour, alpha=0.7, 
                    label=interval_label,
                    zorder=scenario_zorder
                )
                axis.plot(indices, results_df[0.500], color=colour)
                if output in all_targets and len(all_targets[output]) > 0:
                    all_targets[output].plot.line(
                        ax=axis, 
                        linewidth=0., 
                        markersize=8.,
                        marker="o",
                        markerfacecolor="w",
                        markeredgecolor="w",
                        alpha=0.2,
                        label="_nolegend_",
                        zorder=11,
                    )
                if output in targets:
                    targets[output].plot.line(
                        ax=axis, 
                        linewidth=0., 
                        markersize=5., 
                        marker="o", 
                        markerfacecolor="k",
                        markeredgecolor="k",
                        label="_nolegend_",
                        zorder=12,
                    )
                axis.tick_params(axis="x", labelrotation=45)
                title = output if output not in title_lookup else title_lookup[output]
                axis.set_title(title)
            if i_ax == 0:
                axis.legend()
    fig.tight_layout()

In [None]:
plot_outputs_with_uncertainty(outputs_to_plot_with_uncertainty, scenarios_to_plot_with_uncertainty)

# Exploration of the hierarchical parameterisation

In [None]:
full_run['urun'] = full_run['chain'].astype(str).str.zfill(2) + '_' + full_run['run'].astype(str).str.zfill(6)
uruns = list(full_run.urun.unique())

In [None]:
min_ll, max_ll = post_burnin_mcmc_table['loglikelihood'].min(), post_burnin_mcmc_table['loglikelihood'].max()
max_ap_ll = post_burnin_mcmc_table['ap_loglikelihood'].max()

In [None]:
def plot_hierarchical_outputs(urun, title=None):
    x_vals = np.arange(0, 0.5, 0.001)
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    main_title = f"Run id: {urun}" if title is None else title
    fig.suptitle(main_title, fontsize=18)
    params = post_burnin_mcmc_table.loc[urun]

    ### plot model fits to data
    for i, loc in enumerate(["AUS", "ITA"]):
        axis = axes[0, i]
        output = f"incidence_{loc}"
        selection = full_run[(full_run["urun"] == urun)]
        axis.plot(ref_times_to_dti(REF_DATE, selection["times"]), selection[output])
    
        all_targets[output].plot.line(ax=axis, linewidth=0., markersize=10., marker="o")
        axis.scatter(all_targets[output].index, all_targets[output], color="k", s=5, alpha=0.5, zorder=10)
       
        axis.get_xaxis().set_visible(False)
        axis.set_title(f"Incidence in {loc}")

    ### plot hierarchical prior
    hyper_mu, hyper_sd = params['hyper_beta_mean'], params['hyper_beta_sd']

    axis = axes[1, 0]
    y_vals = truncnorm.pdf(x_vals, a=-hyper_mu / hyper_sd, b=np.inf, loc=hyper_mu, scale=hyper_sd)
    axis.plot(x_vals, y_vals, color="blue")
    axis.fill_between(x_vals, y_vals, color="blue", alpha=.1)    
    
    ymax = max(y_vals)
    axis.text(x=.2, y = .8 * ymax, s=f"mu={round(hyper_mu, 3)}\nsd={round(hyper_sd,3)}", color="blue", fontsize=10)

    axis.scatter(
        x=[params['beta.AUS'], params['beta.ITA']], 
        y=[0] * 2,
        color=["red", "green"],
        marker="o"
    )
    axis.scatter(
        x=[.1, .2], 
        y= [-0.03*ymax] * 2,
        color= ["red", "green"],
        marker="^"
    )
    axis.set_xlim((-.05, .5))
    axis.set_title("Hierarchical prior")

    ### print iteration scores
    log_prior_sum = round(sum(truncnorm.logpdf([params['beta.AUS'], params['beta.ITA']], a=-hyper_mu / hyper_sd, b=np.inf, loc=hyper_mu, scale=hyper_sd)), 2)
    fit_score = round(100. * ((params['loglikelihood'] - min_ll) / (max_ll - min_ll))**4)

    axis = axes[1, 1]
    fs = 15
    axis.text(.2, .8, s=f"fitting score = {fit_score}%",fontsize=fs)
    axis.text(.2, .6, s=f"log-prior = {log_prior_sum}",fontsize=fs)

    axis.set_xlim((0, 1))
    axis.set_ylim((0, 1))
    axis.set_frame_on(False)
    axis.get_xaxis().set_visible(False) 
    axis.get_yaxis().set_visible(False) 

### Plot randomly-selected runs

In [None]:
n_samples = 2
sampled_uruns = sample(uruns, n_samples)
for urun in sampled_uruns:
    plot_hierarchical_outputs(urun)

### Plot highest-likelihood runs

In [None]:
# best fit 
best_fit_urun = post_burnin_mcmc_table[post_burnin_mcmc_table['loglikelihood'] == max_ll].index[0]
plot_hierarchical_outputs(best_fit_urun, "Best model fit (highest likelihood)")

# best ap-likelihood
best_ap_urun = post_burnin_mcmc_table[post_burnin_mcmc_table['ap_loglikelihood'] == max_ap_ll].index[0]
plot_hierarchical_outputs(best_ap_urun, "Best a-posteriori likelihood")

### Correlations between parameters

In [None]:
params_list = post_burnin_mcmc_table.columns.to_list()[:5]
param_pairs = list(combinations(params_list, 2))

fig, axes = plt.subplots(4, 3, figsize=(20, 20))

for i_ax, axis in enumerate(axes.reshape(-1)):
        if i_ax >= len(param_pairs):
            axis.set_visible(False)
        else:
            param_pair = param_pairs[i_ax]
            axis.scatter(post_burnin_mcmc_table[param_pair[0]], post_burnin_mcmc_table[param_pair[1]], s=10, color="purple")
            axis.set_xlabel(param_pair[0])
            axis.set_ylabel(param_pair[1])
