In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../../") 

import os
import yaml
import numpy as np
import textwrap
import json

import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['pdf.fonttype'] = 42
import matplotlib.gridspec as gridspec
plt.rcParams["font.family"] = "Optima"
plt.rcParams["font.weight"] = "light"

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import arviz as az

from epimodel import preprocess_data, run_model, EpidemiologicalParameters
from epimodel.plotting.intervention_plots import combine_npi_samples

In [None]:
def load_jsons(paths, vars=("seasonality_beta1", "seasonality_max_R_day")):
    traces = []
    df = pd.DataFrame()
    for fn in paths:
        with open('../../'+fn) as f:
            d = json.load(f)
            d["MODEL"] = re.search('model(.*)_', d['model_config_name']).groups()[0]
            d["DATA"] = re.search('data(.*)', d['model_config_name']).groups()[0]
            d["LABEL"] = f"Seasonal {d['MODEL']} et al." #\n{d['DATA']} data" # NB: Change for 2x2 plots
            if d['DATA'] == "BraunerTE":
                d["LABEL"] += "\n(temperate Europe)"
            print(f"Loaded {d['MODEL']} model, {d['DATA']} data. Rhat: {d['rhat']}")
            traces.append(d)

            cols = {v: np.array(d[v]) for v in vars}
            cols["label"] = d["LABEL"]
            df = df.append(pd.DataFrame(cols), ignore_index=True)

    cols = {v: np.array(df[v].values) for v in vars}
    cols["label"] = "Combined"
    df = df.append(pd.DataFrame(cols), ignore_index=True)
    return traces, df


In [None]:
beta1_SRC=[
    "sensitivity_final/default_cmodelSharma_dataSharma/seasonality_basic_R_prior/20210429-044743-70284_summary.json",
    # NB: Change for 2x2 plots
    #"sensitivity_final/modelBrauner_dataSharma/seasonality_basic_R_prior/complex_seasonal_2021-04-30-025219_pid22787_summary.json",
    #"sensitivity_final/default_cmodelSharma_dataBraunerTE/seasonality_basic_R_prior/20210430-084421-36555_summary.json",
    "sensitivity_final/modelBrauner_dataBraunerTE/seasonality_basic_R_prior/complex_seasonal_2021-04-30-012232_pid18922_summary.json",
    ]

traces, df1 = load_jsons(beta1_SRC)
df1["top-to-trough"] = 100*(1-(1-df1["seasonality_beta1"]) / (1+df1["seasonality_beta1"]))
df1["gamma_percent"] = 100*df1["seasonality_beta1"]
df1["Seasonality peak"] = "January 1"

In [None]:
Rt0_default = np.load("Rt0_default.npz")['arr_0']
Rt0_default.shape

plt.close()
plt.figure(figsize=(6, 4), dpi=150)

days = np.arange(-365*0.5, 365*0.25, 1.0)
gamma_q_0 = np.median(1.0 + df1[df1.label=="Combined"].seasonality_beta1.values * np.cos(2 * np.pi / 365.0 * (days.reshape((-1, 1)) + 1)), axis=1)
R0s = np.median(Rt0_default, axis=(0, 1)) / gamma_q_0[int(213-365-days[0]):][:Rt0_default.shape[-1]]
R0 = np.exp(np.mean(np.log(R0s)))

x = np.arange(Rt0_default.shape[-1]) + 213-365
for y in np.median(Rt0_default, axis=0):
    plt.plot(x, y, color='#67c5', lw=0.6)
plt.plot(x, np.quantile(Rt0_default, [0.025, 0.25, 0.75, 0.975], axis=(0, 1)).T, color='black', lw=1.0, alpha=0.7)
plt.plot(x, np.quantile(Rt0_default, [0.5], axis=(0, 1)).T, 'k', lw=1.5)
plt.plot(days, R0 * gamma_q_0, "--", lw=1.8, color="#777")

plt.title("Rt0 with random walk (non-seasonal model)\n(with seasonal curve (gray), overall median, 50% CI and 95% CI)")
plt.ylabel("Rt")
plt.xticks([182-365, -90, 1, 91], ["July 1", "Oct 1", "Jan 1", "April 1"])
plt.ylim(0.5, 4.0)
plt.savefig(f'figures/Fig_Rt0_modelSharma_dataSharma_default.pdf', bbox_inches='tight')

In [None]:
Rt0_seas = np.load("Rt0_seasonal.npz")['arr_0']
Rt0_seas.shape

plt.close()
plt.figure(figsize=(6, 4), dpi=150)

days = np.arange(-365*0.5, 365*0.25, 1.0)
gamma_q_0 = np.median(1.0 + df1[df1.label=="Combined"].seasonality_beta1.values * np.cos(2 * np.pi / 365.0 * (days.reshape((-1, 1)) + 1)), axis=1)
R0s = np.median(Rt0_seas, axis=(0, 1)) / gamma_q_0[int(213-365-days[0]):][:Rt0_seas.shape[-1]]
R0 = np.exp(np.mean(np.log(R0s)))

x = np.arange(Rt0_seas.shape[-1]) + 213-365
for y in np.median(Rt0_seas, axis=0):
    plt.plot(x, y, color='#67c5', lw=0.6)
plt.plot(x, np.quantile(Rt0_seas, [0.025, 0.25, 0.75, 0.975], axis=(0, 1)).T, color='black', lw=1.0, alpha=0.7)
plt.plot(x, np.quantile(Rt0_seas, [0.5], axis=(0, 1)).T, 'k', lw=1.5)
plt.plot(days, R0 * gamma_q_0, "--", lw=1.6, color="darkred")

plt.title("Rt0 with random walk and seasonality\n(with seasonal curve (red), overall median, 50% CI and 95% CI)")
plt.ylabel("Rt")
plt.ylim(0.5, 4.0)
plt.xticks([182-365, -90, 1, 91], ["July 1", "Oct 1", "Jan 1", "April 1"])
plt.savefig(f'figures/Fig_Rt0_modelSharma_dataSharma_seasonal.pdf', bbox_inches='tight')

In [None]:
Rtw = np.load("Rt_walk_seasonal.npz")['arr_0']
plt.close()
plt.figure(figsize=(6, 4), dpi=150)

x = np.arange(Rtw.shape[-1]) + 213-365
for y in np.median(Rtw, axis=0):
    plt.plot(x, y, color='#67c5', lw=0.6)
plt.plot(x, np.quantile(Rtw, [0.025, 0.25, 0.75, 0.975], axis=(0, 1)).T, color='black', lw=1.0, alpha=0.7)
plt.plot(x, np.quantile(Rtw, [0.5], axis=(0, 1)).T, 'k', lw=1.5)

plt.title("Random walk applied to basic_R (seasonal)\n(regions, overall median, 50% CI and 95% CI)")
plt.ylabel("Rt random walk multiplier")
plt.yscale("log")
plt.ylim(0.4, 2.5)
plt.xticks([-90, 1], ["Oct 1", "Jan 1"])
plt.savefig(f'figures/Fig_random_walk_modelSharma_dataSharma_seasonal.pdf', bbox_inches='tight')
None

In [None]:
Rtw = np.load("Rt_walk_default.npz")['arr_0']
plt.close()
plt.figure(figsize=(6, 4), dpi=150)

x = np.arange(Rtw.shape[-1]) + 213-365
for y in np.median(Rtw, axis=0):
    plt.plot(x, y, color='#67c5', lw=0.6)
plt.plot(x, np.quantile(Rtw, [0.025, 0.25, 0.75, 0.975], axis=(0, 1)).T, color='black', lw=1.0, alpha=0.7)
plt.plot(x, np.quantile(Rtw, [0.5], axis=(0, 1)).T, 'k', lw=1.5)

plt.title("Random walk applied to basic_R (non-seasonal)\n(regions, overall median, 50% CI and 95% CI)")
plt.ylabel("Rt random walk multiplier")
plt.ylim(0.4, 2.5)
plt.yscale("log")
plt.xticks([-90, 1], ["Oct 1", "Jan 1"])
plt.savefig(f'figures/Fig_random_walk_modelSharma_dataSharma_default.pdf', bbox_inches='tight')
None