In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import scipy.stats

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
def load_jsons(paths, vars=("seasonality_beta1", "seasonality_max_R_day")):
    traces = []
    df = pd.DataFrame()
    for fn in paths:
        with open('../../'+fn) as f:
            d = json.load(f)
            d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
            d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
            d["LABEL"] = f"Seasonal {d['MODEL']} et al." #\n{d['DATA']} data" # NB: Change for 2x2 plots
            if d['DATA'] == "BraunerTE":
                d["LABEL"] += "\n(temperate Europe)"
            print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")
            traces.append(d)

            cols = {v: np.array(d[v]) for v in vars}
            cols["label"] = d["LABEL"]
            cols["mobility"] = str('mobility' in d["data_path"])
            cols["mobility_type"] = re.search("mobility_(.*)\.csv", d["data_path"]).groups()[0] if 'mobility' in d["data_path"] else "None"
            df = df.append(pd.DataFrame(cols), ignore_index=True)

    for mob in df['mobility'].unique():
        dx2 = df[df['mobility'] == mob]
        dx2['label'] = "Combined"
        df = df.append(dx2, ignore_index=True)
    #cols = {v: np.array(df[v].values) for v in vars}
    #cols["label"] = "Combined"
    #df = df.append(pd.DataFrame(cols), ignore_index=True)
    return traces, df


In [None]:
beta1_SRC=[
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_basic_R_prior/20210429-044743-70284_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_mobility/complex_seasonal_2021-06-25-223518_pid53529_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_mobility/complex_seasonal_2021-06-25-223548_pid53688_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_mobility/complex_seasonal_2021-06-25-223618_pid53905_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_basic_R_prior/complex_seasonal_2021-04-30-012232_pid18922_summary.json",
    ]

traces, df1 = load_jsons(beta1_SRC)
df1["gamma_percent"] = 100*df1["seasonality_beta1"]
df1["Seasonality peak"] = "January 1"
df1.fillna("False", inplace=True)
print(df1["mobility"].unique())

for mt in ['RecreGrocTransWorkResid', 'GrocTransWorkResid', 'GrocTransWork']:
    df2 = df1[(df1['mobility_type'] == 'None') | (df1['mobility_type'] == mt)]
    sns.violinplot(y="label", x="gamma_percent", data=df2, linewidth=1.0, inner="quartiles", split=True, hue="mobility")
    #plt.legend(shadow=False, fancybox=False, loc="upper right", bbox_to_anchor=(0.99, 0.99), fontsize=6)
    plt.title(f"Sensitivity to mobility types {mt}", fontsize=10)
    plt.xlabel("Seasonality amplitude γ (with 50% CI)")
    plt.ylabel(None)
    #plt.xticks([0, 10, 20, 30, 40, 50, 60, 70, 80], ["0%", "", "20%", "", "40%", "", "60%", "", "80%"])
    plt.xticks([0, 20, 40, 60], ["0%", "20%", "40%", "60%"])
    plt.savefig(f'figures/Fig_seasonality_mobility_{mt}.pdf', bbox_inches='tight')
    plt.close()
