In [None]:
from matplotlib import pyplot as plt
import os
import numpy as np
import pandas as pd
import datetime
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
from copy import copy


from summer.utils import ref_times_to_dti
from autumn.core.runs.managed import ManagedRun
from autumn.models.sm_covid import base_params
from autumn.settings.constants import COVID_BASE_DATETIME
from autumn.core import inputs

from autumn.projects.sm_covid.common_school.project_maker import get_school_project_timeseries
from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (
    SCHOOL_PROJECT_NOTEBOOK_PATH, 
    FIGURE_WIDTH,
    RESOLUTION,
    INCLUDED_COUNTRIES,
    set_up_style
)

set_up_style()
output_fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, "output_figs")
# xx-small, x-small, small, medium, large, x-large, xx-large, larger, or smaller

In [None]:
def update_rcparams():
    plt.rcParams.update(
        {
            'font.size': 6,
            'axes.titlesize': "large",
            'axes.labelsize': "x-large",
            'xtick.labelsize': 'large',
            'ytick.labelsize': 'large',
            'legend.fontsize': 'large',
            'legend.title_fontsize': 'large',
            'lines.linewidth': 1.,

            'xtick.major.size':    2.5,
            'xtick.major.width':   0.8,
            'xtick.major.pad':     2,

            'ytick.major.size':    2.5,
            'ytick.major.width':   0.8,
            'ytick.major.pad':     2,

            'axes.labelpad':      2.
        }
    )

In [None]:
run_id = "sm_covid/france/1661230366/1244249"
ISO3 = "FRA"
COUNTRY_NAME = "France"

mr = ManagedRun(run_id)
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()

model_dates = pbi.get_derived_outputs().index
model_start, model_end = min(model_dates), max(model_dates)


In [None]:
title_lookup = {
    "infection_deaths": "COVID-19 deaths",
    "cumulative_infection_deaths": "Cumulative COVID-19 deaths",
    "cumulative_incidence": "Cumulative COVID-19 incidence",

    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "incidence": "daily new infections",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",

    "peak_hospital_occupancy": "Peak COVID-19 hospital occupancy"
}
sc_colours = ["black", "crimson"]
unc_sc_colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))


In [None]:
input_db = inputs.database.get_input_db()
unesco_data = input_db.query(
    table_name='school_closure', 
    columns=["date", "status", "country_id"],
)

In [None]:
SCHOOL_COLORS = {
    'partial': 'azure',
    'full': 'thistle'
}

def add_school_closure_patches(ax, iso3, ymax, school_colors=SCHOOL_COLORS):
    data = unesco_data[unesco_data['country_id'] == iso3]
    partial_dates = data[data['status'] == "Partially open"]['date'].to_list()
    closed_dates = data[data['status'] == "Closed due to COVID-19"]['date'].to_list()
    
    # for date in partial_dates:
    ax.vlines(partial_dates,ymin=0, ymax=ymax, lw=1, alpha=1., color=school_colors['partial'], zorder = 1)
    ax.vlines(closed_dates, ymin=0, ymax=ymax, lw=1, alpha=1, color=school_colors['full'], zorder = 1)



# Model calibration

In [None]:
timeseries = get_school_project_timeseries("France")
all_targets = {}
for k, v in timeseries.items():
    all_targets[k] = pd.Series(data=v['values'], index=v['times'], name=v['output_key'])
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)

    all_targets[target] = all_targets[target][model_start <= all_targets[target].index][all_targets[target].index <= model_end]


def plot_model_fit(axis, results, output_name):
    update_rcparams() 
   
    if output_name in all_targets and len(all_targets[output_name]) > 0:
        axis.scatter(all_targets[output_name].index, all_targets[output_name], marker=".", color='black', label='observations', zorder=11, s=.5)

    colour = unc_sc_colours[0]
    
    results_df = results[(output_name, 0)]
    indices = results_df.index

    axis.plot(indices, results_df[0.500], color=colour, zorder=10, label="model (median)")

    axis.fill_between(
        indices, 
        results_df[0.25], results_df[0.75], 
        color=colour, 
        alpha=0.5, 
        edgecolor=None,
        label="model (IQR)"
    )
    axis.fill_between(
        indices, 
        results_df[0.025], results_df[0.975], 
        color=colour, 
        alpha=0.3,
        edgecolor=None,
        label="model (95% CI)",
    )
    # axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_ylabel(title)
    axis.set_xlim((model_start, model_end))
    # plt.tight_layout()

    plt.legend(markerscale=2.)

    # axis.set_ylim((0, 1500))

In [None]:
# for output_name in ["abs_diff_cumulative_infection_deaths", "infection_deaths", "cumulative_infection_deaths"]:
#     fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))
#     plot_model_fit(axis, results, output_name)

# Scenario comparison over time

In [None]:

def plot_two_scenarios(axis, results, output_name, include_unc=False):
    update_rcparams()

    ymax = 0.
    for scenario in [0, 1]:
        colour = unc_sc_colours[scenario]
        results_df = results[(output_name, scenario)]
        indices = results_df.index
        label = "baseline" if scenario == 0 else "schools open"
        scenario_zorder = 10 if scenario == 0 else scenario + 2

        if include_unc:
            axis.fill_between(
                indices, 
                results_df[0.25], results_df[0.75], 
                color=colour, alpha=0.7, 
                # label=interval_label,
                zorder=scenario_zorder
            )

        axis.plot(indices, results_df[0.500], color=colour, label=label, lw=1.)
        ymax = max(ymax, max(results_df[0.500]))
        
    plot_ymax = ymax * 1.1    
    add_school_closure_patches(axis, ISO3, ymax=plot_ymax)

    # axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_ylabel(title)
    axis.set_xlim((model_start, model_end))
    axis.set_ylim((0, plot_ymax))
    axis.legend()
    # plt.tight_layout()


In [None]:
# for output_name in ["infection_deaths", "cumulative_infection_deaths", "missed_school_death_ratio"]:
#     fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))
#     plt.style.use("ggplot")
#     plot_two_scenarios(axis,results, output_name, True)

# Scenario comparison final size

In [None]:
def plot_final_size_compare(axis, results, output_name):
    update_rcparams()
    # plt.rcParams.update({'font.size': 12})    
    box_width = .7
    color = 'black'
    box_color= 'lightcoral'
    y_max = 0
    for i, label in enumerate(["baseline", "schools open"]):
        quantiles = results[(output_name, i)].iloc[-1]
        x = 1 + i

        # median
        axis.hlines(y=quantiles[0.5], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=1., color=color, zorder=3)    
        
        # IQR
        height = quantiles[0.75] - quantiles[0.25]
        rect = Rectangle(xy=(x - box_width / 2., quantiles[0.25]), width=box_width, height=height, zorder=2, facecolor=box_color)
        axis.add_patch(rect)

        # 95% CI
        axis.vlines(x=x, ymin=quantiles[0.025] , ymax=quantiles[0.975], lw=.7, color=color, zorder=1)

        y_max = max(y_max, quantiles[0.975])
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_ylabel(title)
    axis.set_xticks(ticks=[1, 2], labels=["baseline", "schools open"]) #, fontsize=15)

    axis.set_xlim((0., 3.))
    axis.set_ylim((0, y_max * 1.2))
    # plt.tight_layout()



In [None]:
# for output_name in ["cumulative_infection_deaths", "cumulative_incidence", "peak_hospital_occupancy"]:        
#     fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH * .6 , FIGURE_WIDTH *.7))
#     plot_final_size_compare(axis, results, output_name)

# Age-specific incidence

In [None]:
# colours = ["cornflowerblue", "darkorange", "mediumseagreen", "pink", "purple"]
colours = ["cornflowerblue", "slateblue", "mediumseagreen", "lightcoral", "purple"]


def plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool):
    update_rcparams()
    y_label = "COVID-19 incidence proportion" if as_proportion else "COVID-19 incidence"    

    times = derived_outputs["incidence", scenario].index.to_list()
    running_total = [0] * len(derived_outputs["incidence", scenario])
    age_groups = base_params['age_groups']
    for i_age, age_group in enumerate(age_groups):
        output_name = f"incidenceXagegroup_{age_group}"
    
        if i_age < len(age_groups) - 1:
            upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else ""
            age_group_name = f"{age_group}-{upper_age}"
        else:
            age_group_name = f"{age_group}+"

        age_group_incidence = derived_outputs[output_name, scenario]
        
        if as_proportion:
            numerator, denominator = age_group_incidence, derived_outputs["incidence", scenario]
            age_group_proportion = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0)
            new_running_total = age_group_proportion + running_total
        else: 
            new_running_total = age_group_incidence + running_total 

        ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8)
        running_total = copy(new_running_total)

    y_max = max(new_running_total)
    plot_ymax = y_max * 1.1
    add_school_closure_patches(ax, ISO3, ymax=plot_ymax)

    # work out first time with positive incidence
    t_min = derived_outputs['incidence', 0].gt(0).idxmax()    
    ax.set_xlim((t_min, model_end))
    ax.set_ylim((0, plot_ymax))

    ax.set_ylabel(y_label)

    if not as_proportion and scenario == 0:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            reversed(handles),
            reversed(labels),
            title="Age:",
            # fontsize=12,
            # title_fontsize=12,
            labelspacing=.2,
            handlelength=1.,
            handletextpad=.5,
            columnspacing=1.,
            facecolor="white",
            ncol=2,

        )

In [None]:
# derived_outputs = pbi.get_derived_outputs()

# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))
# plot_incidence_by_age(derived_outputs, axis, 0, as_proportion=False)

In [None]:
# fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))
# plot_incidence_by_age(derived_outputs, axis, 1, as_proportion=True)

# Make combined multi-panel figure

In [None]:
update_rcparams()

fig = plt.figure(figsize=(8.3, 11.7), dpi=300) # crete an A4 figure
outer = gridspec.GridSpec(
    3, 1, hspace=.1, height_ratios=(3, 62, 35), 
    left=0.07, right=0.97, bottom=0.03, top =.97   # this affects the outer margins of the saved figure 
)

#### Top row with country name
ax1 = fig.add_subplot(outer[0, 0])
t = ax1.text(0.5,0.5, COUNTRY_NAME, fontsize=16)
t.set_ha('center')
t.set_va('center')
ax1.set_xticks([])
ax1.set_yticks([])

#### Second row will need to be split
outer_cell = outer[1, 0]
# first split in left/right panels
inner_grid = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_cell, wspace=.2, width_ratios=(70, 30))
left_grid = inner_grid[0, 0]  # will contain timeseries plots
right_grid = inner_grid[0, 1]  # will contain final size plots

#### Split left panel into 3 panels
inner_left_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=left_grid, hspace=.05, height_ratios=(1, 1, 1))
# calibration
ax2 = fig.add_subplot(inner_left_grid[0, 0])
plot_model_fit(ax2, results, "infection_deaths")
plt.setp(ax2.get_xticklabels(), visible=False)
# scenario compare deaths
ax3 = fig.add_subplot(inner_left_grid[1, 0], sharex=ax2)
plot_two_scenarios(ax3, results, "infection_deaths", True)
plt.setp(ax3.get_xticklabels(), visible=False)
# scenario compare hosp
ax4 = fig.add_subplot(inner_left_grid[2, 0], sharex=ax2)
# plot_two_scenarios(ax4, results, "hospital_occupancy", False)
plot_two_scenarios(ax4, results, "cumulative_infection_deaths", True)


## Split right panel into 3 panels
inner_right_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=right_grid, hspace=.1, height_ratios=(1, 1, 1))
# final size deaths
ax5 = fig.add_subplot(inner_right_grid[0, 0])
plot_final_size_compare(ax5, results, "cumulative_infection_deaths")
# final size incidence
ax6 = fig.add_subplot(inner_right_grid[1, 0])
plot_final_size_compare(ax6, results, "cumulative_incidence")
# # hosp peak
ax7 = fig.add_subplot(inner_right_grid[2, 0])
plot_final_size_compare(ax7, results, "peak_hospital_occupancy")

#### Third row will need to be split into 4 panels
derived_outputs = pbi.get_derived_outputs()
outer_cell = outer[2, 0]
inner_grid = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=outer_cell, wspace=.2, hspace=.05, width_ratios=(50, 50), height_ratios=(50,50))

# top left
ax8 = fig.add_subplot(inner_grid[0, 0])
plot_incidence_by_age(derived_outputs, ax8, 0, as_proportion=False)
plt.setp(ax8.get_xticklabels(), visible=False)

# top right
ax9 = fig.add_subplot(inner_grid[0, 1])
plot_incidence_by_age(derived_outputs, ax9, 1, as_proportion=False)
plt.setp(ax9.get_xticklabels(), visible=False)

ax10 = fig.add_subplot(inner_grid[1, 0], sharex=ax8)
plot_incidence_by_age(derived_outputs, ax10, 0, as_proportion=True)
# bottom right
ax11 = fig.add_subplot(inner_grid[1, 1], sharex=ax9)
plot_incidence_by_age(derived_outputs, ax11, 1, as_proportion=True)

In [None]:
fig.savefig("out.png", facecolor="white")
fig.savefig("out.pdf", facecolor="white")