In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import math 

import numpy as np

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis
from autumn.models.covid_19.stratifications.agegroup import AGEGROUP_STRATA

import matplotlib.patches as mpatches

from autumn.dashboards.calibration_results.plots import get_uncertainty_df


In [None]:
# Specify model details
model = Models.COVID_19
region = Region.MANILA
dirname = "2021-09-21"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["MLE", "median","csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
titles = {
    "notifications": "Daily number of notified Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "accum_deaths": "Cumulative number of Covid-19 deaths",
    "incidence": "Daily incidence (incl. asymptomatics and undetected)",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "new_hospital_admissions": "New hospital admissions",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated",
    "prop_incidence_strain_delta": "Proportion of Delta variant in new cases",
    
    
    "accum_notifications": "Cumulated Covid-19 notifications", 
    "accum_incidence": "Cumulated SARS-CoV-2 infections",
    
}

def plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, show_v_lines=False, x_min=590, x_max=775):

    # plot options
    title = titles[output_name]
    title_fontsize = 18
    label_font_size = 15
    linewidth = 3
    n_xticks = 10

    # initialise figure
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()

    # prepare colors for ucnertainty
    n_scenarios_to_plot = len(scenario_list)
    uncertainty_colors = _apply_transparency(COLORS[:n_scenarios_to_plot], ALPHAS[:n_scenarios_to_plot])

    vals_ = []
    
    if output_type == "MLE":
        derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)
    for i, scenario in enumerate(scenario_list):    
        linestyle = sc_linestyles[scenario]
        color = sc_colors[scenario]

        if output_type == "MLE":
            times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
            axis.plot(times, values, color=color, linestyle=linestyle, linewidth=linewidth)
            vals_.append(list(values)[-1])
            
            
        elif output_type == "median":
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                [_, _, _, color],
                overlay_uncertainty=False,
                start_quantile=0,
                zorder=scenario + 1,
                linestyle=linestyle,
                linewidth=linewidth,
             )
        elif output_type == "uncertainty":
            scenario_colors = uncertainty_colors[i]         
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                scenario_colors,
                overlay_uncertainty=True,
                start_quantile=0,
                zorder=scenario + 1,
             )
        else:
            print("Please use supported output_type option")

            
    print(output_name)
    print(vals_[0] - vals_[1])

    axis.set_xlim((x_min, x_max))
    axis.set_title(title, fontsize=title_fontsize)
    plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
    plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
    change_xaxis_to_date(axis, REF_DATE)
    plt.locator_params(axis="x", nbins=n_xticks)

    if show_v_lines:
        release_dates = {}
        y_max = plt.gca().get_ylim()[1]
        linestyles = ["dashdot", "solid"]
        i = 0
        for time, date in release_dates.items():
            plt.vlines(time, ymin=0, ymax=y_max, linestyle=linestyles[i])
            text = f"Lockdown relaxed on {date}"
            plt.text(time - 5, .5*y_max, text, rotation=90, fontsize=11)
            i += 1
            
    return axis


# Scenario plots with single lines

In [None]:
output_names = ["accum_notifications", "accum_deaths", "accum_incidence", "icu_occupancy"]
scenario_x_min, scenario_x_max = 610, 920 

sc_to_plot = scenario_list[1:]
legend = ["Without vaccine", "With vaccine"]
lift_time = 731
text_font = 14

sc_colors = [COLOR_THEME[i] for i in scenario_list]
sc_linestyles = ["dotted"] + ["solid"] * (len(scenario_list) - 1)
for output_type in ["MLE"]:
    for output_name in output_names:
        ax = plot_outputs(output_type, output_name, sc_to_plot, sc_linestyles, sc_colors, False, x_min=scenario_x_min, x_max=scenario_x_max)
        path = os.path.join(base_dir, output_type, f"{output_name}.png")
        
        if output_name == "accum_deaths":
            plt.legend(labels=legend, fontsize=text_font, facecolor="white",loc="right")
        else:
            plt.legend(labels=legend, fontsize=text_font, facecolor="white")
        
        ymax = plt.gca().get_ylim()[1]
        plt.vlines(x=lift_time,ymin=0,ymax=1.05*ymax, linestyle="dashed")  # 31 Dec 2021
        
        plt.text(x=(scenario_x_min + lift_time) / 2., y=1.* ymax, s="Vaccination phase", ha="center", fontsize = text_font)
        
        if output_name.startswith("accum"):
            s_y = ymax*.4
        else:
            s_y = ymax
            
    
        plt.text(x=lift_time + 3, y=s_y, s="Restrictions lifted", fontsize = text_font, rotation=90, va="top")
        
        plt.savefig(path)
        

# Make Adverse Effects figures

In [None]:
outputs = ["hospital_admissions"]
for agegroup in AGEGROUP_STRATA:
    outputs.append(f"vaccinationXagegroup_{agegroup}")
    outputs.append(f"new_hospital_admissionsXagegroup_{agegroup}")
#     outputs.append(f"tts_casesXagegroup_{agegroup}")
#     outputs.append(f"myocarditis_casesXagegroup_{agegroup}")

do = {}
for o in outputs:
    do[o] = db.load.load_derived_output_tables(calib_path, column=o)[0]

In [None]:
params = project.param_set.baseline.to_dict()
ae_risk = {
    "AstraZeneca": params["vaccination_risk"]["tts_rate"],
    "mRNA": params["vaccination_risk"]["myocarditis_rate"]
}

In [None]:
text_font = 12
    
age_brackets = {
    "15-19": ["15"],
    "20-29": ["20", "25"],
    "30-39": ["30", "35"],
    "40-49": ["40", "45"],
    "50-60": ["50", "55"],
    "60-70": ["60", "65"],
    "70+": ["70", "75"],
}

vacc_scenarios = {
    "mRNA": 2,
    "AstraZeneca": 2,
}

adverse_effects = {
    "mRNA": "myocarditis",
    "AstraZeneca": "thrombosis with thrombocytopenia syndrome",
}

adverse_effects_short= {
    "mRNA": "myocarditis",
    "AstraZeneca": "tts",
}

left_title = "COVID-19-associated hospitalisations prevented"


def make_ae_figure(vacc_scenario, log_scale=False):
    right_title = f"Cases of {adverse_effects[vacc_scenario]}"
    
    fig = plt.figure(figsize=(10, 4))
    plt.style.use("default")
    axis = fig.add_subplot()     
    
    h_max = 0
    delta_agegroup = 1.5 if log_scale else 7000 
    barwidth = .7
    text_offset = 0.1 if log_scale else 700 
    for i, age_bracket in enumerate(list(age_brackets.keys())):
        y = len(age_brackets) - i - .5
        plt.text(x=delta_agegroup / 2, y=y, s=age_bracket, ha="center", va="center", fontsize=text_font)
    
        # make calculations
        hosp_novacc = 0
        hosp_withvacc = 0
        ae_cases = 0
        vacc = 0
        for agegroup in age_brackets[age_bracket]:
            hosp_name = f"new_hospital_admissionsXagegroup_{agegroup}"
            hosp_novacc += do[hosp_name][hosp_name][do[hosp_name]["scenario"] == 1].sum()
            hosp_withvacc += do[hosp_name][hosp_name][do[hosp_name]["scenario"] == vacc_scenarios[vacc_scenario]].sum()
                        
            vacc_name = f"vaccinationXagegroup_{agegroup}"
            vacc_this = do[vacc_name][vacc_name][do[vacc_name]["scenario"] == vacc_scenarios[vacc_scenario]].sum()
            
            vacc += vacc_this            
            ae_cases += vacc_this * ae_risk[vacc_scenario][int(agegroup)]
            
            
        prev_hosp = (hosp_novacc - hosp_withvacc) / vacc * 1.e6 
        ae_cases = ae_cases / vacc * 1.e6
    
        h_max = max(prev_hosp, h_max)
       
        h_val = prev_hosp
        ae_val = ae_cases
        if log_scale:
            h_val = math.log(prev_hosp, 10)
            ae_val = math.log(ae_cases, 10)
            
        origin = 0
        rect = mpatches.Rectangle((origin, y - barwidth/2), width=-h_val, height=barwidth, facecolor="cornflowerblue")
        axis.add_patch(rect)     
        disp_val = int(prev_hosp)
        plt.text(x= -h_val - text_offset, y=y, s=int(disp_val), ha="right", va="center", fontsize=text_font*.7)       
        
        min_bar_length = 0
        if not log_scale:
            min_bar_length = 0 if vacc_scenario == "Astrazeneca" else 0
                
        rect = mpatches.Rectangle((delta_agegroup + origin, y - barwidth/2), width=max(ae_val, min_bar_length), height=barwidth, facecolor="tab:red")
        axis.add_patch(rect)
        disp_val = int(ae_cases)
        plt.text(x=delta_agegroup + origin + max(ae_val, min_bar_length) + text_offset, y=y, s=int(disp_val), ha="left", va="center", fontsize=text_font*.7)      

    if log_scale:
        h_max = math.log(h_max, 10)        
  
    # main title
    axis.set_title(f"Benefit/Risk analysis with {vacc_scenario} vaccine", fontsize = text_font + 2)
    
    # sub-titles
    plt.text(x= - h_max / 2, y=len(age_brackets) + .3, s=left_title, ha="center", fontsize=text_font)
    plt.text(x= h_max / 2 + delta_agegroup, y=len(age_brackets) + .3, s=right_title, ha="center", fontsize=text_font)

    
    # x axis ticks
    magnitude = 1 if log_scale else 5000
    max_val_display = math.ceil(h_max / magnitude) * magnitude
    
    n_ticks = 6 if log_scale else 6
    
    x_ticks = [-max_val_display + j * (max_val_display/(n_ticks - 1)) for j in range(n_ticks)] + [delta_agegroup + j * (max_val_display/(n_ticks - 1)) for j in range(n_ticks)]
    rev_n_ticks = x_ticks[:n_ticks]
    rev_n_ticks.reverse()
    
    if log_scale:
        x_labels = [int(10**(-v)) for v in x_ticks[:n_ticks]] + [int(10**(-v)) for v in rev_n_ticks]
        x_labels[5] = x_labels[6] = 0
    else:
        x_labels = [int(-v) for v in x_ticks[:n_ticks]] + [int(-v) for v in rev_n_ticks]
        
    plt.xticks(ticks=x_ticks, labels=x_labels)

    # x, y lims
    axis.set_xlim((-max_val_display, max_val_display + delta_agegroup))
    axis.set_ylim((0, len(age_brackets) + 1))   
   
    # remove axes
    axis.set_frame_on(False)
    axis.axes.get_yaxis().set_visible(False)

    log_ext = "_log_scale" if log_scale else ""    
    path = os.path.join(base_dir, f"{vacc_scenario}_adverse_effects{log_ext}.png")    
    plt.tight_layout()
    plt.savefig(path, dpi=600)

for vacc_scenario in ["mRNA", "AstraZeneca"]:
    for log_scale in [False,True]:
        make_ae_figure(vacc_scenario, log_scale)        
    

# Dump outputs to csv files

In [None]:
csv_outputs = ["icu_occupancy"]
start_time = 609 # 31 Aug 2021

includes_MLE = True
requested_quantiles = [0.025, 0.50, 0.975]

# for age in [str(int(5. * i)) for i in range(16)]:
#     csv_outputs.append(f"notificationsXagegroup_{age}")

def get_uncertainty_data(output_name, scenario_idx, quantile):
    mask = (
            (uncertainty_df["type"] == output_name)
            & (uncertainty_df["scenario"] == scenario_idx)
            & (uncertainty_df["quantile"] == quantile)
        )
    df = uncertainty_df[mask]
    times = df.time.unique()[1:]
    values = df["value"].tolist()[1:]
        
    return times, values

COVID_BASE_DATE = pd.datetime(2019, 12, 31)
start_date = pd.to_timedelta(start_time, unit="days") + (COVID_BASE_DATE)  

for scenario in scenario_list:
    df = pd.DataFrame()
    
    # include a column for the date
    t, _ = get_uncertainty_data("notifications", scenario, 0.5)
    df["date"] = pd.to_timedelta(t, unit="days") + (COVID_BASE_DATE)  
    
    for output in csv_outputs:
        if includes_MLE:
            derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output)
            do_times, do_values = get_output_from_run_id(output, mcmc_tables, derived_output_tables, "MLE", scenario)            
            
            assert list(do_times[1:]) == list(t)
            do_values = list(do_values)[1:]        

            name = f"{output}_MLE"
            df[name] = do_values       
       
        if output in list(uncertainty_df["type"].unique()):
            for quantile in requested_quantiles:
                _, v = get_uncertainty_data(output, scenario, quantile)         
                name = f"{output}_{quantile}"
                df[name] = v            
    
    
    # trim the dataframe to keep requested times only
    df.drop(df[df.date < start_date].index, inplace=True)    
    
    path = os.path.join(base_dir, 'csv_files', f"outputs_scenario_{scenario}.csv")
    df.to_csv(path)
            
