In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import datetime
import matplotlib as mpl

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis

from autumn.dashboards.calibration_results.plots import get_uncertainty_df

In [None]:
def get_uncertainty_data(directory_name):
    # Specify model details
    model = Models.COVID_19
    region = Region.SRI_LANKA
    dirname = directory_name
    # get the relevant project and output data
    project = get_project(model, region)
    project_calib_dir = os.path.join(
        OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
    )
    calib_path = os.path.join(project_calib_dir, dirname)
    # Load tables
    mcmc_tables = db.load.load_mcmc_tables(calib_path)
    mcmc_params = db.load.load_mcmc_params_tables(calib_path)

    uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
    scenario_list = uncertainty_df['scenario'].unique()
    return scenario_list, uncertainty_df, mcmc_tables,calib_path

In [None]:
titles = {
    "notifications": "Daily number of notified Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "accum_deaths": "Cumulative number of Covid-19 deaths",
    "incidence": "Daily incidence (incl. asymptomatics and undetected)",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated",
    "prop_incidence_strain_delta": "Proportion of Delta variant in new cases",
    "prop_incidence_strain_alpha_beta":  "Proportion of Alpha variant in new cases",
    "prop_ever_infected": "recovered proportion",
    "prop_detected_traced": "Proportion of cases contact traced"
}

def plot_outputs(output_type, output_name, scenario_to_plot, dir_names, show_v_lines=False, x_min=590, x_max=775):
    # initialise figure
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()
    for dir_name in dir_names:
        #get uncetainty data
        scenario_list, uncertainty_df, mcmc_tables,calib_path = get_uncertainty_data(dir_name)

        sc_colors = [COLOR_THEME[i] for i in scenario_list]
        
        # prepare colors for ucnertainty
        n_scenarios_to_plot = len(scenario_to_plot)
        uncertainty_colors = _apply_transparency(COLORS[:2], ALPHAS[:2])
        
        if dir_name == "2022-06-06":# CHN result solid
            sc_linestyles = "solid" 
            scenario_colors = uncertainty_colors[0]
        else:
            sc_linestyles = "solid" #HKG results dashed
            scenario_colors = uncertainty_colors[1]

        title = titles[output_name]
        title_fontsize = 24
        label_font_size = 24
        linewidth = 3
        n_xticks = 10
        legend = True



        if output_type == "MLE":
            derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)

        for i, scenario in enumerate(scenario_to_plot):    
            linestyle = sc_linestyles
            color = sc_colors[scenario]

            if output_type == "MLE":
                times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
                axis.plot(times, values, color=color, linestyle=linestyle, linewidth=linewidth)
                quantiles = 0
            elif output_type == "median":
                _plot_uncertainty(
                    axis,
                    uncertainty_df,
                    output_name,
                    scenario,
                    x_max,
                    x_min,
                    [_, _, _, color],
                    overlay_uncertainty=True,
                    start_quantile=0,
                    zorder=scenario + 1,
                    linestyle=linestyle,
                    linewidth=linewidth,
                 )
            elif output_type == "uncertainty":
                #scenario_colors = uncertainty_colors[i]         
                times, quantiles = _plot_uncertainty(
                    axis,
                    uncertainty_df,
                    output_name,
                    scenario,
                    x_max,
                    x_min,
                    scenario_colors,
                    overlay_uncertainty=True,
                    start_quantile=0,
                    zorder=scenario + 1,
                    linestyle=linestyle,

                 )
            else:
                print("Please use supported output_type option")

            if output_name == "notifications":
                if legend:
                    ax = plt.gca()
                    legend_elem = [mpl.patches.Patch(facecolor=uncertainty_colors[0][1], label='contact matrix China'),
                                    mpl.patches.Patch(facecolor=uncertainty_colors[1][1], 
                                                        label='contact matrix Hong Kong'),
                                  ]
                    ax.legend(handles=legend_elem, fontsize = 16, loc = "upper left")

                    #plt.legend(["baseline","early lockdown 3rd wave (S2)"],fontsize=label_font_size, loc = "upper left") 
                
    
        axis.set_xlim((x_min, x_max))
        axis.set_title(title, fontsize=title_fontsize)
        plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
        plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
        change_xaxis_to_date(axis, REF_DATE)
        plt.locator_params(axis="x", nbins=n_xticks)

          
    return axis, quantiles, times

# Scenario plots with single lines

In [None]:
output_names = ["notifications","icu_occupancy","infection_deaths","hospital_occupancy","accum_deaths","prop_ever_infected"]
dirnames = ["2022-06-06","2022-06-28"]

scenario_x_min, scenario_x_max = 420, 791



# Uncertainty around scenarios


In [None]:
def to_date(x_value, date_str_format="%#d-%b-%Y"):
    ref_date = datetime.date(2019, 12, 31)
    date = ref_date + datetime.timedelta(days=int(x_value))
    return date.strftime(date_str_format)

# Calibration plots

In [None]:
calibration_x_min, calibration_x_max = 420, 680

# Specify model details
model = Models.COVID_19
region = Region.SRI_LANKA
project = get_project(model, region)

for output_name in output_names + ["cdr","prop_ever_infected"]:
    axis, quantiles, t = plot_outputs("uncertainty", output_name, [0], dirnames, False,  x_min=calibration_x_min, x_max=calibration_x_max)  
    #path = os.path.join(base_dir, 'calibration', f"{output_name}.png")
 
    targets = project.plots
    targets = {k: v for k, v in targets.items() if v["output_key"] == output_name}
    values, times = _get_target_values(targets, output_name)
    axis.scatter(times, values, marker="o", color="black", s=10, zorder=999)
    _plot_targets_to_axis(axis, values, times, on_uncertainty_plot=True)
    
    #plt.savefig(path)