In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import scipy.stats

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
B_REGIONS = ['AL', 'AD', 'AT', 'BE', 'BA', 'BG', 'HR', 'CZ', 'DK', 'EE', 'FR', 'DE', 'GR', 'HU', 'IE', 'IT', 'LV', 'LT', 'MT', 'NL', 'PL', 'PT', 'RO', 'RS', 'SK', 'SI', 'ES', 'CH', 'GB']
S_REGIONS = ['Austria', 'Czech', 'England', 'Germany', 'Italy', 'Netherlands', 'Switzerland']


In [None]:
def load_json(path, vars=("seasonality_beta1", "seasonality_max_R_day")):
    with open('../../'+fn) as f:
        d = json.load(f)
        d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
        d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
        d["LABEL"] = f"Seasonal {d['MODEL']} et al." #\n{d['DATA']} data" # NB: Change for 2x2 plots
        if d['DATA'] == "BraunerTE":
            d["LABEL"] += "\n(temperate Europe)"
        print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")

        cols = {v: np.array(d[v]) for v in vars}
        cols["label"] = d["LABEL"]

    return d, pd.DataFrame(cols)


In [None]:
for fn in [
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214513_pid47284_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214413_pid46689_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214443_pid47122_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214543_pid47441_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214614_pid47588_summary.json",
]:
    d, df0 = load_json(fn)
    Rs = {"BraunerTE": B_REGIONS, "Sharma": S_REGIONS}[d['DATA']]
    local_beta1 = np.array(d["seasonality_local_beta1"])
    print(local_beta1.shape)
    dfs = []
    for i, r in enumerate(Rs):
        dfs.append(pd.DataFrame({"Country": r, "Local gamma": local_beta1[:,i]}))
    dfs.sort(key=lambda df: df["Local gamma"].mean())
    dfs.append(pd.DataFrame({"Country": "Base\ngamma", "Local gamma": np.array(d["seasonality_beta1"])}))
    df = pd.concat(dfs, axis=0, ignore_index=True)
    print(df)
    #sns.kdeplot(data=df, x="local_beta1", hue="Country", multiple="stack")
    plt.figure(figsize=(6,8))
    sns.boxplot(data=df, x="Local gamma", y="Country", fliersize=0)
    local_sd = d['exp_config']['local_seasonality_sd']
    plt.title(f"Local seasonal amplitudes, sd={local_sd:.2f}")
    plt.xlim(-0.2, 0.8)
    sns.despine()
    plt.savefig(f'figures/Fig_seasonality_local_{d["DATA"]}_{local_sd:.2f}.pdf', bbox_inches='tight')
