# Estimating COVID-19's $R_t$ in Real-Time - MCMC
Kevin Systrom - April 12

In [None]:
import pandas as pd
import numpy as np

from matplotlib import pyplot as plt

import pymc3 as pm

FILTERED_REGIONS = [
    'Virgin Islands',
    'American Samoa',
    'Northern Mariana Islands',
    'Guam',
    'Puerto Rico']

%config InlineBackend.figure_format = 'retina'

In [None]:
url = 'https://raw.githubusercontent.com/nytimes/covid-19-data/master/us-states.csv'
states = pd.read_csv(url,
                     usecols=[0,1,3],
                     index_col=['state', 'date'],
                     parse_dates=['date'],
                     squeeze=True).sort_index()

In [None]:
state_name = 'California'

def prepare_cases(cases):
    new_cases = cases.diff()

    smoothed = new_cases.rolling(7,
        win_type='gaussian',
        min_periods=1,
        center=True).mean(std=2).round()
    
    zeros = smoothed.index[smoothed.eq(0)]
    if len(zeros) == 0:
        idx_start = 0
    else:
        last_zero = zeros.max()
        idx_start = smoothed.index.get_loc(last_zero) + 1
    smoothed = smoothed.iloc[idx_start:]
    original = new_cases.loc[smoothed.index]
    
    return original, smoothed

cases = states.xs(state_name).rename(f"{state_name} cases")

original, smoothed = prepare_cases(cases)

In [None]:
original = original.clip(1)

In [None]:
smoothed = smoothed.iloc[10:]

In [None]:
def lambda_to_rt(λ, k, γ=4):
    k_tm1 = k.iloc[:-1].values
    λ = λ[1:]
    return pm.math.log(λ / k_tm1) / γ + 1

In [None]:
with pm.Model() as model:
    step_width = pm.HalfNormal('step_width', sigma=100.)
    
    lam = pm.GaussianRandomWalk('lambda',
                               mu=0,
                               sigma=step_width,
                               shape=len(smoothed),
                               testval=smoothed.values)
    
    pm.Deterministic('Rt', lambda_to_rt(lam, smoothed))
    
    pm.Poisson('obs', lam, observed=smoothed.values)
    
    trace = pm.sample(cores=1, tune=1000, target_accept=0.95)

In [None]:
pm.traceplot(trace);

In [None]:
plt.plot(np.log(trace['lambda'].T), alpha=.01, color='0.5')
plt.plot(np.log(smoothed.values))
plt.ylabel('$log(\lambda)$');

In [None]:
fig, ax = plt.subplots()

ax.plot(trace['Rt'].T, alpha=.01, color='.5', lw=1);
ax.set_ylabel('$R_{t}$');
ax.set_ylim(.8, 1.3)
ax.axhline(1.0, c='k', linestyle=":")