In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import datetime
import matplotlib as mpl


# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis

from autumn.dashboards.calibration_results.plots import get_uncertainty_df

In [None]:
# Specify model details
model = Models.COVID_19
region = Region.SRI_LANKA
dirname = "2022-08-02"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["calibration", "MLE", "median", "uncertainty", "csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
## get R_hat diagnostics
# r_hats = calculate_r_hats(mcmc_params, mcmc_tables, burn_in=0)
# for key, value in r_hats.items():
#     print(f"{key}: {value}")

In [None]:
titles = {
    "notifications": "Daily number of notified Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "accum_deaths": "Cumulative number of Covid-19 deaths",
    "incidence": "Daily incidence (incl. asymptomatics and undetected)",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "hospital_admissions": "Daily Covid-19 patients admitted to hospitals",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated",
    "prop_incidence_strain_delta": "Proportion of Delta variant in new cases",
    "prop_incidence_strain_alpha_beta":  "Proportion of Alpha variant in new cases",
    "prop_ever_infected": "recovered proportion",
    "prop_detected_traced": "Proportion of cases contact traced"
}

def plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, show_v_lines=False, x_min=590, x_max=775):
    
    title = titles[output_name]
    title_fontsize = 24
    label_font_size = 24
    linewidth = 3
    n_xticks = 10
    legend = True

    # initialise figure
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()

    # prepare colors for ucnertainty
    n_scenarios_to_plot = len(scenario_list)
    uncertainty_colors = _apply_transparency(COLORS[:n_scenarios_to_plot], ALPHAS[:n_scenarios_to_plot])

    if output_type == "MLE":
        derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)
        
    for i, scenario in enumerate(scenario_list):    
        linestyle = sc_linestyles[scenario]
        color = sc_colors[scenario]

        if output_type == "MLE":
            times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
            axis.plot(times, values, color=color, linestyle=linestyle, linewidth=linewidth)
            quantiles = 0
        elif output_type == "median":
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                [_, _, _, color],
                overlay_uncertainty=True,
                start_quantile=0,
                zorder=scenario + 1,
                linestyle=linestyle,
                linewidth=linewidth,
             )
        elif output_type == "uncertainty":
            scenario_colors = uncertainty_colors[i]         
            times, quantiles = _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                scenario_colors,
                overlay_uncertainty=True,
                start_quantile=0,
                zorder=scenario + 1,
                
             )
        else:
            print("Please use supported output_type option")

        if output_name == "notifications":
            if legend:
                ax = plt.gca()
                legend_elem = [mpl.patches.Patch(facecolor=uncertainty_colors[0][1], label='baseline'),
                               mpl.patches.Patch(facecolor=uncertainty_colors[1][1], 
                                                    label='no vaccination (S4)'),
                              ]
                ax.legend(handles=legend_elem, fontsize = 16, loc = "upper left")

                #plt.legend(["baseline","early lockdown 3rd wave (S2)"],fontsize=label_font_size, loc = "upper left") 
                
    
    axis.set_xlim((x_min, x_max))
    axis.set_title(title, fontsize=title_fontsize)
    plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
    plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
    change_xaxis_to_date(axis, REF_DATE)
    plt.locator_params(axis="x", nbins=n_xticks)

          
    return axis, quantiles, times

# Scenario plots with single lines

In [None]:
output_names = ["notifications","icu_occupancy","infection_deaths","hospital_occupancy","hospital_admissions","accum_deaths","prop_ever_infected"]

scenario_x_min, scenario_x_max = 420, 791
scenarios_to_plot = [0,4]
sc_colors = [COLOR_THEME[i] for i in scenario_list]
sc_linestyles = ["dotted"] + ["solid"] * (len(scenario_list) - 1)
for output_type in ["MLE"]:
    for output_name in output_names:
        plot_outputs(output_type, output_name, scenarios_to_plot, sc_linestyles, sc_colors, False, x_min=scenario_x_min, x_max=scenario_x_max)
        

#          path = os.path.join(base_dir, output_type, f"{output_name}.png")
#         plt.savefig(path)

# Uncertainty around scenarios


In [None]:
def to_date(x_value, date_str_format="%#d-%b-%Y"):
    ref_date = datetime.date(2019, 12, 31)
    date = ref_date + datetime.timedelta(days=int(x_value))
    return date.strftime(date_str_format)

In [None]:
output_type = "uncertainty"
for output_name in output_names:
    axis, quantiles, times = plot_outputs(output_type, output_name, scenarios_to_plot, sc_linestyles, sc_colors, False, x_min=scenario_x_min, x_max=scenario_x_max)
    quantile_val = 0.975
    quantile_max_val= (quantiles[quantile_val])## to get the maximum of median and 95% range 0.5, 0.025,0.975
    quantile_max_val = max(quantile_max_val[1:420])
    index_quantile= quantiles[quantile_val].index(quantile_max_val)
    time_quantile = to_date(times[index_quantile])
    print(output_name+f" maximum quantile value:", quantile_max_val) 
    print(output_name+f" time corresponding to maximum:",time_quantile )
    
    
    #path = os.path.join(base_dir, output_type, f"{output_name}_scenario_{scenario}.png")
    #plt.savefig(path)