In [None]:
import pandas as pd
from matplotlib import pyplot as plt
import datetime
import os
import warnings

from summer.utils import ref_times_to_dti

from autumn.core.project.project import get_project
from autumn.core.project.timeseries import load_timeseries
from autumn.core.plots.utils import REF_DATE
from autumn.core.runs.managed import ManagedRun
from autumn.core.runs.calibration.utils import get_posteriors
from autumn.core.utils.pandas import pdfilt
from autumn.settings.region import Region
from autumn.core.plots.plotter.base_plotter import COLOR_THEME
from autumn.core.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis
from autumn.core.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.calibration.utils import get_uncertainty_df

In [None]:
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
run_id = "sm_sir/malaysia/1664889114/ac07778"
region = "malaysia"

In [None]:
mr = ManagedRun(run_id)

In [None]:
# full_run = mr.full_run.get_derived_outputs()
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()
mcmc_params = mr.calibration.get_mcmc_params()

In [None]:
project = get_project("sm_sir", region, reload=True)

In [None]:
project_root = os.path.abspath(os.path.join(os.getcwd(), "..", "..", ".."))
model, country, run, commit = run_id.split("/")
project_file_path = os.path.join(project_root, "autumn", "projects", model, country, country, "timeseries.json")
all_targets = load_timeseries(project_file_path)
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(REF_DATE, all_targets[target].index)

In [None]:
title_lookup = {
    "notifications": "Daily number of notified Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "accum_deaths": "Cumulative number of Covid-19 deaths",
    "incidence": "Daily incidence (incl. asymptomatics and undetected)",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated",
    "prop_incidence_strain_delta": "Proportion of Delta variant in new cases",
    "prop_incidence_strain_alpha_beta":  "Proportion of Alpha variant in new cases",
    "prop_ever_infected": "Proportion ever infected",
    "prop_detected_traced": "Proportion of cases contact traced"
}


In [None]:
plot_left_date = datetime.date(2021, 4, 5)
plot_right_date = datetime.date(2022, 10, 15)
colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))
outputs = (
    "notifications", 
    "cdr",
    "hospital_admissions",
    "hospital_occupancy",
    "icu_admissions", 
    "icu_occupancy",
    "prop_ever_infected",
    "incidence",
    "proportion_seropositive"
)

In [None]:
def plot_outputs(outputs, left_date, right_date):
    fig, axes = plt.subplots(2, 2, figsize=(12, 8), sharex="all")
    for i_ax, axis in enumerate(axes.reshape(-1)):
        output = outputs[i_ax]
#         for scenario in [0] + list(range(1, 4)):
        for scenario in [0, 2, 4]:
            colour = colours[scenario]
            results_df = results[(output, scenario)]
            indices = results_df.index
            interval_label = "baseline" if scenario == 0 else project.param_set.scenarios[scenario - 1]["description"]
            scenario_zorder = 10 if scenario == 0 else scenario
#             axis.fill_between(
#                 indices, 
#                 results_df[0.025], results_df[0.975], 
#                 color=colour, 
#                 alpha=0.5,
#                 label="_nolegend_",
#                 zorder=scenario_zorder,
#             )
#             axis.fill_between(
#                 indices, 
#                 results_df[0.25], results_df[0.75], 
#                 color=colour, alpha=0.7, 
#                 label=interval_label,
#                 zorder=scenario_zorder
#             )
            axis.plot(indices, results_df[0.500], color=colour,label=interval_label)
            if output in all_targets and len(all_targets[output]) > 0:
                all_targets[output].plot.line(
                    ax=axis, 
                    linewidth=0., 
                    markersize=8.,
                    marker="o",
                    markerfacecolor="w",
                    markeredgecolor="w",
                    alpha=0.2,
                    label="_nolegend_",
                    zorder=11,
                )
            if output in targets:
                targets[output].plot.line(
                    ax=axis, 
                    linewidth=0., 
                    markersize=5., 
                    marker="o", 
                    markerfacecolor="k",
                    markeredgecolor="k",
                    label="_nolegend_",
                    zorder=12,
                )
            axis.tick_params(axis="x", labelrotation=45)
            axis.set_xlim(left=left_date, right = right_date)
            axis.set_title(title_lookup[output])
        if i_ax == 0:
            axis.legend(loc = "upper left")
    fig.tight_layout()

In [None]:
outputs_to_plot = (
    "notifications", 
    "infection_deaths",
    "hospital_occupancy",
    "icu_occupancy",
    
)
plot_outputs(outputs_to_plot, plot_left_date, plot_right_date)