In [None]:
from matplotlib import pyplot as plt
import os
import numpy as np
import pandas as pd
import datetime
from matplotlib.patches import Rectangle
import matplotlib.gridspec as gridspec
from copy import copy


from summer.utils import ref_times_to_dti
from autumn.core.runs.managed import ManagedRun
from autumn.models.sm_covid import base_params
from autumn.settings.constants import COVID_BASE_DATETIME
from autumn.core import inputs

from autumn.projects.sm_covid.common_school.project_maker import get_school_project_timeseries
from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (
    SCHOOL_PROJECT_NOTEBOOK_PATH, 
    FIGURE_WIDTH,
    RESOLUTION,
    INCLUDED_COUNTRIES,
    set_up_style
)

set_up_style()
output_fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, "output_figs")

In [None]:
run_id = "sm_covid/france/1660800488/e54d5c7"
ISO3 = "FRA"
COUNTRY_NAME = "France"

mr = ManagedRun(run_id)
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()

model_dates = pbi.get_derived_outputs().index
model_start, model_end = min(model_dates), max(model_dates)

In [None]:
title_lookup = {
    "infection_deaths": "COVID-19-specific deaths",
    "cumulative_infection_deaths": "Cumulative COVID-19-specific deaths",
    "cumulative_incidence": "Cumulative COVID-19 disease incidence",

    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "incidence": "daily new infections",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",
}
sc_colours = ["black", "crimson"]
unc_sc_colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))


In [None]:
input_db = inputs.database.get_input_db()
unesco_data = input_db.query(
    table_name='school_closure', 
    columns=["date", "status", "country_id"],
)

In [None]:
def add_school_closure_patches(ax, iso3, ymax):
    data = unesco_data[unesco_data['country_id'] == iso3]
    partial_dates = data[data['status'] == "Partially open"]['date'].to_list()
    closed_dates = data[data['status'] == "Closed due to COVID-19"]['date'].to_list()
    
    # for date in partial_dates:
    ax.vlines(partial_dates,ymin=0, ymax=ymax, lw=1, alpha=1, color='lightyellow', zorder = 1)
    ax.vlines(closed_dates, ymin=0, ymax=ymax, lw=1, alpha=1, color='lavenderblush', zorder = 1)


# Model calibration

In [None]:
timeseries = get_school_project_timeseries("France")
all_targets = {}
for k, v in timeseries.items():
    all_targets[k] = pd.Series(data=v['values'], index=v['times'], name=v['output_key'])
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(COVID_BASE_DATETIME, all_targets[target].index)

    all_targets[target] = all_targets[target][model_start <= all_targets[target].index][all_targets[target].index <= model_end]


def plot_model_fit(axis, results, output_name):
    plt.rcParams.update({'font.size': 12})    

    colour = unc_sc_colours[0]
    
    results_df = results[(output_name, 0)]
    indices = results_df.index
    axis.fill_between(
        indices, 
        results_df[0.025], results_df[0.975], 
        color=colour, 
        alpha=0.5,
        label="_nolegend_",
    )
    axis.fill_between(
        indices, 
        results_df[0.25], results_df[0.75], 
        color=colour, alpha=0.6, 
    )
    axis.plot(indices, results_df[0.500], color=colour, zorder=10)

    if output_name in all_targets and len(all_targets[output_name]) > 0:
        all_targets[output_name].plot.line(
            ax=axis, 
            linewidth=0., 
            markersize=2.,
            marker="o",
            markerfacecolor="black",
            markeredgecolor="black",
            alpha=1,
            label="_nolegend_",
            zorder=11,
        )
    axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_title(title)
    axis.set_xlim((model_start, model_end))
    plt.tight_layout()

    # axis.set_ylim((0, 1500))

In [None]:
for output_name in ["infection_deaths", "cumulative_infection_deaths"]:
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))
    plot_model_fit(axis, results, output_name)

# Scenario comparison over time

In [None]:

def plot_two_scenarios(axis, results, output_name, include_unc=False):
    plt.rcParams.update({'font.size': 12})

    ymax = 0.
    for scenario in [0, 1]:
        colour = sc_colours[scenario]
        results_df = results[(output_name, scenario)]
        indices = results_df.index
        label = "baseline" if scenario == 0 else "schools open"
        scenario_zorder = 10 if scenario == 0 else scenario

        if include_unc:
            axis.fill_between(
                indices, 
                results_df[0.25], results_df[0.75], 
                color=colour, alpha=0.7, 
                # label=interval_label,
                zorder=scenario_zorder
            )

        axis.plot(indices, results_df[0.500], color=colour, label=label, lw=2.)
        ymax = max(ymax, max(results_df[0.500]))
        
    plot_ymax = ymax * 1.1    
    add_school_closure_patches(axis, ISO3, ymax=plot_ymax)

    axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_title(title)
    axis.set_xlim((model_start, model_end))
    axis.set_ylim((0, plot_ymax))
    axis.legend()
    plt.tight_layout()


In [None]:
for output_name in ["infection_deaths", "cumulative_infection_deaths"]:
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))
    plt.style.use("ggplot")
    plot_two_scenarios(axis,results, output_name, False)

# Scenario comparison final size

In [None]:
def plot_final_size_compare(axis, results, output_name):

    plt.rcParams.update({'font.size': 12})    
    box_width = .7
    color = 'black'
    box_color= 'coral'
    y_max = 0
    for i, label in enumerate(["baseline", "schools open"]):
        quantiles = results[(output_name, i)].iloc[-1]
        x = 1 + i

        # median
        axis.hlines(y=quantiles[0.5], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=2., color=color, zorder=3)    
        
        # IQR
        height = quantiles[0.75] - quantiles[0.25]
        rect = Rectangle(xy=(x - box_width / 2., quantiles[0.25]), width=box_width, height=height, zorder=2, facecolor=box_color)
        axis.add_patch(rect)

        # 95% CI
        axis.vlines(x=x, ymin=quantiles[0.025] , ymax=quantiles[0.975], lw=1.3, color=color, zorder=1)

        y_max = max(y_max, quantiles[0.975])
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_title(title)
    plt.xticks(ticks=[1, 2], labels=["baseline", "schools open"], fontsize=15)

    axis.set_xlim((0., 3.))
    axis.set_ylim((0, y_max * 1.2))
    plt.tight_layout()



In [None]:
for output_name in ["cumulative_infection_deaths", "cumulative_incidence"]:        
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH * .6 , FIGURE_WIDTH *.7))
    plot_final_size_compare(axis, results, output_name)

# Age-specific incidence

In [None]:
colours = ["cornflowerblue", "darkorange", "mediumseagreen", "pink", "purple"]
def plot_incidence_by_age(derived_outputs, ax, scenario, as_proportion: bool):
    y_label = "COVID-19 incidence proportion" if as_proportion else "COVID-19 incidence"    

    times = derived_outputs["incidence", scenario].index.to_list()
    running_total = [0] * len(derived_outputs["incidence", scenario])
    age_groups = base_params['age_groups']
    for i_age, age_group in enumerate(age_groups):
        output_name = f"incidenceXagegroup_{age_group}"
    
        if i_age < len(age_groups) - 1:
            upper_age = age_groups[i_age + 1] - 1 if i_age < len(age_groups) - 1 else ""
            age_group_name = f"{age_group}-{upper_age}"
        else:
            age_group_name = f"{age_group}+"

        age_group_incidence = derived_outputs[output_name, scenario]
        
        if as_proportion:
            numerator, denominator = age_group_incidence, derived_outputs["incidence", scenario]
            age_group_proportion = np.divide(numerator, denominator, out=np.zeros_like(numerator), where=denominator!=0)
            new_running_total = age_group_proportion + running_total
        else: 
            new_running_total = age_group_incidence + running_total 

        ax.fill_between(times, running_total, new_running_total, color=colours[i_age], label=age_group_name, zorder=2, alpha=.8)
        running_total = copy(new_running_total)

    y_max = max(new_running_total)
    plot_ymax = y_max * 1.1
    add_school_closure_patches(ax, ISO3, ymax=plot_ymax)

    ax.set_xlim((model_start, model_end))
    ax.set_ylim((0, plot_ymax))

    ax.set_ylabel(y_label)

    if not as_proportion:
        handles, labels = ax.get_legend_handles_labels()
        ax.legend(
            reversed(handles),
            reversed(labels),
            title="Age:",
            fontsize=12,
            title_fontsize=12,
            labelspacing=.5,
            facecolor="white"
        )

In [None]:
derived_outputs = pbi.get_derived_outputs()

fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))
plot_incidence_by_age(derived_outputs, axis, 0, as_proportion=False)

In [None]:
fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.5))
plot_incidence_by_age(derived_outputs, axis, 1, as_proportion=True)

# Make combined multi-panel figure

In [None]:


fig = plt.figure(figsize=(8.3, 11.7), dpi=RESOLUTION) # crete an A4 figure
outer = gridspec.GridSpec(3, 1, wspace=0., hspace=0., height_ratios=(5, 60, 35))

#### Top row with country name
ax = plt.Subplot(fig, outer[0, 0])
t = ax.text(0.5,0.5, COUNTRY_NAME)
t.set_ha('center')
t.set_va('center')
ax.set_xticks([])
ax.set_yticks([])
fig.add_subplot(ax)

#### Second row will need to be split
outer_cell = outer[1, 0]
# first split in left/right panels
inner_grid = gridspec.GridSpecFromSubplotSpec(1, 2, subplot_spec=outer_cell, wspace=0., hspace=0., width_ratios=(70, 30))
left_grid = inner_grid[0, 0]  # will contain timeseries plots
right_grid = inner_grid[0, 1]  # will contain final size plots

#### Split left panel into 3 panels
inner_left_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=left_grid, wspace=0., hspace=0., height_ratios=(1, 1, 1))
# calibration
ax = plt.Subplot(fig, inner_left_grid[0, 0])
plot_model_fit(ax, results, "infection_deaths")
fig.add_subplot(ax)
# scenario compare deaths
ax = plt.Subplot(fig, inner_left_grid[1, 0])
plot_two_scenarios(ax, results, "infection_deaths", False)
fig.add_subplot(ax)
# scenario compare hosp
ax = plt.Subplot(fig, inner_left_grid[2, 0])
plot_two_scenarios(ax, results, "cumulative_infection_deaths", False)
fig.add_subplot(ax)


## Split right panel into 3 panels
inner_right_grid = gridspec.GridSpecFromSubplotSpec(3, 1, subplot_spec=right_grid, wspace=0., hspace=0., height_ratios=(1, 1, 1))
# calibration
ax = plt.Subplot(fig, inner_right_grid[0, 0])
plot_final_size_compare(ax, results, "cumulative_infection_deaths")
fig.add_subplot(ax)
# scenario compare deaths
ax = plt.Subplot(fig, inner_right_grid[1, 0])
plot_final_size_compare(ax, results, "cumulative_incidence")
fig.add_subplot(ax)
# # scenario compare hosp
ax = plt.Subplot(fig, inner_right_grid[2, 0])
t = ax.text(0.5,0.5, 'peak hosp coming soon')
t.set_ha('center')
ax.set_xticks([])
ax.set_yticks([])
fig.add_subplot(ax)

#### Third row will need to be split into 4 panels
outer_cell = outer[2, 0]
inner_grid = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=outer_cell, wspace=0., hspace=0., width_ratios=(50, 50), height_ratios=(50,50))
# top left
ax = plt.Subplot(fig, inner_grid[0, 0])
plot_incidence_by_age(derived_outputs, ax, 0, as_proportion=False)
fig.add_subplot(ax)
# top right
ax = plt.Subplot(fig, inner_grid[0, 1])
plot_incidence_by_age(derived_outputs, ax, 1, as_proportion=False)
fig.add_subplot(ax)
# bottom left
ax = plt.Subplot(fig, inner_grid[1, 0])
plot_incidence_by_age(derived_outputs, ax, 0, as_proportion=True)
fig.add_subplot(ax)
# bottom right
ax = plt.Subplot(fig, inner_grid[1, 1])
plot_incidence_by_age(derived_outputs, ax, 1, as_proportion=True)
fig.add_subplot(ax)


# for i in range(3):
#     if i == 0:
#         continue
#     for j in range(1):
#         ax = plt.Subplot(fig, outer[i, j])
#         t = ax.text(0.5,0.5, 'i=%d, j=%d' % (i, j))
#         t.set_ha('center')
#         ax.set_xticks([])
#         ax.set_yticks([])
#         fig.add_subplot(ax)

outer