In [None]:
# Import packages
import os
from matplotlib import pyplot as plt
import pandas as pd
import math 

import numpy as np

# Import AuTuMN modules
from autumn.settings import Models, Region
from autumn.settings.folders import OUTPUT_DATA_PATH
from autumn.tools.project import get_project
from autumn.tools import db
from autumn.tools.plots.calibration.plots import calculate_r_hats, get_output_from_run_id
from autumn.tools.plots.uncertainty.plots import _plot_uncertainty, _get_target_values
from autumn.tools.plots.plotter.base_plotter import COLOR_THEME
from autumn.tools.plots.utils import get_plot_text_dict, change_xaxis_to_date, REF_DATE, ALPHAS, COLORS, _apply_transparency, _plot_targets_to_axis
from autumn.models.covid_19.stratifications.agegroup import AGEGROUP_STRATA

import matplotlib.patches as mpatches

from autumn.dashboards.calibration_results.plots import get_uncertainty_df


In [None]:
# Specify model details
model = Models.COVID_19
region = Region.MANILA
dirname = "2021-10-20"

In [None]:
# get the relevant project and output data
project = get_project(model, region)
project_calib_dir = os.path.join(
    OUTPUT_DATA_PATH, "calibrate", project.model_name, project.region_name
)
calib_path = os.path.join(project_calib_dir, dirname)
# Load tables
mcmc_tables = db.load.load_mcmc_tables(calib_path)
mcmc_params = db.load.load_mcmc_params_tables(calib_path)

uncertainty_df = get_uncertainty_df(calib_path, mcmc_tables, project.plots)
scenario_list = uncertainty_df['scenario'].unique()

# make output directories
output_dir = f"{model}_{region}_{dirname}"
base_dir = os.path.join("outputs", output_dir)
os.makedirs(base_dir, exist_ok=True)
dirs_to_make = ["MLE", "median","csv_files"]
for dir_to_make in dirs_to_make:
    os.makedirs(os.path.join(base_dir, dir_to_make), exist_ok=True)

In [None]:
titles = {
    "notifications": "Daily number of notified Covid-19 cases",
    "infection_deaths": "Daily number of Covid-19 deaths",
    "accum_deaths": "Cumulative number of Covid-19 deaths",
    "incidence": "Daily incidence (incl. asymptomatics and undetected)",
    "hospital_occupancy": "Hospital beds occupied by Covid-19 patients",
    "icu_occupancy": "ICU beds occupied by Covid-19 patients",
    "new_hospital_admissions": "New hospital admissions",
    "cdr": "Proportion detected among symptomatics",
    "proportion_vaccinated": "Proportion vaccinated",
    "prop_incidence_strain_delta": "Proportion of Delta variant in new cases"
}

def plot_outputs(output_type, output_name, scenario_list, sc_linestyles, sc_colors, show_v_lines=False, x_min=590, x_max=775):

    # plot options
    title = titles[output_name]
    title_fontsize = 18
    label_font_size = 15
    linewidth = 3
    n_xticks = 10

    # initialise figure
    fig = plt.figure(figsize=(12, 8))
    plt.style.use("ggplot")
    axis = fig.add_subplot()

    # prepare colors for ucnertainty
    n_scenarios_to_plot = len(scenario_list)
    uncertainty_colors = _apply_transparency(COLORS[:n_scenarios_to_plot], ALPHAS[:n_scenarios_to_plot])

    if output_type == "MLE":
        derived_output_tables = db.load.load_derived_output_tables(calib_path, column=output_name)
    for i, scenario in enumerate(scenario_list):    
        linestyle = sc_linestyles[scenario]
        color = sc_colors[scenario]

        if output_type == "MLE":
            times, values = get_output_from_run_id(output_name, mcmc_tables, derived_output_tables, "MLE", scenario)
            axis.plot(times, values, color=color, linestyle=linestyle, linewidth=linewidth)
        elif output_type == "median":
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                [_, _, _, color],
                overlay_uncertainty=False,
                start_quantile=0,
                zorder=scenario + 1,
                linestyle=linestyle,
                linewidth=linewidth,
             )
        elif output_type == "uncertainty":
            scenario_colors = uncertainty_colors[i]         
            _plot_uncertainty(
                axis,
                uncertainty_df,
                output_name,
                scenario,
                x_max,
                x_min,
                scenario_colors,
                overlay_uncertainty=True,
                start_quantile=0,
                zorder=scenario + 1,
             )
        else:
            print("Please use supported output_type option")


    axis.set_xlim((x_min, x_max))
    axis.set_title(title, fontsize=title_fontsize)
    plt.setp(axis.get_yticklabels(), fontsize=label_font_size)
    plt.setp(axis.get_xticklabels(), fontsize=label_font_size)
    change_xaxis_to_date(axis, REF_DATE)
    plt.locator_params(axis="x", nbins=n_xticks)

    if show_v_lines:
        release_dates = {}
        y_max = plt.gca().get_ylim()[1]
        linestyles = ["dashdot", "solid"]
        i = 0
        for time, date in release_dates.items():
            plt.vlines(time, ymin=0, ymax=y_max, linestyle=linestyles[i])
            text = f"Lockdown relaxed on {date}"
            plt.text(time - 5, .5*y_max, text, rotation=90, fontsize=11)
            i += 1
            
    return axis


# Scenario plots with single lines

In [None]:
output_names = ["notifications", "icu_occupancy"]
scenario_x_min, scenario_x_max = 650, 920 

sc_to_plot = [0, 2]
legend = ["With vaccine", "Without vaccine"]
lift_time = 731
text_font = 14

sc_colors = [COLOR_THEME[i] for i in scenario_list]
sc_linestyles = ["solid"] * (len(scenario_list))
for output_type in ["median", "MLE"]:
    for output_name in output_names:
        plot_outputs(output_type, output_name, sc_to_plot, sc_linestyles, sc_colors, False, x_min=scenario_x_min, x_max=scenario_x_max)
        path = os.path.join(base_dir, output_type, f"{output_name}.png")
        plt.legend(labels=legend, fontsize=text_font, facecolor="white")
        
        ymax = plt.gca().get_ylim()[1]
        plt.vlines(x=lift_time,ymin=0,ymax=1.05*ymax, linestyle="dashed")  # 31 Dec 2021
        
        plt.text(x=(scenario_x_min + lift_time) / 2., y=1.* ymax, s="Vaccination phase", ha="center", fontsize = text_font)
        
        plt.text(x=lift_time + 3, y=ymax, s="Restrictions lifted", fontsize = text_font, rotation=90, va="top")
        
        plt.savefig(path)
        

# Make Adverse Effects figures

In [None]:
params = project.param_set.baseline.to_dict()
ae_risk = {
    "AstraZeneca": params["vaccination_risk"]["tts_rate"],
    "mRNA": params["vaccination_risk"]["myocarditis_rate"]
}

In [None]:
agg_agegroups = ["15_19", "20_29", "30_39", "40_49", "50_59", "60_69", "70_plus"]
text_font = 12
    
vacc_scenarios = {
    "mRNA": 2,
    "AstraZeneca": 2,
}

adverse_effects = {
    "mRNA": "myocarditis",
    "AstraZeneca": "thrombosis with thrombocytopenia syndrome",
}

adverse_effects_short= {
    "mRNA": "myocarditis",
    "AstraZeneca": "tts",
}

left_title = "COVID-19-associated hospitalisations prevented"

def format_age_label(age_bracket):
    if age_bracket.startswith("70"):
        return "70+"
    else:
        return age_bracket.replace("_", "-")
    


def make_ae_figure(vacc_scenario, log_scale=False):
    trimmed_df = uncertainty_df[
        (uncertainty_df["scenario"]==vacc_scenarios[vacc_scenario]) & (uncertainty_df["time"]==913)
    ] 
    
    right_title = f"Cases of {adverse_effects[vacc_scenario]}"
    
    fig = plt.figure(figsize=(10, 4))
    plt.style.use("default")
    axis = fig.add_subplot()     
    
    h_max = 0
    delta_agegroup = 1 if log_scale else 1000 
    barwidth = .7
    text_offset = 0.1 if log_scale else 20 
    unc_color = "black"
    unc_lw = 1.
    
    for i, age_bracket in enumerate(agg_agegroups):
        y = len(agg_agegroups) - i - .5
        plt.text(x=delta_agegroup / 2, y=y, s=format_age_label(age_bracket), ha="center", va="center", fontsize=text_font)
    
        # get outputs
        hosp_output_name = f"abs_diff_cumulative_hospital_admissionsXagg_age_{age_bracket}"
        ae_output_name = f"abs_diff_cumulative_{adverse_effects_short[vacc_scenario]}_casesXagg_age_{age_bracket}"
        
        prev_hosp_df = trimmed_df[trimmed_df["type"] == hosp_output_name]
        prev_hosp_values = [  # median, lower, upper
            float(prev_hosp_df['value'][prev_hosp_df["quantile"] == q]) for q in [0.5, 0.025, 0.975]
        ]
        log_prev_hosp_values = [math.log10(v) for v in prev_hosp_values]
        
        ae_df = trimmed_df[trimmed_df["type"] == ae_output_name]
        ae_values = [  # median, lower, upper
            - float(ae_df['value'][ae_df["quantile"] == q]) for q in [0.5, 0.975, 0.025]
        ]                             
        log_ae_values = [max(math.log10(v), 0) for v in ae_values]
       
        if log_scale:
            plot_h_values = log_prev_hosp_values
            plot_ae_values = log_ae_values
        else:
            plot_h_values = prev_hosp_values
            plot_ae_values = ae_values
            
        h_max = max(plot_h_values[2], h_max)       
            
        origin = 0
        # hospital
        rect = mpatches.Rectangle((origin, y - barwidth/2), width=-plot_h_values[0], height=barwidth, facecolor="cornflowerblue")
        axis.add_patch(rect)     
        plt.hlines(y=y, xmin=-plot_h_values[1], xmax=-plot_h_values[2], color=unc_color, linewidth=unc_lw)
        
        disp_val = int(prev_hosp_values[0])
        plt.text(x= -plot_h_values[0] - text_offset, y=y + barwidth/2, s=int(disp_val), ha="right", va="center", fontsize=text_font*.7)       
        
        min_bar_length = 0
        if not log_scale:
            min_bar_length = 0 if vacc_scenario == "Astrazeneca" else 0
                
        rect = mpatches.Rectangle((delta_agegroup + origin, y - barwidth/2), width=max(plot_ae_values[0], min_bar_length), height=barwidth, facecolor="tab:red")
        axis.add_patch(rect)
        plt.hlines(y=y, xmin=delta_agegroup + origin + plot_ae_values[1], xmax=delta_agegroup + origin + plot_ae_values[2], color=unc_color, linewidth=unc_lw)
        
        disp_val = int(ae_values[0])
        plt.text(x=delta_agegroup + origin + max(plot_ae_values[0], min_bar_length) + text_offset, y=y + barwidth/2, s=int(disp_val), ha="left", va="center", fontsize=text_font*.7)      

    # main title
    axis.set_title(f"Benefit/Risk analysis with {vacc_scenario} vaccine", fontsize = text_font + 2)
      
    # x axis ticks
    if log_scale:
        max_val_display = math.ceil(h_max)
    else:
        magnitude = 500
        max_val_display = math.ceil(h_max / magnitude) * magnitude        
   
    # sub-titles 
    plt.text(x= - max_val_display / 2, y=len(agg_agegroups) + .3, s=left_title, ha="center", fontsize=text_font)
    plt.text(x= max_val_display / 2 + delta_agegroup, y=len(agg_agegroups) + .3, s=right_title, ha="center", fontsize=text_font)
         
    if log_scale:
        ticks = range(max_val_display + 1)
        rev_ticks = [-t for t in ticks]
        rev_ticks.reverse()        
        x_ticks = rev_ticks + [delta_agegroup + t for t in ticks]
        
        labels = [10**(p) for p in range(max_val_display + 1)]
        rev_labels = [l for l in labels]
        rev_labels.reverse()
        x_labels = rev_labels + labels        
        x_labels[max_val_display] = x_labels[max_val_display + 1] = 0
    else:
        n_ticks = 6
        x_ticks = [-max_val_display + j * (max_val_display/(n_ticks - 1)) for j in range(n_ticks)] + [delta_agegroup + j * (max_val_display/(n_ticks - 1)) for j in range(n_ticks)]
        rev_n_ticks = x_ticks[:n_ticks]
        rev_n_ticks.reverse()
        x_labels = [int(-v) for v in x_ticks[:n_ticks]] + [int(-v) for v in rev_n_ticks]
        
    plt.xticks(ticks=x_ticks, labels=x_labels)
      
    # x, y lims
    axis.set_xlim((-max_val_display, max_val_display + delta_agegroup))
    axis.set_ylim((0, len(agg_agegroups) + 1))   
   
    # remove axes
    axis.set_frame_on(False)
    axis.axes.get_yaxis().set_visible(False)

    log_ext = "_log_scale" if log_scale else ""    
    path = os.path.join(base_dir, f"{vacc_scenario}_adverse_effects{log_ext}.png")    
    plt.tight_layout()
    plt.savefig(path, dpi=600)

for vacc_scenario in ["mRNA", "AstraZeneca"]:
    for log_scale in [False,True]:
        make_ae_figure(vacc_scenario, log_scale)        
    

# Counterfactual "no-vaccine" scenario

In [None]:
output_type = "uncertainty"
output_names = ["notifications", "icu_occupancy", "accum_deaths"]
sc_to_plot = [0, 1]
x_min, x_max = 400, 670
vacc_start = 426
for output_name in output_names:
    axis = plot_outputs(output_type, output_name, sc_to_plot, sc_linestyles, sc_colors, False, x_min=400, x_max=670)
    y_max = plt.gca().get_ylim()[1]
    plt.vlines(x=vacc_start, ymin=0, ymax=y_max, linestyle="dashdot")
    plt.text(x=vacc_start - 5, y=.6 * y_max, s="Vaccination starts", rotation=90, fontsize=12)
    
    path = os.path.join(base_dir, f"{output_name}_counterfactual.png")    
    plt.tight_layout()
    plt.savefig(path, dpi=600)


### Approximate N lives saved by vaccination by 21 Octobre:

In [None]:
today = 660  # 21 Oct
df = uncertainty_df[(uncertainty_df["type"] == "accum_deaths") & (uncertainty_df["quantile"] == 0.5) & (uncertainty_df["time"] == today)]

baseline = float(df[df["scenario"] == 0]["value"])
counterfact = float(df[df["scenario"] == 1]["value"])

print(counterfact - baseline)