In [1]:
import matplotlib
import matplotlib.pyplot as plt

import jax
import jax.numpy as np
from jax.random import PRNGKey

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive

import pandas as pd
from functools import partial
from glm import glm, GLM, log_link, Gamma
from models import SEIR_hierarchical, plot_samples, plot_forecast, plot_R0
import util


def gen_covariates(places=None, num_places=5, intercept=False, drop_first=True):
    if places is None:
        places = [f'{i}' for i in range(num_places)]
    places = pd.DataFrame({'place': places})
    covariates = pd.get_dummies(places, drop_first=drop_first)
    if intercept:
        covariates['intercept'] = 1
    return covariates


Bad key "nbagg.transparent" on line 426 in
/usr/local/Cellar/python3/3.6.4_2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.2.1/matplotlibrc.template
or from the matplotlib source distribution

Bad key "animation.mencoder_path" on line 509 in
/usr/local/Cellar/python3/3.6.4_2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle.
You probably need to get an updated matplotlibrc file from
https://github.com/matplotlib/matplotlib/blob/v3.2.1/matplotlibrc.template
or from the matplotlib source distribution

Bad key "animation.mencoder_args" on line 512 in
/usr/local/Cellar/python3/3.6.4_2/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle.
You probably need to get an update

In [2]:
import models
models.SEIR_stochastic

<function models.SEIR_stochastic>

# Run Inference

In [None]:
#states = ['MA', 'NY', 'WA', 'AL']
states = ['MA', 'NY']
data, place_data = util.load_state_Xy(which=states)
data = data.join(place_data.drop(columns='state'), how='inner')

args = {
    'data': data,
    'place_data' : place_data,
    'use_rw': False,
    'rw_scale': 1e-2,
    'det_noise_scale' : 0.2
}

prob_model = SEIR_hierarchical

kernel = NUTS(prob_model,
              init_strategy = numpyro.infer.util.init_to_median())

mcmc = MCMC(kernel, 
            num_warmup=10,
            num_samples=10,
            num_chains=1)

mcmc.run(jax.random.PRNGKey(1), use_obs=True, **args)

mcmc.print_summary()
mcmc_samples = mcmc.get_samples()


# Prior samples for comparison
prior = Predictive(prob_model, posterior_samples = {}, num_samples = 100)
prior_samples = prior(PRNGKey(2), **args)

# Posterior predictive samples for visualization
args['rw_scale'] = 0 # set drift to zero for forecasting
post_pred = Predictive(prob_model, posterior_samples = mcmc_samples)
post_pred_samples = post_pred(PRNGKey(2), T_future=100, **args)

sample: 100%|██████████| 20/20 [01:53<00:00,  5.67s/it, 1023 steps of size 8.49e-06. acc. prob=0.36]
  rho_k = 1. - (var_within - gamma_k_c.mean(axis=0)) / var_estimator



                                    mean       std    median      5.0%     95.0%     n_eff     r_hat
                         E0[0]  99908.52    106.08  99881.62  99743.53 100024.55      3.06      1.88
                         E0[1] 154882.89    152.29 154875.91 154704.61 155153.47      3.46      1.61
                 E_duration[0]      6.75      0.01      6.75      6.74      6.77      5.03      1.19
                 E_duration[1]      4.15      0.01      4.15      4.14      4.16      7.45      0.92
E_duration_C(state, OneHot)[0]      0.44      0.00      0.44      0.43      0.44      2.48      2.19
E_duration_C(state, OneHot)[1]     -0.64      0.00     -0.64     -0.64     -0.63      5.18      1.05
       E_duration_Intercept[0]     -0.13      0.00     -0.13     -0.13     -0.12      6.04      0.91
                         I0[0]  79967.14     35.31  79966.16  79924.23  80022.87      3.86      1.36
                         I0[1] 133515.92    209.30 133444.38 133235.41 133834.70      4.84

In [None]:
util.write_summary('US_covariates', mcmc)
util.save_samples('US_covariates', prior_samples, mcmc_samples, post_pred_samples)

In [None]:
print("gamma", 1/mcmc_samples['I_duration'].mean(axis=0))
print("R0", mcmc_samples['R0'].mean(axis=0))
print("future", post_pred_samples['R0_future'].mean(axis=0))

In [None]:
T = 100
scale = 'log'

places = data.index.unique(level=0)
start = data.index.unique(level=1).min()
num_places = len(places)

for i, place in zip(range(num_places), places):
    
    t = pd.date_range(start=start, periods=T, freq='D')
    
    shared_params = ['beta0_base', 'gamma_base', 'sigma_base', 'det_rate_base']
    place_samples = {k: v[:,i,...] for k, v in post_pred_samples.items() if not k in shared_params}
    
    t = pd.date_range(start=start, periods=T, freq='D')
    
    positive = data.loc[place].positive
    
    fig, ax = plot_forecast(place_samples, T, positive, t=t, scale=scale)

    #plot_samples(place_samples, plot_fields=['I'], ax=ax, t=t, T=T)
    #plt.plot(obs[i,:], 'o')

    name = place
    plt.suptitle(f'{name} {T} days ')
    plt.tight_layout()
    plt.show()

    
    

In [None]:
load = True

for place in ['Italy', 'US', 'WA', 'NY', 'MA']:

    if load:
        confirmed = data[place].confirmed[start:]
        start = confirmed.index.min()

        T = len(confirmed)
        N = pop[place]

        filename = f'out/{place}_samples.npz'
        x = np.load(filename, allow_pickle=True)
        mcmc_samples = x['arr_0'].item()
        post_pred_samples = x['arr_1'].item()
        
    
    # Inspect and Save Results

    for scale in ['log', 'lin']:
        for T in [len(confirmed), 30, 40, 50, 100]:

            t = pd.date_range(start=start, periods=T, freq='D')

            fig, ax = plt.subplots(figsize=(14,5))
            plot_samples(post_pred_samples, T=T, t=t, ax=ax, plot_fields=['I', 'y'], model=model)
            plt.title('Posterior predictive')

            confirmed.plot(style='o')

            if scale == 'log':
                plt.yscale('log')

            filename = f'figs/{place}_predictive_scale_{scale}_T_{T}.png'
            plt.savefig(filename)

            plt.show()

    # Compute average R0 over time
    gamma = mcmc_samples['gamma'][:,None]
    beta = mcmc_samples['beta']
    t = pd.date_range(start=start, periods=beta.shape[1], freq='D')
    R0 = beta/gamma

    pi = np.percentile(R0, (10, 90), axis=0)
    df = pd.DataFrame(index=t, data={'R0': R0.mean(axis=0)})
    df.plot(style='-o')
    plt.fill_between(t, pi[0,:], pi[1,:], alpha=0.1)

    filename = f'figs/{place}_R0.png'
    plt.savefig(filename)

    plt.title(place)
    plt.show()

In [None]:
from compartment import SIRModel, SEIRModel

save = False

beta = mcmc_samples['beta0']
#gamma = np.broadcast_to(mcmc_samples['gamma'][:,None], beta.shape)
gamma = mcmc_samples['gamma']
plt.plot(beta, gamma, '.')
offset = (beta-gamma).mean()

plt.plot([0.+offset, gamma.max()+offset], [0., gamma.max()])
plt.xlabel('beta')
plt.ylabel('gamma')
plt.title('posterior over (beta, gamma)')
plt.legend(['samples', r'gamma = beta + ' + f'{offset:.2f}'])
if save:
    filename = f'figs/{place}_beta_gamma.pdf'
    plt.savefig(filename)
plt.show()


plt.hist(beta/gamma, bins=100)
plt.title('R_0')
plt.xlabel('beta/gamma')
R_0_mean = np.mean(beta/gamma)
print("R0:", R_0_mean)
if save:
    filename = f'figs/{place}_R0.pdf'
    plt.savefig(filename)
plt.show()


growth_rate = SEIRModel.growth_rate((mcmc_samples['beta0'], 
                                     mcmc_samples['sigma'],
                                     mcmc_samples['gamma']))
plt.hist(growth_rate, bins=100)
plt.title('growth rate')
if save:
    filename = f'figs/{place}_growth_rate.pdf'
    plt.savefig(filename)
plt.show()

plt.hist(mcmc_samples['det_rate'], bins=50)
plt.title('det. rate')
if save:
    filename = f'figs/{place}_growth_rate.pdf'
    plt.savefig(filename)
plt.show()


## Tests

In [None]:
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

x = np.arange(10)
for i in range(len(colors)):
    y = i*np.arange(10)
    plt.plot(x, y, color=colors[i])


In [None]:
data, place_data = util.load_state_Xy()
data = data.join(place_data.drop(columns='state'), how='inner')
#print(data.columns)
i, j = 23, 40
display(data.loc['MA'].iloc[:,i:j])


In [None]:
data, place_data = util.load_state_Xy()
data = data.join(place_data.drop(columns='state'), how='inner')
future = util.future_data(data, 100)
display(future)
display(future['t'])