In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import math 
import datetime

import numpy as np

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis, change_xaxis_to_date
from autumn.models.covid_19.stratifications.agegroup import AGEGROUP_STRATA

import matplotlib.patches as mpatches

from autumn.dashboards.calibration_results.plots import get_uncertainty_df

In [None]:
# Specify model details
model = Models.COVID_19
region = Region.MALAYSIA
dirname = "2021-10-19"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["median","csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
outputs = []
for agegroup in AGEGROUP_STRATA:
    outputs.append(f"notificationsXagegroup_{agegroup}")
    outputs.append(f"incidenceXagegroup_{agegroup}")
    outputs.append(f"hospital_admissionsXagegroup_{agegroup}")
    outputs.append(f"infection_deathsXagegroup_{agegroup}")
    outputs.append(f"accum_deathsXagegroup_{agegroup}")

do = {}
for o in outputs:
    do[o] = db.load.load_derived_output_tables(calib_path, column=o)[0]
    

In [None]:
params = project.param_set.baseline.to_dict()

def generate_pediatric_results(output_name, scenario_number, age_brackets):
    

    #pediatric outputs
    pediatric_outcome = 0

    for i, age_bracket in enumerate(list(age_brackets.keys())):
        for agegroup in age_brackets[age_bracket]:
            notify = output_name+ f"{agegroup}"
            pediatric_outcome += do[notify][notify][do[notify]["scenario"] == scenario_number]
                  
    time_values = do[notify]["times"][do[notify]["scenario"] == scenario_number]
    
    return time_values, pediatric_outcome

In [None]:
output_names =["notificationsXagegroup_","infection_deathsXagegroup_", "hospital_admissionsXagegroup_","accum_deathsXagegroup_"]


scenario_x_min, scenario_x_max = 640, 791 
sc_to_plot = [scenario_list[3], scenario_list[4]]
sc_colors = [COLOR_THEME[i] for i in scenario_list]
sc_linestyles =["dashed","solid"]#+ ["solid"] * (len(scenario_list))
ref_date = datetime.date(2019, 12, 31)

for age in range(1, 3):
    fig = plt.figure(figsize=(14,5))
    plot_number =1 
    if age == 1: # peadeatric results
        age_brackets = {"0-19": ["0","5","10","15"]}
        title_names = ["COVID-19 daily notifications <=19y    ","   COVID-19 related deaths<=19y", 
                       "hospitalisations<=19y","accumulated deaths<=19y"]
    if age == 2: # adult results
        age_brackets = {"20-75": ["20","25","30","35","40","45","50","55","60","65","70","75"]}  
        title_names = ["COVID-19 daily notifications >19y   ","COVID-19 related deaths>19y", 
                       "hospitalisations>19y","accumulated deaths>19y"]

    for output in output_names:
        for scenario in range(1,len(sc_to_plot)+1):
            cum_cases = 0
            ax1 = fig.add_subplot(1, 4, plot_number)
            pediatric_output = generate_pediatric_results(output, sc_to_plot[scenario-1],age_brackets)
            ax1.plot(pediatric_output[0],pediatric_output[1], color=sc_colors[scenario-1], linestyle=sc_linestyles[scenario-1])
            
            if output == "notificationsXagegroup_": # print cumulative cases
                cum_cases = np.cumsum(pediatric_output[1])
                print("cumulative cases = "+ f"{max(cum_cases)}")
        
            change_xaxis_to_date(ax1, ref_date, rotation=45)
            ax1.spines['top'].set_visible(False)
            ax1.spines['right'].set_visible(False)
            print(max(pediatric_output[1]))
            if output == "accum_deathsXagegroup_" and age == 1:
                ax1.set_ylim(30,60)
                
            if output == "accum_deathsXagegroup_" and age == 2:
                ax1.set_ylim(20000,33000)

        ax1.set_title(title_names[plot_number-1])
        ax1.set_xlim (scenario_x_min,scenario_x_max)
        plot_number+=1
        plt.gca().legend(('scenario 3','scenario 4'),loc = "best")
        fig.tight_layout()
    
    #ax1.set_yscale('log')


