In [None]:
from matplotlib import pyplot as plt
import os
import datetime
from math import ceil

from summer.utils import ref_times_to_dti

from autumn.settings.constants import COVID_BASE_DATETIME
from autumn.core.runs import ManagedRun
from autumn.core.project import get_project, load_timeseries

In [None]:
region = "northern_territory"

In [None]:
run_id = "sm_sir/northern_territory/1661221665/4c459e0"

In [None]:
mr = ManagedRun(run_id)

In [None]:
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()
mcmc_params = mr.calibration.get_mcmc_params()
n_params = mcmc_params.shape[1]
mcmc_runs = mr.calibration.get_mcmc_runs()
chains = mcmc_runs.chain.unique()
mcmc_table = mcmc_params.merge(mcmc_runs, on=["urun"])
full_run = mr.full_run.get_derived_outputs()

## Calibration parameters are:

In [None]:
params_list = list(mcmc_params.columns)
params_list

In [None]:
plt.style.use("ggplot")

burn_in = 500  # This needs to be confirmed over BuildKite
burnt_mcmc_params = mcmc_params[[int(i.split("_")[1]) > burn_in for i in mcmc_params.index]]

param_lookup = {
    "contact_rate": "infection risk per contact",
    "voc_emergence.ba_1.cross_protection.ba_2.early_reinfection": "BA.1 protection against BA.2",
    "voc_emergence.ba_2.cross_protection.ba_5.early_reinfection": "BA.2 protection against BA.5",
    "detect_prop": "detection of symptomatic cases",
    "sojourns.latent.total_time": "infection latent period",
}

unit_lookup = {
    "contact_rate": "probability",
    "voc_emergence.ba_1.cross_protection.ba_2.early_reinfection": "multiplier",
    "voc_emergence.ba_2.cross_protection.ba_5.early_reinfection": "multiplier",
    "detect_prop": "proportion",
    "sojourns.latent.total_time": "days",
}

fig, axes = plt.subplots(3, 2, figsize=(10, 12))
for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax < 5:
        param = mcmc_params.columns[i_ax]
        axis.hist(burnt_mcmc_params[param])
        axis.set_title(param_lookup[param])
        axis.set_xlabel(unit_lookup[param])
    if i_ax == 5:
        axis.set_axis_off()

fig.suptitle("parameter posterior histograms", fontsize=15, y=1)
fig.tight_layout()

In [None]:
n_col = 2
n_row = ceil(n_params / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(8*n_col, 4*n_row))

for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_params:
        axis.set_visible(False)
    else:
        param = params_list[i_ax]    
        for chain in chains:
            chain_filter = mcmc_table["chain"] == chain
            axis.plot(mcmc_table[chain_filter]['run'], mcmc_table[chain_filter][param], lw=.5)
        ymin, ymax = axis.get_ylim()
        axis.vlines(x=[burn_in], ymin=ymin, ymax=ymax, color="k", alpha=0.8, ls="--")
        par_name = param if param not in param_lookup else param_lookup[param]
        axis.set_title(par_name)

## Available model outputs are:

In [None]:
project = get_project("sm_sir", region, reload=True)

### Get all targets, including those not used in calibration, to use as a validation

In [None]:
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "..", ".."))
model, country, run, commit = run_id.split("/")
project_file_path = os.path.join(project_root, "autumn", "projects", model, "australia", country, "timeseries.secret.json")
all_targets = load_timeseries(project_file_path)
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)

In [None]:
colours = (
    (0.2, 0.2, 0.8), 
    (0.2, 0.8, 0.2), 
    (0.8, 0.2, 0.2), 
    (0.8, 0.8, 0.2), 
    (0.8, 0.2, 0.8), 
    (0.2, 0.8, 0.8), 
    (0.8, 0.8, 0.8),
)
title_lookup = {
    "notifications": "daily notifications",
    "infection_deaths": "COVID-19-specific deaths",
    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "proportion_seropositive": "proportion recovered from COVID-19",
    "incidence": "daily new infections",
    "prop_incidence_strain_delta": "proportion of cases due to Delta",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",
    "cumulative_infection_deaths": "cumulative COVID-19 deaths",
    "cumulative_hospital_admissions": "cumulative hospital admissions",
}

In [None]:
from datetime import timedelta
something = results.index + timedelta(days=187)
something

In [None]:
def plot_outputs(outputs, left_date, right_date, scenarios, scenario_2_shift=0):
    fig, axes = plt.subplots(1, 2, figsize=(12, 4.5), sharex="all")
    for i_ax, axis in enumerate(axes.reshape(-1)):
        output = outputs[i_ax]
        for scenario in scenarios:
            colour = colours[scenario]
            results_df = results[(output, scenario)]
            
            if scenario == 2:
                results_df.index += timedelta(days=scenario_2_shift)
            
            indices = results_df.index
            interval_label = "baseline" if scenario == 0 else project.param_set.scenarios[scenario - 1]["description"]
            scenario_zorder = 10 if scenario == 0 else scenario
            axis.fill_between(
                indices, 
                results_df[0.025], results_df[0.975], 
                color=colour, 
                alpha=0.5,
                label="_nolegend_",
                zorder=scenario_zorder,
            )
            axis.fill_between(
                indices, 
                results_df[0.25], results_df[0.75], 
                color=colour, alpha=0.7, 
                label=interval_label,
                zorder=scenario_zorder
            )
            axis.plot(indices, results_df[0.500], color=colour)
            if output in all_targets and len(all_targets[output]) > 0:
                all_targets[output].plot.line(
                    ax=axis, 
                    linewidth=0., 
                    markersize=8.,
                    marker="o",
                    markerfacecolor="w",
                    markeredgecolor="w",
                    alpha=0.4,
                    label="_nolegend_",
                    zorder=11,
                )
            if output in targets:
                targets[output].plot.line(
                    ax=axis, 
                    linewidth=0., 
                    markersize=5., 
                    marker="o", 
                    markerfacecolor="k",
                    markeredgecolor="k",
                    label="_nolegend_",
                    zorder=12,
                )
            axis.tick_params(axis="x", labelrotation=45)
            axis.set_xlim(
                left=left_date, 
                right=right_date
            )
            axis.set_title(title_lookup[output])
        if i_ax == 0:
            axis.legend()
    fig.tight_layout()

In [None]:
print(f"Available indicators with uncertainty are are: {set([i[0] for i in results.columns])}")

In [None]:
results.to_csv("nt_results.csv")

In [None]:
scenarios = [0, 1, 2]
outputs_to_plot = (
    "hospital_admissions", 
    "icu_admissions",
)
start_plot_time = datetime.date(2021, 6, 1)
end_plot_time = datetime.date(2022, 9, 1)
plot_outputs(
    outputs_to_plot, 
    start_plot_time,
    end_plot_time,
    scenarios=scenarios,
    scenario_2_shift=-187.,
)

In [None]:
all_targets = load_timeseries(os.path.join(project.get_path(), "timeseries.secret.json"))
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)

In [None]:
outputs_for_full_run_plots = (
    "notifications",
    "hospital_admissions",
    "icu_admissions",
    "infection_deaths",
)

n_outputs = len(outputs_for_full_run_plots)
n_col = 2
n_row = ceil(n_outputs / n_col)
fig, axes = plt.subplots(n_row, n_col, figsize=(n_col * 8, n_row * 6), sharex="all")
for i_ax, axis in enumerate(axes.reshape(-1)):
    if i_ax >= n_outputs:
            axis.set_visible(False)
    else:
        output = outputs_for_full_run_plots[i_ax]
        scenario_chain = (full_run["scenario"] == 0) & (full_run["chain"] == 1)
        for i_run in full_run[scenario_chain]["run"].unique():
            selection = full_run[(full_run["run"] == i_run) & scenario_chain]
            axis.plot(ref_times_to_dti(COVID_BASE_DATETIME, selection["times"]), selection[output])
        if output in all_targets and len(all_targets[output]) > 0:
            all_targets[output].plot.line(ax=axis, linewidth=0., markersize=8., marker="o", color="w", alpha=0.4)
            axis.scatter(all_targets[output].index, all_targets[output], color="w", s=5, alpha=0.4, zorder=10)
        if output in targets:
            axis.scatter(targets.index, targets[output], facecolors="k", edgecolors="k", s=15, zorder=10)
        axis.tick_params(axis="x", labelrotation=45)
        
        title = output if output not in title_lookup else title_lookup[output]
        axis.set_title(title)
        axis.set_xlim(
            left=datetime.datetime(2021, 12, 15),
            right=datetime.datetime(2022, 9, 1),
        )