In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
def load_jsons(paths, vars=("seasonality_beta1", "seasonality_max_R_day")):
    traces = []
    df = pd.DataFrame()
    for fn in paths:
        with open('../../'+fn) as f:
            d = json.load(f)
            d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
            d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
            d["LABEL"] = f"{d['MODEL']} model\n{d['DATA']} data"
            print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")
            traces.append(d)

            cols = {v: np.array(d[v]) for v in vars}
            cols["label"] = d["LABEL"]
            df = df.append(pd.DataFrame(cols), ignore_index=True)

    cols = {v: np.array(df[v].values) for v in vars}
    cols["label"] = "Combined"
    df = df.append(pd.DataFrame(cols), ignore_index=True)
    return traces, df


In [None]:
beta1_SRC=[
    "sensitivity_final/modelSharma_dataSharma/seasonality_basic_R_prior/20210429-044743-70284_summary.json",
    "sensitivity_final/modelBrauner_dataSharma/seasonality_basic_R_prior/complex_seasonal_2021-04-30-025219_pid22787_summary.json",
    "sensitivity_final/modelSharma_dataBraunerTE/seasonality_basic_R_prior/20210430-084421-36555_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_basic_R_prior/complex_seasonal_2021-04-30-012232_pid18922_summary.json",
    ]

traces, df1 = load_jsons(beta1_SRC)
df1["top-to-trough"] = 100*(1-(1-df1["seasonality_beta1"]) / (1+df1["seasonality_beta1"]))
df1["Seasonality peak"] = "January 1"

sns.violinplot(y="label", x="seasonality_beta1", data=df1, width=1.1, linewidth=1.0, inner="quartiles")
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonality amplitude γ (with 50% CI)")
plt.ylabel(None)
plt.savefig(f'figures/Fig_seasonality_gamma.pdf', bbox_inches='tight')
plt.close()

sns.violinplot(y="label", x="top-to-trough", data=df1, width=1.1, linewidth=1.0, inner="quartiles")
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonality peak-to-trough reduction (percent, with 50% CI)")
plt.ylabel(None)
plt.savefig(f'figures/Fig_seasonality_extremes.pdf', bbox_inches='tight')

In [None]:
maxRday_SRC=[
    "sensitivity_final/modelSharma_dataSharma/seasonality_max_R_day_normal/20210429-044738-70161_summary.json",
    "sensitivity_final/modelBrauner_dataSharma/seasonality_max_R_day_normal/complex_seasonal_2021-04-30-180604_pid56261_summary.json",
    "sensitivity_final/modelSharma_dataBraunerTE/seasonality_max_R_day_normal/20210430-060658-29991_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_max_R_day_normal/complex_seasonal_2021-04-30-050100_pid27982_summary.json",
    ]

traces, df2 = load_jsons(maxRday_SRC)
#df2["seasonality_max_R_day_0"] = df2["seasonality_max_R_day"] - 1
df2["Seasonality peak"] = "Inferred"
dfc = df1.append(df2,  ignore_index=True)

sns.violinplot(y="label", x="seasonality_max_R_day", data=df2, width=1.1, linewidth=1.0, inner="quartiles", cut=0.0)
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Inferred seasonality peak day (with 50% CI)")
plt.ylabel(None)
plt.xticks([1-31-30, 1-31, 1+0, 1+31, 1+31+28], ["Nov 1", "Dec 1", "Jan 1", "Feb 1", "Mar 1"])
plt.xlim((1-31-30-20, 1+31+28+20))
plt.savefig(f'figures/Fig_seasonality_maxRday.pdf', bbox_inches='tight')
plt.close()

sns.violinplot(y="label", x="seasonality_beta1", data=dfc,  linewidth=1.0, hue="Seasonality peak", split=True, inner="quartiles", cut=0.0)
#plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
#plt.title("", fontsize=10)
plt.xlabel("Seasonal amplitude γ (with 50% CI)")
plt.ylabel(None)
plt.savefig(f'figures/Fig_seasonality_gamma_with_maxRday.pdf', bbox_inches='tight')
