In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis

from autumn.dashboards.calibration_results.plots import get_uncertainty_df


In [None]:
# Specify model details
model = Models.COVID_19
region = Region.HO_CHI_MINH_CITY
dirname = "2021-08-20"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)

In [None]:
# get R_hat diagnostics
r_hats = calculate_r_hats(mcmc_params, mcmc_tables, burn_in=0)
for key, value in r_hats.items():
    print(f"{key}: {value}")

In [None]:
titles = {
    "notifications": "Daily number of Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated"
}

def plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, show_v_lines=True):

    # plot options
    x_min, x_max = 590 , 775  # 475, 650 
    title = titles[output_name]
    title_fontsize = 18
    label_font_size = 15
    linewidth = 3
    n_xticks = 10

    # initialise figure
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()

    # prepare colors for ucnertainty
    n_scenarios_to_plot = len(scenario_list)
    uncertainty_colors = _apply_transparency(COLORS[:n_scenarios_to_plot], ALPHAS[:n_scenarios_to_plot])

    if output_type == "MLE":
        derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)
    for i, scenario in enumerate(scenario_list):    
        linestyle = sc_linestyles[scenario]
        color = sc_colors[scenario]

        if output_type == "MLE":
            times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
            axis.plot(times, values, color=color, linestyle=linestyle, linewidth=linewidth)
        elif output_type == "median":
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                [_, _, _, color],
                overlay_uncertainty=False,
                start_quantile=0,
                zorder=scenario + 1,
                linestyle=linestyle,
                linewidth=linewidth,
             )
        elif output_type == "uncertainty":
            scenario_colors = uncertainty_colors[i]         
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                scenario_colors,
                overlay_uncertainty=True,
                start_quantile=0,
                zorder=scenario + 1,
             )
        else:
            print("Please use supported output_type option")


    axis.set_xlim((x_min, x_max))
    axis.set_title(title, fontsize=title_fontsize)
    plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
    plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
    change_xaxis_to_date(axis, REF_DATE)
    plt.locator_params(axis="x", nbins=n_xticks)

    if show_v_lines:
        release_dates = {624: "15 Sep 2021", 609: "31 Aug 2021"}
        y_max = plt.gca().get_ylim()[1]
        linestyles = ["dashdot", "solid"]
        i = 0
        for time, date in release_dates.items():
            plt.vlines(time, ymin=0, ymax=y_max, linestyle=linestyles[i])
            text = f"Lockdown relaxed on {date}"
            plt.text(time - 5, .5*y_max, text, rotation=90, fontsize=11)
            i += 1
            
    return axis

early_release_linestyle = "solid"
late_release_linestyle = "dashdot"            
vacc_colors = ["crimson", "seagreen","slateblue", "coral"]

sc_linestyles = ["dotted"] + [late_release_linestyle] * 4 + [early_release_linestyle] * 4
sc_colors = ["black"] + vacc_colors * 2

# Plot outputs
output_type = "MLE"  # one of ["MLE", "median", "uncertainty"]
output_names = ["notifications", "infection_deaths", "hospital_occupancy", "icu_occupancy"]

# output_names = ["proportion_vaccinated"]

# for output_type in ["median", "MLE"]:
#     scenario_list = [0, 1, 2, 3, 4, 5, 6, 7, 8]
#     for output_name in output_names:
#         plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, True)
#         path = f"outputs/{output_type}/{output_name}.png"
#         plt.savefig(path)
    


In [None]:
output_type = "uncertainty"
for i_comparaison in range(4):
    scenario_list = [0, i_comparaison + 1, i_comparaison + 5]
    for output_name in output_names:
        plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, True)
        path = f"outputs/{output_type}/{output_name}_scenario_V{i_comparaison + 1}.png"
        plt.savefig(path)

In [None]:
for output_name in output_names + ["cdr"]:
    axis = plot_outputs("uncertainty", output_name, [0], sc_linestyles, sc_colors, False)
    path = f"outputs/calibration/{output_name}.png"    
    
    targets = project.plots
    targets = {k: v for k, v in targets.items() if v["output_key"] == output_name}
    values, times = _get_target_values(targets, output_name)
#     trunc_values = [v for (v, t) in zip(values, times) if x_low <= t <= x_up]
#     trunc_times = [t for (v, t) in zip(values, times) if x_low <= t <= x_up]
    _plot_targets_to_axis(axis, values, times, on_uncertainty_plot=True)
    
    plt.savefig(path)




In [None]:
def get_data(output_name, scenario_idx, quantile):
    mask = (
            (uncertainty_df["type"] == output_name)
            & (uncertainty_df["scenario"] == scenario_idx)
            & (uncertainty_df["quantile"] == quantile)
        )
    df = uncertainty_df[mask]
    times = df.time.unique()[1:]
    values = df["value"].tolist()[1:]
        
    return times, values
COVID_BASE_DATE = pd.datetime(2019, 12, 31)

for scenario in range(9):
    df = pd.DataFrame()
    t, _ = get_data("notifications", scenario, 0.025)
    df["date"] = pd.to_timedelta(t, unit="days") + (COVID_BASE_DATE)
    for output in output_names:
        for quantile in [0.025, 0.50, 0.975]:
            _, v = get_data(output, scenario, quantile)
            name = f"{output}_{quantile}"

            df[name] = v
            
    df.to_csv(f"outputs_scenario_{scenario}.csv")
            
