In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import scipy.stats

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
def load_jsons(paths, vars=("seasonality_beta1", "seasonality_max_R_day")):
    traces = []
    df = pd.DataFrame()
    for fn in paths:
        with open('../../'+fn) as f:
            d = json.load(f)
            d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
            d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
            d["LABEL"] = f"Seasonal {d['MODEL']} et al." #\n{d['DATA']} data" # NB: Change for 2x2 plots
            if d['DATA'] == "BraunerTE":
                d["LABEL"] += "\n(temperate Europe)"
            print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")
            traces.append(d)

            cols = {v: np.array(d[v]) for v in vars}
            cols["label"] = d["LABEL"]
            df = df.append(pd.DataFrame(cols), ignore_index=True)

    cols = {v: np.array(df[v].values) for v in vars}
    cols["label"] = "Combined"
    df = df.append(pd.DataFrame(cols), ignore_index=True)
    return traces, df


In [None]:
beta1_SRC=[
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_basic_R_prior/20210429-044743-70284_summary.json",
    # NB: Change for 2x2 plots
    #"sensitivity_final/modelBrauner_dataSharma/seasonality_basic_R_prior/complex_seasonal_2021-04-30-025219_pid22787_summary.json",
    #"sensitivity_final/default_cmodelSharma_dataBraunerTE/seasonality_basic_R_prior/20210430-084421-36555_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_basic_R_prior/complex_seasonal_2021-04-30-012232_pid18922_summary.json",
    ]

traces, df1 = load_jsons(beta1_SRC)
df1["top-to-trough"] = 100*(1-(1-df1["seasonality_beta1"]) / (1+df1["seasonality_beta1"]))
df1["gamma_percent"] = 100*df1["seasonality_beta1"]
df1["Seasonality peak"] = "January 1"

sns.violinplot(y="label", x="gamma_percent", data=df1, linewidth=1.0, inner="quartiles", split=True)
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonality amplitude γ (with 50% CI)")
plt.ylabel(None)
#plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80], ["0%", "", "20%", "", "40%", "", "60%", "", "80%"])
plt.xticks([0, 20, 40, 60, 80], ["0%", "20%", "40%", "60%", "80%"])
plt.savefig(f'figures/Fig_seasonality_gamma.pdf', bbox_inches='tight')
plt.close()

sns.violinplot(y="label", x="top-to-trough", data=df1, linewidth=1.0, inner="quartiles")
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonality peak-to-trough reduction (with 50% CI)")
plt.ylabel(None)
plt.xticks([0, 20, 40, 60, 80], ["0%", "20%", "40%", "60%", "80%"])
plt.savefig(f'figures/Fig_seasonality_extremes.pdf', bbox_inches='tight')

quants_gamma = df1.groupby('label').apply(lambda d: pd.Series(np.quantile(100 * d["seasonality_beta1"], (0.5, 0.25, 0.75, 0.025, 0.975)), index=['Med', 'C1a', 'C1b', 'C2a', 'C2b']))
quants_p2t = df1.groupby('label').apply(lambda d: pd.Series(np.quantile(100*(1-(1-d["seasonality_beta1"])/(1+d["seasonality_beta1"])), (0.5, 0.25, 0.75, 0.025, 0.975)), index=['Med', 'C1a', 'C1b', 'C2a', 'C2b']))

for quants in [quants_gamma, quants_p2t]:
    print(r"""
\begin{tabular}{l|r r r}
\textbf{Model} & \textbf{Median} & \textbf{50\% CI} & \textbf{95\% CI}\\
\hline
""" + '\n'.join([f"{l} & {d.Med:.1f} & {d.C1a:.1f} -- {d.C1b:.1f} & {d.C2a:.1f} -- {d.C2b:.1f} \\\\" for l, d in reversed(list(quants.iterrows()))]) +
    r"""
\end{tabular}
""")


In [None]:
pal = sns.color_palette()
labels = list(reversed(df1["label"].unique()))
M = 5

for name, column, title in [
        ("gamma", "gamma_percent", "Seasonality amplitude γ"),
        ("gamma2", "gamma_percent", "Seasonality amplitude γ"),
        ("extremes", "top-to-trough", "Peak-to-trough seasonal R reduction"),
        ("extremes2", "top-to-trough", "Peak-to-trough seasonal R reduction"),
        ]:

    plt.figure(figsize=(4, 3))

    for y, yl in enumerate(labels):
        dy = df1[df1["label"] == yl]
        x = dy[column]
        kde = scipy.stats.gaussian_kde(x, 0.2)
        x0, xA0, xB0, xM, xB1, xA1, x1 = np.quantile(x, [0.005, 0.025, 0.25, 0.5, 0.75, 0.975, 1.0-0.005])

        yoff = y - 0
        xs = np.linspace(x0, x1, 500)
        #plt.plot([x0, x1], [y, y], color=pal[y])
        plt.fill_between(xs, kde(xs) * M + yoff, yoff, color=pal[y], alpha=0.3)
        plt.plot(xs, kde(xs) * M + yoff, color=pal[y])
        
        yoff = y - 0.2
        cB = pal[y]
        plt.scatter(xM, yoff, marker="+", color=cB, s=100)
        plt.plot([xA0, xA1], [yoff, yoff], color=cB, lw=3, alpha=0.5)
        plt.plot([xB0, xB1], [yoff, yoff], color=cB, lw=5, alpha=1.0)

        if name in ["gamma", "gamma2", "extremes2"]:
            tx = -3
        else:
            tx = 8
        plt.text(tx, y+0, yl, ha="right", va="center")

    plt.xlabel(title, labelpad=8)
    plt.ylabel(None)
    if name == "gamma":
        plt.xticks([0, 10, 20, 30, 40, 50, 60], ["0.0", "0.1", "0.2", "0.3", "0.4", "0.5", "0.6"])
    if name == "gamma2":
        plt.xticks([0, 10, 20, 30, 40, 50], ["0.0", "0.1", "0.2", "0.3", "0.4", "0.5"])
    elif name == "extremes2":
        plt.xticks([0, 10, 20, 30, 40, 50, 60], ["0%", "10%", "20%", "30%", "40%", "50%", "60%"])
    else:
        plt.xticks([10, 20, 30, 40, 50, 60], ["10%", "20%", "30%", "40%", "50%", "60%"])
    #plt.yticks(range(len(labels)), labels)
    plt.yticks([])
    plt.ylim(-0.5, 2.5)
    sns.despine(left=True, trim=True)
    plt.savefig(f'figures/Fig_seasonality_{name}_2.pdf', bbox_inches='tight')
    plt.close()


In [None]:
q=df1.groupby('label').apply(lambda d: pd.Series(100 * np.quantile(d["seasonality_beta1"], (0.5, 0.25, 0.75, 0.025, 0.975)), index=['Med', 'C1a', 'C1b', 'C2a', 'C2b']))
list(q.iterrows())

In [None]:
maxRday_SRC=[
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_max_R_day_normal/20210429-044738-70161_summary.json",
    # NB: Change for 2x2 plots
    #"sensitivity_final/modelBrauner_dataSharma/seasonality_max_R_day_normal/complex_seasonal_2021-04-30-180604_pid56261_summary.json",
    #"sensitivity_final/default_cmodelSharma_dataBraunerTE/seasonality_max_R_day_normal/20210430-060658-29991_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_max_R_day_normal/complex_seasonal_2021-04-30-050100_pid27982_summary.json",
    ]

traces, df2 = load_jsons(maxRday_SRC)
df2["gamma_percent"] = 100*df2["seasonality_beta1"]
df2["Seasonality peak"] = "Variable"
dfc = df1.append(df2,  ignore_index=True)

sns.violinplot(y="label", x="seasonality_max_R_day", data=df2, linewidth=1.0, inner="quartiles", cut=0.0)
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Inferred seasonality peak day (with 50% CI)")
plt.ylabel(None)
plt.xticks([1-31-30, 1-31, 1+0, 1+31, 1+31+28], ["Nov 1", "Dec 1", "Jan 1", "Feb 1", "Mar 1"])
plt.xlim((1-31-30-20, 1+31+28+20))
plt.savefig(f'figures/Fig_seasonality_maxRday.pdf', bbox_inches='tight')
plt.close()
print(df2[df2.label=="Combined"]["seasonality_max_R_day"].median())

sns.violinplot(y="label", x="gamma_percent", data=dfc, linewidth=1.0, hue="Seasonality peak", split=True, inner="quartiles", cut=0.0)
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonal amplitude γ (with 50% CI)")
plt.ylabel(None)
plt.xticks([0, 20, 40, 60, 80], ["0%", "20%", "40%", "60%", "80%"])
plt.savefig(f'figures/Fig_seasonality_gamma_with_maxRday.pdf', bbox_inches='tight')


In [None]:
df2b = df2[df2.label=="Combined"]
#df2b.seasonality_beta1.values
#df2b.seasonality_max_R_day.values
days = np.arange(-365*0.25, 365*0.75, 0.5)
gamma_samples = 1.0 + df2b.seasonality_beta1.values * np.cos(2 * np.pi / 365.0 * (days.reshape((-1, 1)) + df2b.seasonality_max_R_day.values))
gamma_q = np.quantile(gamma_samples, (0.025, 0.25, 0.5, 0.75, 0.975), axis=1)

for i in range(5):
    plt.plot(days, gamma_q[i], "k", lw=1.2 if i==2 else 0.3)
plt.fill_between(days, gamma_q[0], gamma_q[4], alpha=0.10, color="b")
plt.fill_between(days, gamma_q[1], gamma_q[3], alpha=0.10, color="b")

gamma_q_0 = np.median(1.0 + df1[df1.label=="Combined"].seasonality_beta1.values * np.cos(2 * np.pi / 365.0 * (days.reshape((-1, 1)) + 1)), axis=1)
plt.plot(days, gamma_q_0, "--", lw=1.2, color="darkred")

plt.xlabel(None)
plt.ylabel("Seasonal multiplier Γ(t)")
plt.xticks([-90, 1, 91, 182, 274], ["Oct 1", "Jan 1", "April 1", "July 1", "Oct 1"])
plt.ylim((0.58, 1.42))
plt.savefig(f'figures/Fig_seasonality_multiplier_with_maxRday.pdf', bbox_inches='tight')


In [None]:
days = np.arange(-365*0, 365*1.0, 0.5)
gamma_samples = 1.0 + df1[df1.label=="Combined"].seasonality_beta1.values * np.cos(2 * np.pi / 365.0 * (days.reshape((-1, 1)) + 1))
gamma_q = np.quantile(gamma_samples, (0.025, 0.25, 0.5, 0.75, 0.975), axis=1)

for i in range(5):
    plt.plot(days, gamma_q[i], "k", lw=1.2 if i==2 else 0.9, alpha=0.7 if i==2 else 0.3)
plt.fill_between(days, gamma_q[0], gamma_q[4], alpha=0.8, color="#c0d0ff")#color="#d0d0ff")
plt.fill_between(days, gamma_q[1], gamma_q[3], alpha=0.15, color="#02f")

plt.xlabel(None)
plt.ylabel("Seasonal multiplier Γ(t)")
plt.xticks([-90, 1, 91, 182, 274, 366], ["Oct 1", "Jan 1", "Apr 1", "Jul 1", "Oct 1", "Jan 1"])

plt.ylim((0.58, 1.42))
brauner_period = [22, 150]
#leech_period = [pd.to_datetime("2020-05-01"), pd.to_datetime("2020-09-21")]
sharma_period = [213, 365+9]
plt.fill_between(brauner_period, 0, 2, color='#aaa', alpha=0.4, label="Brauner", zorder=-2, lw=0)
plt.text(np.mean(brauner_period), 1.38, "Brauner et al.", ha="center", va="center", size=9)
#plt.fill_between(leech_period, 0.7, 1.3, color='red', alpha=0.15, label="Leech")
plt.fill_between(sharma_period, 0, 2, color='#aaa', alpha=0.4, label="Sharma", zorder=-2, lw=0)
plt.text(np.mean(sharma_period), 1.38, "Sharma et al.", ha="center", va="center", size=9)
sns.despine(trim=True)

plt.savefig(f'figures/Fig_seasonality_multiplier_fixed.pdf', bbox_inches='tight')

In [None]:
plt.plot?