In [None]:
from scipy import stats
from numpy import linspace, exp
import pandas as pd
import matplotlib.pyplot as plt
from autumn.core import inputs
from notebooks.user.rragonnet.project_specific.School_Closure.plotting_constants import (
    SCHOOL_PROJECT_NOTEBOOK_PATH, 
    FIGURE_WIDTH,
    RESOLUTION,
    set_up_style
)

set_up_style()

## Sojourn time in latent and active compartments

In [None]:
fig_path = os.path.join(SCHOOL_PROJECT_NOTEBOOK_PATH, "input_figs", "sojourns.pdf")

sojourn_means = {
    "latent": 5.5,
    "active": 8
}
titles = {
    "latent": "incubation period",
    "active": "infectious period"
}
colors = {
    "latent": "mediumpurple",
    "active": "coral"
}
n_replicates = 4
fig, axes = plt.subplots(1, 2, figsize = (FIGURE_WIDTH, .42 * FIGURE_WIDTH), dpi=RESOLUTION, sharey=True)
for i_ax, ax in enumerate(axes):
    state = list(sojourn_means.keys())[i_ax]
    distri = stats.gamma(a=n_replicates, scale=sojourn_means[state] / n_replicates)
    x_min, x_max = 0., 20
    x = linspace(x_min, x_max, 1000)

    ax.plot(x, distri.pdf(x), '-', color="black", lw=1.5, alpha=0.6)
    ax.fill_between(x, distri.pdf(x), alpha=.5, color= colors[state])
    
    ax.set_xlabel("days")
    if i_ax == 0:
        ax.set_ylabel("density")
    ax.set_title(titles[state],fontsize=12)
    ax.locator_params(nbins=5)

plt.tight_layout()
plt.savefig(fig_path)

## Load list of included countries

In [None]:
import yaml

with open('included_countries.yml') as file:
    included_countries = yaml.load(file, Loader=yaml.FullLoader)

n_included_countries = len(included_countries)
input_db = inputs.database.get_input_db()

## UNESCO data

In [None]:
 # Get the UNSECO school closures data

unesco_data = input_db.query(
    table_name='school_closure', 
    columns=["date", "status", "country_id"],
)
unesco_data = unesco_data[unesco_data["country_id"].isin(included_countries)]

In [None]:
unesco_data["country_id"] = pd.Categorical(unesco_data["country_id"]) 
unesco_data['rev_order'] = unesco_data.country_id.cat.codes
unesco_data['order'] = n_included_countries - unesco_data['rev_order'] - 1


In [None]:
import datetime
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerLine2D


In [None]:
short_country_name = {
    "Russian Federation": "Russia",
    "Iran, Islamic Republic of": "Iran",
    "Korea, Republic of": "South Korea",
    "Lao People's Democratic Republic": "Laos",
}

In [None]:
colors = {
    "Fully open": "green",
    "Partially open": "cornflowerblue",
    "Closed due to COVID-19": "tomato",
    "Academic break": "grey"

}
fsize = 10
plt.style.use("default")
plt.rcParams["font.family"] = "Times New Roman"
fig, ax = plt.subplots(1, 1, figsize = (1.3 * FIGURE_WIDTH, 2.* FIGURE_WIDTH), dpi=RESOLUTION)

# Plot the data
for status in unesco_data['status'].unique():
    sub_data = unesco_data[unesco_data['status'] == status]
    ax.scatter(sub_data['date'], sub_data['order'], color=sub_data['status'].map(colors), marker="|", s=50, label=status)

# Add country names
text_date = datetime.date(2020, 2, 1)
for iso3, country in included_countries.items():
    order = float(unesco_data[unesco_data['country_id'] == iso3].iloc[0]['order'])
    if country in short_country_name:
        country_label = short_country_name[country]
    else:
        country_label = country

    ax.text(x=text_date, y=order, s=country_label, ha='right', va='center', fontsize=fsize)

# Mark the different years
first_date = min(unesco_data.date)
last_date = max(unesco_data.date)
new_years = [datetime.date(2021, 1, 1), datetime.date(2022, 1, 1)]
for i in range(len(new_years)):
    x0 = new_years[i]
    y0, y1 = -1, n_included_countries
    ax.vlines(x=x0, ymin=y0, ymax=y1, colors="black", zorder=-10, linestyles="dashed", linewidths=.8)

year_positions = {
    "2020": first_date + (datetime.date(2021, 1, 1) - first_date) / 2,
    "2021": datetime.date(2021, 1, 1) + (datetime.date(2022, 1, 1) - datetime.date(2021, 1, 1)) / 2,
    "2022": datetime.date(2022, 1, 1) + (last_date - datetime.date(2022, 1, 1)) / 2
}
for year, mid_x in year_positions.items():
    ax.text(x=mid_x, y=n_included_countries, s=year, fontsize=fsize, ha="center", va="center")


ax.set_ylim((-1, n_included_countries))
ax.get_yaxis().set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)

handles = [plt.plot([],ls="-", color=col)[0] for cat, col in colors.items()]
labels = [cat for cat, col in colors.items()]
leg = plt.legend(handles, labels,ncol=4, bbox_to_anchor=(0.5, -.02), loc="upper center", frameon=False)
for line in leg.get_lines():
    line.set_linewidth(7.0)

plt.tight_layout()
plt.savefig("unesco.png")
plt.savefig("unesco.pdf")
plt.close()
