In [None]:
from matplotlib import pyplot as plt
import os
import numpy as np
import pandas as pd
import datetime
from matplotlib.patches import Rectangle

from summer.utils import ref_times_to_dti
from autumn.core.runs.managed import ManagedRun
from autumn.models.sm_covid import base_params
from autumn.settings.constants import COVID_BASE_DATETIME

from autumn.projects.sm_covid.common_school.project_maker import get_school_project_timeseries
from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (
    SCHOOL_PROJECT_NOTEBOOK_PATH, 
    FIGURE_WIDTH,
    RESOLUTION,
    INCLUDED_COUNTRIES,
    set_up_style
)

set_up_style()
output_fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, "output_figs")

In [None]:
run_id = "sm_covid/france/1660709197/8b139a0"
mr = ManagedRun(run_id)
pbi = mr.powerbi.get_db()
targets = pbi.get_targets()
results = pbi.get_uncertainty()

model_dates = pbi.get_derived_outputs().index
model_start, model_end = min(model_dates), max(model_dates)

In [None]:
title_lookup = {
    "infection_deaths": "COVID-19-specific deaths",
    "cumulative_infection_deaths": "Cumulative COVID-19-specific deaths",
    "cumulative_incidence": "Cumulative COVID-19 disease incidence",

    "hospital_admissions": "new daily hospital admissions",
    "icu_admissions": "new daily admissions to ICU",
    "incidence": "daily new infections",
    "hospital_admissions": "daily hospital admissions",
    "hospital_occupancy": "total hospital beds",
    "icu_admissions": "daily ICU admissions",
    "icu_occupancy": "total ICU beds",
    "prop_ever_infected": "ever infected with Delta or Omicron",
}
sc_colours = ["black", "crimson"]
unc_sc_colours = ((0.2, 0.2, 0.8), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2), (0.8, 0.2, 0.2), (0.2, 0.8, 0.2), (0.8, 0.8, 0.2))


# Model calibration

In [None]:
timeseries = get_school_project_timeseries("France")
all_targets = {}
for k, v in timeseries.items():
    all_targets[k] = pd.Series(data=v['values'], index=v['times'], name=v['output_key'])
for target in all_targets:
    all_targets[target].index = ref_times_to_dti(REF_DATE, all_targets[target].index)

    all_targets[target] = all_targets[target][model_start <= all_targets[target].index][all_targets[target].index <= model_end]


In [None]:

def plot_model_fit(results, output_name):
    plt.rcParams.update({'font.size': 12})
    
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))

    colour = unc_sc_colours[0]
    
    results_df = results[(output_name, 0)]
    indices = results_df.index
    axis.fill_between(
        indices, 
        results_df[0.025], results_df[0.975], 
        color=colour, 
        alpha=0.5,
        label="_nolegend_",
    )
    axis.fill_between(
        indices, 
        results_df[0.25], results_df[0.75], 
        color=colour, alpha=0.6, 
    )
    axis.plot(indices, results_df[0.500], color=colour, zorder=10)

    if output_name in all_targets and len(all_targets[output_name]) > 0:
        all_targets[output_name].plot.line(
            ax=axis, 
            linewidth=0., 
            markersize=2.,
            marker="o",
            markerfacecolor="black",
            markeredgecolor="black",
            alpha=1,
            label="_nolegend_",
            zorder=11,
        )
    axis.tick_params(axis="x", labelrotation=45)
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_title(title)
    axis.set_xlim((model_start, model_end))
    fig.tight_layout()

    # axis.set_ylim((0, 1500))

In [None]:
plot_model_fit(results, "cumulative_infection_deaths")

# Scenario comparison over time

In [None]:

def plot_two_scenarios(results, output_name, include_unc=False):
    plt.rcParams.update({'font.size': 12})
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH, FIGURE_WIDTH *.7))
    for scenario in [0, 1]:
        colour = sc_colours[scenario]
        results_df = results[(output_name, scenario)]
        indices = results_df.index
        label = "baseline" if scenario == 0 else "schools open"
        scenario_zorder = 10 if scenario == 0 else scenario

        if include_unc:
            axis.fill_between(
                indices, 
                results_df[0.25], results_df[0.75], 
                color=colour, alpha=0.7, 
                # label=interval_label,
                zorder=scenario_zorder
            )

        axis.plot(indices, results_df[0.500], color=colour, label=label, lw=2.)
        axis.tick_params(axis="x", labelrotation=45)
        title = output_name if output_name not in title_lookup else title_lookup[output_name]
        axis.set_title(title)
        axis.set_xlim((model_start, model_end))

        axis.legend()
    fig.tight_layout()


In [None]:
plot_two_scenarios(results, "infection_deaths", False)

# Scenario comparison final size

In [None]:
def plot_final_size_compare(results, output_name):

    plt.rcParams.update({'font.size': 12})    
    fig, axis = plt.subplots(1, 1, figsize=(FIGURE_WIDTH * .6 , FIGURE_WIDTH *.7))
    box_width = .7
    color = 'black'
    box_color= 'coral'
    y_max = 0
    for i, label in enumerate(["baseline", "schools open"]):
        quantiles = results[(output_name, i)].iloc[-1]
        x = 1 + i

        # median
        axis.hlines(y=quantiles[0.5], xmin=x - box_width / 2. , xmax= x + box_width / 2., lw=2., color=color, zorder=3)    
        
        # IQR
        height = quantiles[0.75] - quantiles[0.25]
        rect = Rectangle(xy=(x - box_width / 2., quantiles[0.25]), width=box_width, height=height, zorder=2, facecolor=box_color)
        axis.add_patch(rect)

        # 95% CI
        axis.vlines(x=x, ymin=quantiles[0.025] , ymax=quantiles[0.975], lw=1.3, color=color, zorder=1)

        y_max = max(y_max, quantiles[0.975])
    title = output_name if output_name not in title_lookup else title_lookup[output_name]
    axis.set_title(title)
    plt.xticks(ticks=[1, 2], labels=["baseline", "schools open"], fontsize=15)

    axis.set_xlim((0., 3.))
    axis.set_ylim((0, y_max * 1.2))
    fig.tight_layout()



In [None]:
plot_final_size_compare(results, "cumulative_infection_deaths")

In [None]:
plot_final_size_compare(results, "cumulative_incidence")

# Age-specific incidence