In [None]:
from matplotlib import pyplot as plt
import os
import warnings
from math import ceil
from summer.utils import ref_times_to_dti

from autumn.core.project import get_project, load_timeseries
from autumn.core.plots.utils import REF_DATE
from autumn.core.runs.managed import ManagedRun
from autumn.core.plots.calibration.plots import plot_prior, calculate_r_hats

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning) 

## Specify the run id and the outputs to plot

In [None]:
run_id = "sm_covid2/france/1678316952/4a682e4"
BURN_IN = 10

# Outputs requested for full run plots
outputs_for_full_run_plots = (
    "infection_deaths",
    # "transformed_random_process",
    # "cumulative_infection_deaths",
    "prop_ever_infected"
)

# Outputs requested with uncertainty
outputs_to_plot_with_uncertainty = (
    "infection_deaths",
    "transformed_random_process",
    "cumulative_infection_deaths",
    "prop_ever_infected"
)
scenarios_to_plot_with_uncertainty = [0, 1]

In [None]:
param_lookup = {
    "contact_rate": "infection risk per contact",
    "testing_to_detection.assumed_cdr_parameter": "CDR at one test per 1,000 population per day",
    "voc_emergence.omicron.new_voc_seed.start_time": "Omicron emergence date",
    "voc_emergence.omicron.contact_rate_multiplier": "relative transmissibility Omicron",
    "age_stratification.cfr.multiplier": "modification to fatality rate",
    "age_stratification.prop_hospital.multiplier": "modification to hospitalisation rate",
}

title_lookup = {
    "notifications": "daily notifications",
    "infection_deaths": "COVID-19-specific deaths",
    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "proportion_seropositive": "proportion recovered from COVID-19",
    "incidence": "daily new infections",
    "prop_incidence_strain_delta": "proportion of cases due to Delta",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",
    "cdr": "case detection rate",
}

## Load run outputs and pre-process data

In [None]:
model, region = run_id.split("/")[0:2]
mr = ManagedRun(run_id)
full_run = mr.full_run.get_derived_outputs()
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()
mcmc_params = mr.calibration.get_mcmc_params()
mcmc_runs = mr.calibration.get_mcmc_runs()

In [None]:
project = get_project(model, region, reload=True)
params_list = list(mcmc_params.columns)
chains = mcmc_runs.chain.unique()
mcmc_table = mcmc_params.merge(mcmc_runs, on=["urun"])

post_burnin_uruns = mcmc_runs[mcmc_runs["run"] >= BURN_IN].index
post_burnin_mcmc_table = mcmc_table.filter(items=post_burnin_uruns, axis=0)

### Determine MLE and MAP run details

In [None]:
mle_urun = mcmc_runs["loglikelihood"].idxmax()
mle_chain, mle_irun = mle_urun.split("_")
mle_chain = mle_chain.lstrip("0")
mle_irun = mle_irun.lstrip("0")
print(f"MLE found in chain {mle_chain} run {mle_irun}")

if mcmc_runs.loc[mle_urun]["accept"] == 1:
    mle_params = mcmc_params.loc[mle_urun]
    mle_accepted = True
else:
    print("MLE params were not accepted")
    mle_params = None
    mle_accepted = False

map_urun = mcmc_runs["ap_loglikelihood"].idxmax()
map_chain, map_irun = map_urun.split("_")
map_chain = map_chain.lstrip("0")
map_irun = map_irun.lstrip("0")
map_params = mcmc_params.loc[map_urun]
print(f"MAP found in chain {map_chain} run {map_irun}")

start_params = mcmc_params.loc[mcmc_runs[mcmc_runs["run"] == 0].index]

### Calculate R_hats

In [None]:
r_hats = calculate_r_hats([mcmc_params], [mcmc_runs], BURN_IN)
r_hats

## Plot posteriors

In [None]:
plt.style.use("ggplot")
n_params = len(params_list)
n_col = 3
n_row = ceil(n_params / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 5, n_row * 3.5))

for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_params:
        axis.set_visible(False)
    else:
        param = mcmc_params.columns[i_ax]

        prior_dict = [d for d in project.calibration.all_priors if d['param_name'] == param][0]
        if prior_dict['distri_params'] is not None:
            plot_prior(0, prior_dict, ax=axis, print_distri=False, alpha=.5)

        axis.hist(post_burnin_mcmc_table[param], weights=post_burnin_mcmc_table["weight"], density=True, bins=15, color='coral', label="posterior")

        # Mark MLE and MAP values
        if mle_accepted:
            axis.scatter(mle_params[param], 0, color='black', marker="*", label="MLE", zorder=10, clip_on=False)
        axis.scatter(map_params[param], 0, color='green', marker="*", label="MAP", zorder=9, clip_on=False)

        # Mark MCMC starting points
        axis.scatter(start_params[param], [0] * len(start_params[param]), color='blue', marker=2, label="seed", zorder=10, clip_on=False)

        par_name = param if param not in param_lookup else param_lookup[param]
        axis.set_title(par_name)
        axis.set_xlabel(par_name)        

        if i_ax == 0:
            axis.legend()

fig.suptitle("parameter posterior histograms", fontsize=15, y=1)
fig.tight_layout()

## Plot traces

In [None]:
n_col = 2
n_row = ceil(n_params / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(8*n_col, 4*n_row))

for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_params:
        axis.set_visible(False)
    else:
        param = params_list[i_ax]    
        for chain in chains:
            chain_filter = mcmc_table["chain"] == chain
            axis.plot(mcmc_table[chain_filter]['run'], mcmc_table[chain_filter][param], lw=.5, label=chain)
        par_name = param if param not in param_lookup else param_lookup[param]
        axis.set_title(par_name)
        if i_ax == 0:
            leg = axis.legend(title="Chain:")
            # change the line width for the legend
            for line in leg.get_lines():
                line.set_linewidth(2)

# Model outputs

### Get all targets, including those not used in calibration, to use as a validation

In [None]:
calib_targets = project.calibration.targets

In [None]:
all_targets = {} 
for calib_target in calib_targets:    
    all_targets[calib_target.data.name] =  calib_target.data    
    all_targets[calib_target.data.name].index = ref_times_to_dti(REF_DATE, all_targets[calib_target.data.name].index)

### Calibration fits with individual model runs

In [None]:
n_outputs = len(outputs_for_full_run_plots)
n_col = 2
n_row = ceil(n_outputs / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 8, n_row * 6), sharex="all")
for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_outputs:
            axis.set_visible(False)
    else:
        output = outputs_for_full_run_plots[i_ax]
        scenario_chain = (full_run["scenario"] == 0) & (full_run["chain"] == 0)
        for i_run in full_run[scenario_chain]["run"].unique():
            selection = full_run[(full_run["run"] == i_run) & scenario_chain]
            axis.plot(ref_times_to_dti(REF_DATE, selection["times"]), selection[output])
        if output in all_targets and len(all_targets[output]) > 0:
            axis.scatter(all_targets[output].index, all_targets[output], facecolors="black", edgecolors="k", s=15, zorder=10)
        if output in targets:
            axis.scatter(targets.index, targets[output], facecolors="r", edgecolors="k", s=15, zorder=10)
        
        if mle_accepted:
            mle_selection = full_run[(full_run["run"] == int(mle_irun)) & (full_run["scenario"] == 0) & (full_run["chain"] == int(mle_chain))]
            axis.plot(ref_times_to_dti(REF_DATE, mle_selection["times"]), mle_selection[output], "--", color="black", zorder=15, label="MLE", linewidth=2)

        map_selection = full_run[(full_run["run"] == int(map_irun)) & (full_run["scenario"] == 0) & (full_run["chain"] == int(map_chain))]
        axis.plot(ref_times_to_dti(REF_DATE, map_selection["times"]), map_selection[output], "--", color="blue", zorder=15, label="MAP", linewidth=2)
        
        axis.tick_params(axis="x", labelrotation=45)
        title = output if output not in title_lookup else title_lookup[output]
        axis.set_title(title)

        if i_ax == 0:
            axis.legend()


### Calibration fits with uncertainty

In [None]:
colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))

def plot_outputs_with_uncertainty(outputs, scenarios = [0]):
    n_outputs = len(outputs)
    n_col = 2
    n_row = ceil(n_outputs / n_col)
    fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 8, n_row * 6), sharex="all")
    for i_ax, axis in enumerate(axes.reshape(-1)):
        if i_ax >= n_outputs:
            axis.set_visible(False)
        else:
            output = outputs[i_ax]
            for scenario in scenarios:
                colour = colours[scenario]
                results_df = results[(output, scenario)]
                indices = results_df.index
                interval_label = "baseline" if scenario == 0 else project.param_set.scenarios[scenario - 1]["description"]
                scenario_zorder = 10 if scenario == 0 else scenario
                axis.fill_between(
                    indices, 
                    results_df[0.025], results_df[0.975], 
                    color=colour, 
                    alpha=0.5,
                    label="_nolegend_",
                    zorder=scenario_zorder,
                )
                axis.fill_between(
                    indices, 
                    results_df[0.25], results_df[0.75], 
                    color=colour, alpha=0.7, 
                    label=interval_label,
                    zorder=scenario_zorder
                )
                axis.plot(indices, results_df[0.500], color=colour)
                if output in all_targets and len(all_targets[output]) > 0:
                    all_targets[output].plot.line(
                        ax=axis, 
                        linewidth=0., 
                        markersize=8.,
                        marker="o",
                        markerfacecolor="w",
                        markeredgecolor="w",
                        alpha=0.2,
                        label="_nolegend_",
                        zorder=11,
                    )
                if output in targets:
                    targets[output].plot.line(
                        ax=axis, 
                        linewidth=0., 
                        markersize=5., 
                        marker="o", 
                        markerfacecolor="k",
                        markeredgecolor="k",
                        label="_nolegend_",
                        zorder=12,
                    )
                axis.tick_params(axis="x", labelrotation=45)
                title = output if output not in title_lookup else title_lookup[output]
                axis.set_title(title)
            if i_ax == 0:
                axis.legend()
    fig.tight_layout()

In [None]:
plot_outputs_with_uncertainty(outputs_to_plot_with_uncertainty, scenarios_to_plot_with_uncertainty)