In [None]:
# Import packages
import os
from matplotlib import pyplot as plt

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency
from autumn.dashboards.calibration_results.plots import get_uncertainty_df


In [None]:
# Specify model details
model = Models.COVID_19
region = Region.MANILA
dirname = "2021-08-11"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)

In [None]:
# get R_hat diagnostics
r_hats = calculate_r_hats(mcmc_params, mcmc_tables, burn_in=0)
for key, value in r_hats.items():
    print(f"{key}: {value}")

In [None]:
# Plot outputs
output_type = "uncertainty"  # one of ["MLE", "median", "uncertainty"]

output_name = "notifications"
scenario_list = [0, 1, 2, 3]
x_min, x_max = 580, 720
title = get_plot_text_dict(output_name)
title_fontsize = 18
label_font_size = 15
linewidth = 3

# initialise figure
fig = plt.figure(figsize=(12, 8))
plt.style.use("ggplot")
axis = fig.add_subplot()

# prepare colors for ucnertainty
n_scenarios_to_plot = len(scenario_list)
uncertainty_colors = _apply_transparency(COLORS[:n_scenarios_to_plot], ALPHAS[:n_scenarios_to_plot])

if output_type == "MLE":
    derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)
for i, scenario in enumerate(scenario_list):    
    if output_type in ["MLE", "median"]:
        linestyle = 'dotted' if scenario == 0 else 'solid'
    
    if output_type == "MLE":
        times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
        axis.plot(times, values, color=COLOR_THEME[scenario], linestyle=linestyle, linewidth=linewidth)
    elif output_type == "median":
        _plot_uncertainty(
            axis,
            uncertainty_df,
            output_name,
            scenario,
            x_max,
            x_min,
            [_, _, _, COLOR_THEME[scenario]],
            overlay_uncertainty=False,
            start_quantile=0,
            zorder=scenario + 1,
         )
    elif output_type == "uncertainty":
        scenario_colors = uncertainty_colors[i]         
        _plot_uncertainty(
            axis,
            uncertainty_df,
            output_name,
            scenario,
            x_max,
            x_min,
            scenario_colors,
            overlay_uncertainty=True,
            start_quantile=0,
            zorder=scenario + 1,
         )
    else:
        print("Please use supported output_type option")

    
axis.set_xlim((x_min, x_max))
axis.set_title(title, fontsize=title_fontsize)
plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
change_xaxis_to_date(axis, REF_DATE)