In [None]:
from scipy import stats
from numpy import linspace, exp
import pandas as pd
import matplotlib.pyplot as plt
from autumn.core import inputs
from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (
    SCHOOL_PROJECT_NOTEBOOK_PATH, 
    FIGURE_WIDTH,
    RESOLUTION,
    set_up_style
)

set_up_style()

## Sojourn time in latent and active compartments

In [None]:
fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, "input_figs", "sojourns.pdf")

sojourn_means = {
    "latent": 5.5,
    "active": 8
}
titles = {
    "latent": "incubation period",
    "active": "infectious period"
}
colors = {
    "latent": "mediumpurple",
    "active": "coral"
}
n_replicates = 4
fig, axes = plt.subplots(1, 2, figsize = (FIGURE_WIDTH, .42 * FIGURE_WIDTH), dpi=RESOLUTION, sharey=True)
for i_ax, ax in enumerate(axes):
    state = list(sojourn_means.keys())[i_ax]
    distri = stats.gamma(a=n_replicates, scale=sojourn_means[state] / n_replicates)
    x_min, x_max = 0., 20
    x = linspace(x_min, x_max, 1000)

    ax.plot(x, distri.pdf(x), '-', color="black", lw=1.5, alpha=0.6)
    ax.fill_between(x, distri.pdf(x), alpha=.5, color= colors[state])
    
    ax.set_xlabel("days")
    if i_ax == 0:
        ax.set_ylabel("density")
    ax.set_title(titles[state],fontsize=12)
    ax.locator_params(nbins=5)

plt.tight_layout()
plt.savefig(fig_path)

## Load list of included countries

In [None]:
import yaml

with open('included_countries.yml') as file:
    included_countries = yaml.load(file, Loader=yaml.FullLoader)

input_db = inputs.database.get_input_db()

## UNESCO data

In [62]:
 # Get the UNSECO school closures data

unesco_data = input_db.query(
    table_name='school_closure', 
    columns=["date", "status", "country_id"],
)
unesco_data = unesco_data[unesco_data["country_id"].isin(included_countries)]

In [70]:
unesco_data["country_id"] = pd.Categorical(unesco_data["country_id"]) 
unesco_data['order'] = unesco_data.country_id.cat.codes

In [104]:
import datetime

date_0 = datetime.date(2020, 2, 1)

In [113]:
short_country_name = {
    "Russian Federation": "Russia",
    "Iran, Islamic Republic of": "Iran",
    "Korea, Republic of": "South Korea",
    "Lao People's Democratic Republic": "Laos",
}

In [118]:
colors = {
    "Fully open": "green",
    "Partially open": "cornflowerblue",
    "Closed due to COVID-19": "tomato",
    "Academic break": "grey"

}
fig, ax = plt.subplots(1, 1, figsize = (FIGURE_WIDTH, 2. * FIGURE_WIDTH), dpi=RESOLUTION)

ax.scatter(unesco_data['date'], unesco_data['order'], color=unesco_data['status'].map(colors), marker="|", s=50)

for iso3, country in included_countries.items():
    order = float(unesco_data[unesco_data['country_id'] == iso3].iloc[0]['order'])
    if country in short_country_name:
        country_label = short_country_name[country]
    else:
        country_label = country

    ax.text(x=date_0, y=order, s=country_label, ha='right', va='center')

ax.set_ylim((-1, 71))
ax.get_yaxis().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)


# plt.legend()
plt.tight_layout()
plt.savefig("unesco.png")
plt.savefig("unesco.pdf")
plt.close()
