In [None]:
from IPython import get_ipython

# Work file for plot

In [None]:
%config InlineBackend.figure_format = 'retina'
%qtconsole

In [None]:
import arviz as az
import datetime
import matplotlib.dates as md
import os
import os.path
import pandas as pd
import pystan
import scipy.stats as ss
import seaborn as sns
from tqdm import tqdm
import numpy as np
from matplotlib import pyplot as plt
sns.set_context('notebook')
sns.set_palette('colorblind')

In [None]:
loc_europe_EU = [
	'Austria', 
	'Belgium', 
	'Bulgaria', 
	'Croatia', 
	'Cyprus', 
	'Czech Republic', 
	'Denmark', 
	'Estonia', 
	'Finland', 
	'France', 
	'Germany', 
	'Greece', 
	'Hungary', 
	'Ireland', 
	'Italy', 
	'Latvia', 
	'Lithuania', 
	'Luxembourg', 
	'Malta', 
	'Netherlands', 
	'Poland', 
	'Portugal', 
	'Romania', 
	'Slovakia', 
	'Slovenia', 
	'Spain', 
	'Sweden']

loc_europe_others = [
	'Norway',
	'United Kingdom',
	'Switzerland']

loc_non_europe = [
	'United States',
	'Russia']

ACTIVE_COUNTRIES = list(set(loc_europe_EU + loc_europe_others + loc_non_europe))

In [None]:
def write_to_file(obj, name, ext):
    filename = 'data/' + name + '.' + ext
    try:
        obj.to_excel(filename)
    except:
        print("no folder, creating...")
        # import the os module
        import os

        # detect the current working directory
        path = os.getcwd()
        dirpath = path + "/data"
        print ("Path to created folder is %s" % dirpath)
        os.mkdir(dirpath)
        states.to_excel(filename)

 Download the dataset from Our World In Data (github rep)

In [None]:
url = 'https://raw.githubusercontent.com/owid/covid-19-data/master/public/data/owid-covid-data.csv'
states = pd.read_csv(url,
                     usecols=[1,2,3],
                     names=['state', 'date', 'positive'],
                     header=0,
                     parse_dates=['date']).sort_index()
states = states[states['state'].isin(ACTIVE_COUNTRIES)]

In [None]:
# fix cumulative negatives in cases

# write pre to file for debug
#write_to_file(states, 'stan_debug_pre', 'xlsx')

def adjust_negatives(states):

    p = states.shift(-1)
    states.loc[(states.state == p.state) & (states.positive > p.positive), 'positive'] = p.positive
    
    p = states.shift(-1)
    if (((states.state == p.state) & (states.positive > p.positive)).any()):
        states = adjust_negatives(states)
        return states
    else:
        return states

states = adjust_negatives(states)
states.positive = states.positive[states.positive.notnull()].apply(lambda x: int(x))

# write post to file for debug
#write_to_file(states, 'stan_debug_post', 'xlsx')

states = states.set_index(['state', 'date'])
states = states.sort_index()

In [None]:
from datetime import datetime as dtime
from datetime import timedelta
import matplotlib
import matplotlib.ticker as ticker

date_now = dtime.now()
date_then = dtime.now() - timedelta(days=48)

rtlive = ["#5db044", "#da5d5d", "#ebedf2", "#c7cdda", "#787a7f"]
sns.set_palette(rtlive)

In [None]:
def plot_graph(fit, ax, state, layout='epiforecasts'):


    if layout == 'rtlive':
        plot_rtlive_layout(fit, ax, state)
    else:
        plot_epi_layout(fit, ax, state)
    return True

In [None]:
def plot_epi_layout(fit, ax, state):

    m = np.median(fit.posterior.Rt, axis=(0,1))
    hh = np.percentile(fit.posterior.Rt, 95, axis=(0,1))
    h = np.percentile(fit.posterior.Rt, 75, axis=(0,1))
    l = np.percentile(fit.posterior.Rt, 25, axis=(0,1))
    ll = np.percentile(fit.posterior.Rt, 5, axis=(0,1))

    x = fit.posterior.Rt_dates.values
    
    matplotlib.rc('axes',edgecolor='black')
    ax.plot(x, hh, ls='-', marker='', lw=0.6, color=sns.color_palette()[4])
    ax.plot(x, ll, ls='-', marker='', lw=0.6, color=sns.color_palette()[4])

    ax.fill_between(x, h,hh, color=sns.color_palette()[2])
    ax.fill_between(x, ll,l, color=sns.color_palette()[2])
    ax.fill_between(x, l, h, color=sns.color_palette()[3])

    ax.axhline(1.0, linestyle='dotted', lw=1, color='black')
    date_fmt = md.DateFormatter('%b %d')
    ax.xaxis.set_major_formatter(date_fmt)
    ax.xaxis.set_major_locator(md.DayLocator(interval=7))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)

    ax.set_title(state, fontsize=13, color='black', family='sans-serif', loc='left', weight='bold')
    ax.set_ylabel('Effective Rt', fontsize=11, color='black')
    ax.xaxis.label.set_size(10)
    ax.xaxis.label.set_color('black')
    ax.tick_params(which='both', length=3, color='black', labelcolor='black')
    ax.set_xlim([date_then, date_now])

    ax.set_ylim(bottom = 0, top = 3, auto = False)
    sns.despine()

    return True

In [None]:
def plot_rtlive_layout(fit, ax, state):

    m = np.median(fit.posterior.Rt, axis=(0,1))
    hh = np.percentile(fit.posterior.Rt, 97.5, axis=(0,1))
    h = np.percentile(fit.posterior.Rt, 84, axis=(0,1))
    l = np.percentile(fit.posterior.Rt, 16, axis=(0,1))
    ll = np.percentile(fit.posterior.Rt, 2.5, axis=(0,1))

    x = fit.posterior.Rt_dates.values

    # NOTE: try this approach instead (filling with NaN):
    m_good = m.copy()
    m_bad = m.copy()
    m_good[m_good > 1.30] = np.nan
    m_bad[m_bad <= 0.98] = np.nan

    ax.plot(x, m_good, ls='-', marker='', lw=1.5, color=sns.color_palette()[0])
    ax.plot(x, m_bad, ls='-', marker='', lw=1.5, color=sns.color_palette()[1])
    
    ax.fill_between(x, np.where(hh<1, hh, 1) , np.where(ll<1, ll, 1), alpha=0.10, color=sns.color_palette()[0])
    ax.fill_between(x, np.where(hh>1, hh, 1) , np.where(ll>1, ll, 1), alpha=0.10, color=sns.color_palette()[1])
    
    ax.axhline(1.0, linestyle='dotted', lw=1, color='black')
    date_fmt = md.DateFormatter('%b %d')
    ax.xaxis.set_major_formatter(date_fmt)
    ax.xaxis.set_major_locator(md.DayLocator(interval=7))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))
    
    plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
    ax.set_title(state, fontsize=18, color='xkcd:dark grey')
    ax.set_ylabel('Effective Rt', fontsize=11, color='xkcd:dark grey')
    ax.xaxis.label.set_size(10)
    ax.tick_params(which='both', length=3)
    ax.set_xlim([date_then, date_now])

    ax.set_ylim(bottom = 0, top = 3, auto = False)
    sns.despine()

    return True

In [None]:
def make_title_space(fig, topmargin=1):
    s = fig.subplotpars
    w, h = fig.get_size_inches()

    figh = h - (1-s.top)*h  + topmargin
    fig.subplots_adjust(bottom=s.bottom*h/figh, top=1-topmargin/figh)
    fig.set_figheight(figh)
    return fig

In [None]:
import glob

In [None]:
def save_state_fits(fits, directory):
    try:
        os.makedirs(directory)
    except OSError:
        pass

    # clean out old files
    files = glob.glob(directory + '/*')
    for f in files:
        os.remove(f)

    for k, f in fits.items():
        fname = os.path.join(directory, k + '.nc')
        fname_tmp = fname + '.temp'
        az.to_netcdf(f, fname_tmp)
        os.rename(fname_tmp, fname)

def load_state_fits(directory):
    fits = {}
    for f in glob.glob(os.path.join(directory, '*.nc')):
        k = os.path.splitext(os.path.split(f)[1])[0]
        fits[k] = az.from_netcdf(f)
    return fits

In [None]:
#save_state_fits(state_fits, 'state_fits_EUW')

In [None]:
from collections import OrderedDict 

state_fits = load_state_fits('state_fits_EUW')
state_fits = OrderedDict(sorted(state_fits.items()))

In [None]:
# set customized palette
sns.set_palette(rtlive)

 ### Plot single

In [None]:
this_state = 'Sweden'
fit = state_fits[this_state]
fig, ax = plt.subplots(figsize=(7, 3))
out = plot_graph(fit, ax, this_state, layout='epiforecasts')

 ### Plot all

In [None]:
nc = 5
excluded = []
out = [state_fits.pop(x, None) for x in excluded]

# temporarily limit states for testing purposes, set to 0 for unlimited
state_count = 0

if state_count != 0:
    nr = int(np.ceil(state_count/nc))
else:
    nr = int(np.ceil(len(state_fits)/nc))

fig, axes = plt.subplots(nrows=nr, ncols=nc, figsize=(18, 3*nr))

count = 0
for idx, ((k,fit), ax) in enumerate(zip(state_fits.items(), axes.flatten())):
    
    if ((state_count != 0) & (count >= state_count)):
        [ax.set_visible(False) for ax in axes.flatten()[idx:]]
        break
    
    out = plot_graph(fit, ax, k, layout = 'epiforecasts')
    
    count +=1

else:
    [ax.set_visible(False) for ax in axes.flatten()[idx+1:]]

this_date = dtime(2020, 5, 10).strftime('%Y-%m-%d')
#this_date = dtime.today().strftime('%Y-%m-%d')
fig_title = "Stan Rt EU+, " + this_date
fig.suptitle(fig_title, y=0.98, fontsize=24, family="monospace", fontweight='bold')
fig.set_facecolor('w')
plt.subplots_adjust(left=0.10, right=0.95, hspace=0.80, wspace=0.25)

fig = make_title_space(fig, topmargin=1.8)

# -- option to save graph -- #
# imgdir = 'images/'
# fname = 'Stan R0 EUW 2020-05-10.png'
# fpath = imgdir + fname
# plt.savefig(fpath, dpi=150, facecolor='w', edgecolor='w', 
#             format=None, transparent=False, 
#             bbox_inches=None, pad_inches=0.0,
#             frameon=None, metadata=None)

In [None]:
# reset to orig palette
sns.set_palette('colorblind')

Here is a nice plot that shows the distribution of $R_t$ at the latest day of sampling, ordered by its median value:

In [None]:
nc = 4
nd = 1000

d = {'state': [], 'Rt': []}
for k, f in state_fits.items():
    d = {'state': np.concatenate((d['state'], (k,)*(nc*nd))), 
         'Rt': np.concatenate((d['Rt'], f.posterior.Rt[:,:,-1].values.flatten()))}
df = pd.DataFrame(d)

sort_Rts = [np.median(fit.posterior.Rt[:,:,-1]) for fit in state_fits.values()]
state_order = np.array(list(state_fits.keys()))[np.argsort(sort_Rts)]

Compare to http://rt.live:

In [None]:
plt.figure(figsize=(24,4))
sns.boxplot(x='state', y='Rt', data=df, order=state_order, fliersize=0, whis=1.0/1.35, )
plt.axhline(1, color='k')
plt.xticks(rotation=90)
out = plt.axis(ymin=0, ymax=2.5)

Should check in on this model for $\tau$---in the states where it is well-measured, we favor a slightly smaller $\tau$ than the default prior.  Might be worth building a hierarchical model to re-infer $\tau$ from the national data....  Black line is the prior.

In [None]:
with sns.color_palette('husl', n_colors=len(state_fits)):
    for f in state_fits.values():
        sns.kdeplot(f.posterior.tau.values.flatten())
xs = np.linspace(0, 20, 1024)
plt.plot(xs, ss.lognorm(0.57, scale=np.exp(1.4)).pdf(xs), color='k')
plt.axis(xmin=0,xmax=12)

out = plt.xlabel(r'$\tau$ ($\mathrm{d}$)')

