In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
import scipy.stats

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
B_REGIONS_ISO = ['AL', 'AD', 'AT', 'BE', 'BA', 'BG', 'HR',
 'CZ', 'DK', 'EE', 'FR', 'DE', 'GR', 'HU', 'IE', 'IT',
 'LV', 'LT', 'MT', 'NL', 'PL', 'PT', 'RO', 'RS', 'SK',
 'SI', 'ES', 'CH', 'GB']
B_LAT = [41.32, 42.5, 48.2, 50.85, 43.87, 42.7, 45.82,
    50.08, 55.67, 59.43, 48.85, 52.52, 37.97, 47.47, 53.33, 41.9,
    56.93, 54.68, 35.88, 52.37, 52.23, 38.7, 44.42, 44.82, 48.13,
    46.05, 40.38, 46.95, 51.5]
B_REGIONS = ['AL', 'AD', 'Austria', 'BE', 'BA', 'BG', 'HR', 'Czech Rep.', 'DK', 'EE', 'FR', 'Germany', 'GR', 'HU', 'IE', 'Italy', 'LV', 'LT', 'MT', 'Netherlands', 'PL', 'PT', 'RO', 'RS', 'SK', 'SI', 'ES', 'Switzerland', 'England']
S_REGIONS = ['Austria', 'Czech Rep.', 'England', 'Germany', 'Italy', 'Netherlands', 'Switzerland']


In [None]:
def load_json(path, vars=("seasonality_beta1", "seasonality_max_R_day")):
    with open('../../'+fn) as f:
        d = json.load(f)
        d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
        d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
        d["LABEL"] = f"Seasonal {d['MODEL']} et al." #\n{d['DATA']} data" # NB: Change for 2x2 plots
        if d['DATA'] == "BraunerTE":
            d["LABEL"] += "\n(temperate Europe)"
        print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")

        cols = {v: np.array(d[v]) for v in vars}
        cols["label"] = d["LABEL"]

    return d, pd.DataFrame(cols)


In [None]:
ds = []
for fn in [
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214513_pid47284_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214413_pid46689_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214443_pid47122_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214543_pid47441_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_local/complex_seasonal_2021-06-27-214614_pid47588_summary.json",
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_local/20210628-002851-52446_summary.json",
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_local/20210628-002856-52455_summary.json",
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_local/20210628-002901-52575_summary.json",
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_local/20210628-002906-52694_summary.json",
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_local/20210628-002911-52834_summary.json",

]:
    d, df0 = load_json(fn)
    d["df"] = df0
    d["Rs"] = {"BraunerTE": B_REGIONS, "Sharma": S_REGIONS}[d['DATA']]
    d["fn"] = fn
    ds.append(d)


In [None]:

for d in ds:
    local_beta1 = np.array(d["seasonality_local_beta1"])
    dfs = []
    for i, r in enumerate(d['Rs']):
        dfs.append(pd.DataFrame({"Country": r, "Local gamma": local_beta1[:,i]}))
    dfs.sort(key=lambda df: df["Local gamma"].mean())
    dfs.append(pd.DataFrame({"Country": "Base\ngamma", "Local gamma": np.array(d["seasonality_beta1"])}))
    df = pd.concat(dfs, axis=0, ignore_index=True)
    #sns.kdeplot(data=df, x="local_beta1", hue="Country", multiple="stack")
    plt.figure(figsize=(6,8))
    sns.boxplot(data=df, x="Local gamma", y="Country", fliersize=0)
    local_sd = d['exp_config']['local_seasonality_sd']
    plt.title(f"Local seasonal amplitudes, sd={local_sd:.2f}")
    plt.xlim(-0.2, 0.8)
    sns.despine()
    plt.savefig(f'figures/Fig_seasonality_local_{d["DATA"]}_{local_sd:.2f}.pdf', bbox_inches='tight')
    plt.close()


In [None]:
bd = [d for d in ds if d['exp_config']['local_seasonality_sd'] == local_sd and d["DATA"] == "BraunerTE"][0]
dfs = []
for i, r in enumerate(bd['Rs']):
    r2 = B_REGIONS_ISO[B_REGIONS.index(r)]
    dfs.append(pd.DataFrame({"Country": r2, "Model": "Brauner",
        "Local gamma": np.array(bd["seasonality_local_beta1"])[:, i]}))
df = pd.concat(dfs, axis=0, ignore_index=True)
df.groupby(["Country", "Model"]).median().to_csv("tmp.csv")



In [None]:
SDs = sorted(set(d['exp_config']['local_seasonality_sd'] for d in ds))
pal = sns.color_palette()

for local_sd in [0.1]:#SDs:
    print(local_sd)
    bd = [d for d in ds if d['exp_config']['local_seasonality_sd'] == local_sd and d["DATA"] == "BraunerTE"][0]
    sd = [d for d in ds if d['exp_config']['local_seasonality_sd'] == local_sd and d["DATA"] == "Sharma"][0]

    b_local_beta1 = np.array(bd["seasonality_local_beta1"])
    s_local_beta1 = np.array(sd["seasonality_local_beta1"])

    dfs = []
    for i, r in enumerate(bd['Rs']):
        dfs.append(pd.DataFrame({"Country": r, "Model": "Brauner", "Local gamma": b_local_beta1[:,i]}))
    for i, r in enumerate(sd['Rs']):
        dfs.append(pd.DataFrame({"Country": r, "Model": "Sharma", "Local gamma": s_local_beta1[:,i]}))
    dfs.sort(key=lambda df: df["Local gamma"].mean())
    dfs.append(pd.DataFrame(
        {"Country": "Base\ngamma", "Model": "Brauner", "Local gamma": np.array(bd["seasonality_beta1"])}))
    dfs.append(pd.DataFrame(
        {"Country": "Base\ngamma", "Model": "Sharma", "Local gamma": np.array(sd["seasonality_beta1"])}))

    df = pd.concat(dfs, axis=0, ignore_index=True)
    plt.figure(figsize=(5,10))
    Rs = df['Country'].unique()
    Rs[1:] = sorted(Rs[1:],
        key=lambda r: df[df['Country'] == r]["Local gamma"].mean(), reverse=True)
    plt.yticks(range(len(Rs)), Rs)
    plt.ylim(-0.5, len(Rs) -0.5)
    for i, r in enumerate(Rs):
        df2 = df[df['Country'] == r]
        #print(df2)
        #sns.kdeplot(data=df2, x="Local gamma", y=np.full(len(df2), i),hue="Model", multiple="stack")    

        bx = df2[df2['Model']=='Brauner']["Local gamma"].values
        sx = df2[df2['Model']=='Sharma']["Local gamma"].values
        x = np.concatenate([bx, sx])
        x0, xA0, xB0, xM, xB1, xA1, x1 = np.quantile(x,
            [0.025, 0.025, 0.25, 0.5, 0.75, 0.975, 0.975])

        yoff = i - 0
        xs = np.linspace(x0, x1, 500)
        M = 0.15

        bkde = scipy.stats.gaussian_kde(bx, 0.2)
        bxs = bkde(xs) * M
        if len(sx) > 0:
            skde = scipy.stats.gaussian_kde(sx, 0.2)
            sxs = skde(xs) * M / 2
            bxs = bxs / 2
        else:
            skde = lambda x: 0.0
            sxs = skde(xs) * M

        #plt.plot([x0, x1], [y, y], color=pal[y])
        for axx in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
            plt.axvline(x=axx, lw=0.2, zorder=-2, c="#777")
        plt.fill_between(xs, yoff, yoff + bxs, color=pal[0], alpha=1)
        plt.fill_between(xs, yoff + bxs, yoff + bxs + sxs, color=pal[1], alpha=1)
        plt.plot(xs, yoff + bxs + sxs, color='k')

        plt.scatter(xM, yoff, marker="+", color='k', s=60)
        plt.plot([xA0, xA1], [yoff, yoff], color='k', lw=2, alpha=0.5)
        plt.plot([xB0, xB1], [yoff, yoff], color='k', lw=3, alpha=1.0)

    #g = sns.FacetGrid(df, row="Country", hue="Model", aspect=15, height=.5)#, palette=pal)
    #g.map_dataframe(sns.kdeplot, "Local gamma",# multiple="stack",
      #bw_adjust=.5, clip_on=False,
      #fill=True, alpha=1, linewidth=1.5)
    #sns.violinplot(data=df, y="Country", x="Local gamma", hue="Model", split=True)
    #sns.kdeplot(data=df, x="local_beta1", hue="Country", multiple="stack")
    #sns.boxplot(data=df, x="Local gamma", y="Country", fliersize=0)
    #local_sd = d['exp_config']['local_seasonality_sd']
    plt.title(f"Local seasonal amplitudes, sd={local_sd:.2f}")
    plt.xlim(-0.2, 0.8)
    sns.despine()
    plt.savefig(f'figures/Fig_seasonality_local_kdes_{local_sd:.2f}.pdf', bbox_inches='tight')
    plt.close()

    
    plt.figure(figsize=(5,10))
    Rs = df['Country'].unique()
    Rs[1:] = sorted(Rs[1:],
        key=lambda r: df[df['Country'] == r]["Local gamma"].mean(), reverse=True)
    #plt.yticks(range(len(Rs)), Rs)
    #plt.ylim(-0.5, len(Rs) -0.5)
    for i, r in enumerate(B_REGIONS):
        df2 = df[df['Country'] == r]

        bx = df2[df2['Model']=='Brauner']["Local gamma"].values
        sx = df2[df2['Model']=='Sharma']["Local gamma"].values
        x = np.concatenate([bx, sx])
        x0, xA0, xB0, xM, xB1, xA1, x1 = np.quantile(x,
            [0.025, 0.025, 0.25, 0.5, 0.75, 0.975, 0.975])
        yoff = B_LAT[i]
        xs = np.linspace(x0, x1, 500)
        M = 0.15

        bkde = scipy.stats.gaussian_kde(bx, 0.2)
        bxs = bkde(xs) * M
        if len(sx) > 0:
            skde = scipy.stats.gaussian_kde(sx, 0.2)
            sxs = skde(xs) * M / 2
            bxs = bxs / 2
        else:
            skde = lambda x: 0.0
            sxs = skde(xs) * M

        #plt.plot([x0, x1], [y, y], color=pal[y])
        for axx in [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]:
            plt.axvline(x=axx, lw=0.2, zorder=-2, c="#777")
        #plt.fill_between(xs, yoff, yoff + bxs, color=pal[0], alpha=1)
        #plt.fill_between(xs, yoff + bxs, yoff + bxs + sxs, color=pal[1], alpha=1)
        #plt.plot(xs, yoff + bxs + sxs, color='k')

        plt.scatter(xM, yoff, marker="+", color='k', s=60)
        plt.plot([xA0, xA1], [yoff, yoff], color='k', lw=2, alpha=0.5)
        plt.plot([xB0, xB1], [yoff, yoff], color='k', lw=3, alpha=1.0)

    plt.title(f"Local seasonal amplitudes, sd={local_sd:.2f}")
    plt.xlim(-0.2, 0.8)
    sns.despine()
    plt.savefig(f'figures/Fig_seasonality_local_latplot_{local_sd:.2f}.pdf', bbox_inches='tight')
    plt.close()

