# Estimating COVID-19's $R_t$ in Real-Time with PYMC3


This notebook is based off of the fantastic work of [Kevin Systrom](https://github.com/k-sys) & the fine folks at [rt.live](https://rt.live)

The model itself is based on Luís M. A. Bettencourt and Ruy M. Ribeiro's [paper](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0002185)


In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import requests
import pymc3 as pm
import pandas as pd
import numpy as np
import theano
import theano.tensor as tt
import theano.tensor.slinalg

from matplotlib import pyplot as plt
from matplotlib import dates as mdates
from matplotlib import ticker

from datetime import date
from datetime import datetime

from IPython.display import clear_output

%config InlineBackend.figure_format = 'retina'

In [None]:
import os
from datetime import date
from datetime import timedelta
import time

import plotly.graph_objects as go
import plotly.express as px

from tqdm import tqdm
import gc
import sys
sys.path.append('..')

from src.utils.load_data import load_population_df, load_confirmed_cases_df 

import subprocess

%load_ext autoreload
%autoreload 2

## Load State Information
#### Load

In [None]:
url = 'https://covidtracking.com/api/v1/states/daily.csv'
states = pd.read_csv(url,
                     parse_dates=['date'],
                     index_col=['state', 'date']).sort_index()

# Note: GU/AS/VI do not have enough data for this model to run
# Note: PR had -384 change recently in total count so unable to model
states = states.drop(['MP', 'GU', 'AS', 'PR', 'VI'])

#### Clean data with known modifications

In [None]:
# Errors in Covidtracking.com
states.loc[('WA','2020-04-21'), 'positive'] = 12512
states.loc[('WA','2020-04-22'), 'positive'] = 12753 
states.loc[('WA','2020-04-23'), 'positive'] = 12753 + 190

states.loc[('VA', '2020-04-22'), 'positive'] = 10266
states.loc[('VA', '2020-04-23'), 'positive'] = 10988

states.loc[('PA', '2020-04-22'), 'positive'] = 35684
states.loc[('PA', '2020-04-23'), 'positive'] = 37053

states.loc[('MA', '2020-04-20'), 'positive'] = 39643

states.loc[('CT', '2020-04-18'), 'positive'] = 17550
states.loc[('CT', '2020-04-19'), 'positive'] = 17962

states.loc[('HI', '2020-04-22'), 'positive'] = 586

states.loc[('RI', '2020-03-07'), 'positive'] = 3

In [None]:
def create_case_pop_df(
                population_file_path, 
                case_file_path, 
                cum_cases_col='cases',
                date_col='date',
                pop_fips_col='fips',
                case_fips_col='countyFIPS',
                case_county_col='County Name',
                case_state_col='state'
    
      ):
    '''
    A dirty custom function designed to aid the consolidation of population data
    against case data for a certain geography.
    '''
    # Population Data at county level
    pop_df = load_population_df(population_file_path)
    print(pop_df.shape)

    # COVID Cases
    cases_df = load_confirmed_cases_df(case_file_path)
    print(cases_df.shape)


    ##############################################################

    cases_pop_df = pd.merge(
        left=cases_df,
        right=pop_df.rename(columns={pop_fips_col:case_fips_col}),
        left_on=case_fips_col,
        right_on=case_fips_col,
        how='left'
    ).drop_duplicates()


    cases_pop_df['County_State'] = cases_pop_df[case_county_col].str.title()\
                + ' ' + cases_pop_df[case_state_col].str.upper()
    # cases_pop_df['active_cases'] = cases_pop_df['cases'] - cases_pop_df['cases'].shift(14).fillna(0)
    # cases_pop_df['new_cases'] = cases_pop_df['cases'].diff()

    ##############################################################


    append_list = []
    for n, g in cases_pop_df.groupby('County_State'):
        g.sort_values(date_col, inplace=True)
        g['new_cases'] = g[cum_cases_col].diff()
        g['active_cases'] = g[cum_cases_col] - g[cum_cases_col].shift(14).fillna(0)
        append_list.append(g)
    cases_pop_df = pd.concat(append_list)
    del append_list
    return cases_pop_df

In [None]:
cases_pop_df = create_case_pop_df(
    population_file_path='../data/misc/CountyHealthRankings19.csv',
    case_file_path='../data/county_level/covid_confirmed_usafacts.csv'
)

cases_pop_df.head()

## Load Patient Information

Data for the section below is sourced from the University of Washington's [Outbreak and Pandemic Preparedness team](https://github.com/beoutbreakprepared/nCoV2019/raw/master/latest_data/latestdata.tar.gz)

This allows for tracking symptoms vs positive case results, thus giving a better sense of likely onset dates vs dates reported in the official counts. 

In [None]:
LINELIST_PATH = '../data/misc/latestdata.csv'


### Calculate the Probability Distribution of Delay

In [None]:
def calc_p_delay(onset_confirm_path=LINELIST_PATH):

    # Load the patient CSV
    patients = pd.read_csv(
        LINELIST_PATH,
        parse_dates=False,
        usecols=[
            'date_confirmation',
            'date_onset_symptoms'],
        low_memory=False)

    patients.columns = ['Onset', 'Confirmed']

    # There's an errant reversed date
    patients = patients.replace('01.31.2020', '31.01.2020')

    # Only keep if both values are present
    patients = patients.dropna()

    # Must have strings that look like individual dates
    # "2020.03.09" is 10 chars long
    is_ten_char = lambda x: x.str.len().eq(10)
    patients = patients[is_ten_char(patients.Confirmed) & 
                        is_ten_char(patients.Onset)]

    # Convert both to datetimes
    patients.Confirmed = pd.to_datetime(
        patients.Confirmed, format='%d.%m.%Y', errors='coerce')
    patients.Onset = pd.to_datetime(
        patients.Onset, format='%d.%m.%Y', errors='coerce')

    # Only keep records where confirmed > onset
    patients = patients[patients.Confirmed >= patients.Onset]
    
    

    # Calculate the delta in days between onset and confirmation
    delay = (patients.Confirmed - patients.Onset).dt.days

    # Convert samples to an empirical distribution
    p_delay = delay.value_counts().sort_index()
    new_range = np.arange(0, p_delay.index.max()+1)
    p_delay = p_delay.reindex(new_range, fill_value=0)
    p_delay /= p_delay.sum()
    
    return p_delay

In [None]:
p_delay = calc_p_delay(LINELIST_PATH)

In [None]:
ax = patients.plot.scatter(
    title='Onset vs. Confirmed Dates - COVID19',
    x='Onset',
    y='Confirmed',
    alpha=.1,
    lw=0,
    s=10,
    figsize=(6,6))

formatter = mdates.DateFormatter('%m/%d')
locator = mdates.WeekdayLocator(interval=2)

for axis in [ax.xaxis, ax.yaxis]:
    axis.set_major_formatter(formatter)
    axis.set_major_locator(locator)

# $R_t$ estimation

### Translate Confirmation Dates to Onset Dates

Our goal is to translate positive test counts to the dates where they likely occured. Since we have the distribution, we can distribute case counts back in time according to that distribution. To accomplish this, we reverse the case time series, and convolve it using the distribution of delay from onset to confirmation. Then we reverse the series again to obtain the onset curve. Note that this means the data will be 'right censored' which means there are onset cases that have yet to be reported so it looks as if the count has gone down.

In [None]:
def confirmed_to_onset(confirmed, p_delay):

    assert not confirmed.isna().any()
    
    # Reverse cases so that we convolve into the past
    convolved = np.convolve(confirmed[::-1].values, p_delay)

    # Calculate the new date range
    dr = pd.date_range(end=confirmed.index[-1],
                       periods=len(convolved))

    # Flip the values and assign the date range
    onset = pd.Series(np.flip(convolved), index=dr)
    
    return onset


# onset = confirmed_to_onset(confirmed, p_delay)

### Adjust for Right-Censoring

Since we distributed observed cases into the past to recreate the onset curve, we now have a right-censored time series. We can correct for that by asking what % of people have a delay less than or equal to the time between the day in question and the current day.

For example, 5 days ago, there might have been 100 cases onset. Over the course of the next 5 days some portion of those cases will be reported. This portion is equal to the cumulative distribution function of our delay distribution. If we know that portion is say, 60%, then our current count of onset on that day represents 60% of the total. This implies that the total is 166% higher. We apply this correction to get an idea of what actual onset cases are likely, thus removing the right censoring.

In [None]:
def adjust_onset_for_right_censorship(onset, p_delay):
    cumulative_p_delay = p_delay.cumsum()
    
    # Calculate the additional ones needed so shapes match
    ones_needed = len(onset) - len(cumulative_p_delay)
    padding_shape = (0, ones_needed)
    
    # Add ones and flip back
    cumulative_p_delay = np.pad(
        cumulative_p_delay,
        padding_shape,
        constant_values=1)
    cumulative_p_delay = np.flip(cumulative_p_delay)
    
    # Adjusts observed onset values to expected terminal onset values
    adjusted = onset / cumulative_p_delay
    
    return adjusted, cumulative_p_delay

### Sample the Posterior with PyMC3

We assume a poisson likelihood function and feed it what we believe is the onset curve based on reported data. We model this onset curve based on the math in the Bettencourt & Ribeiro [paper](https://journals.plos.org/plosone/article?id=10.1371/journal.pone.0002185):

$$ I^\prime = Ie^{\gamma(R_t-1)} $$

We define $\theta = \gamma(R_t-1)$ and model $ I^\prime = Ie^{\theta} $ where $\theta$ observes a random walk. We let $\gamma$ vary independently based on known parameters for the serial interval. Therefore, we can recover $R_t$ easily by $R_t = \frac{\theta}{\gamma}+1$

The only tricky part is understanding that we're feeding in _onset_ cases to the likelihood. So $\mu$ of the poisson is the positive, non-zero, expected onset cases we think we'd see today.

We calculate this by figuring out how many cases we'd expect there to be yesterday total when adjusted for bias and plugging it into the first equation above. We then have to re-bias this number back down to get the expected amount of onset cases observed that day.

In [None]:
class MCMCModel(object):
    
    def __init__(self, region, onset, cumulative_p_delay, window=50):
        
        # Just for identification purposes
        self.region = region
        
        # For the model, we'll only look at the last N
        self.onset = onset.iloc[-window:]
        self.cumulative_p_delay = cumulative_p_delay[-window:]
        
        # Where we store the results
        self.trace = None
        self.trace_index = self.onset.index[1:]

    def run(self, chains=1, tune=3000, draws=1000, target_accept=.95):

        with pm.Model() as model:

            # Random walk magnitude
            step_size = pm.HalfNormal('step_size', sigma=.03)

            # Theta random walk
            theta_raw_init = pm.Normal('theta_raw_init', 0.1, 0.1)
            theta_raw_steps = pm.Normal('theta_raw_steps', shape=len(self.onset)-2) * step_size
            theta_raw = tt.concatenate([[theta_raw_init], theta_raw_steps])
            theta = pm.Deterministic('theta', theta_raw.cumsum())

            # Let the serial interval be a random variable and calculate r_t
            serial_interval = pm.Gamma('serial_interval', alpha=6, beta=1.5)
            gamma = 1.0 / serial_interval
            r_t = pm.Deterministic('r_t', theta/gamma + 1)

            inferred_yesterday = self.onset.values[:-1] / self.cumulative_p_delay[:-1]
            
            expected_today = inferred_yesterday * self.cumulative_p_delay[1:] * pm.math.exp(theta)

            # Ensure cases stay above zero for poisson
            mu = pm.math.maximum(.1, expected_today)
            observed = self.onset.round().values[1:]
            cases = pm.Poisson('cases', mu=mu, observed=observed)

            self.trace = pm.sample(
                chains=chains,
                tune=tune,
                draws=draws,
                target_accept=target_accept)
            
            return self
    
    def run_gp(self):
        with pm.Model() as model:
            gp_shape = len(self.onset) - 1

            length_scale = pm.Gamma("length_scale", alpha=3, beta=.4)

            eta = .05
            cov_func = eta**2 * pm.gp.cov.ExpQuad(1, length_scale)

            gp = pm.gp.Latent(mean_func=pm.gp.mean.Constant(c=0), 
                              cov_func=cov_func)

            # Place a GP prior over the function f.
            theta = gp.prior("theta", X=np.arange(gp_shape)[:, None])

            # Let the serial interval be a random variable and calculate r_t
            serial_interval = pm.Gamma('serial_interval', alpha=6, beta=1.5)
            gamma = 1.0 / serial_interval
            r_t = pm.Deterministic('r_t', theta / gamma + 1)

            inferred_yesterday = self.onset.values[:-1] / self.cumulative_p_delay[:-1]
            expected_today = inferred_yesterday * self.cumulative_p_delay[1:] * pm.math.exp(theta)

            # Ensure cases stay above zero for poisson
            mu = pm.math.maximum(.1, expected_today)
            observed = self.onset.round().values[1:]
            cases = pm.Poisson('cases', mu=mu, observed=observed)

            self.trace = pm.sample(chains=1, tune=1000, draws=1000, target_accept=.8)
        return self

### Run Pymc3 Model

In [None]:
def df_from_model(model):
    
    r_t = model.trace['r_t']
    mean = np.mean(r_t, axis=0)
    median = np.median(r_t, axis=0)
    hpd_90 = pm.stats.hpd(r_t, credible_interval=.9)
    hpd_50 = pm.stats.hpd(r_t, credible_interval=.5)
    
    idx = pd.MultiIndex.from_product([
            [model.region],
            model.trace_index
        ], names=['region', 'date'])
        
    df = pd.DataFrame(data=np.c_[mean, median, hpd_90, hpd_50], index=idx,
                 columns=['mean', 'median', 'lower_90', 'upper_90', 'lower_50','upper_50'])
    return df

def create_and_run_model(name, county_state, case_col='new_cases'):
    confirmed = county_state[case_col].dropna()
    onset = confirmed_to_onset(confirmed, p_delay)
    adjusted, cumulative_p_delay = adjust_onset_for_right_censorship(onset, p_delay)
    return MCMCModel(name, onset, cumulative_p_delay).run()

### Render Charts

In [None]:
def plot_rt(name, result, ax, c=(.3,.3,.3,1), ci=(0,0,0,.05)):
    ax.set_ylim(0.5, 1.6)
    ax.set_title(name)
    ax.plot(result['median'],
            marker='o',
            markersize=4,
            markerfacecolor='w',
            lw=1,
            c=c,
            markevery=2)
    ax.fill_between(
        result.index,
        result['lower_90'].values,
        result['upper_90'].values,
        color=ci,
        lw=0)
    ax.axhline(1.0, linestyle=':', lw=1)
    
    ax.xaxis.set_major_formatter(mdates.DateFormatter('%m/%d'))
    ax.xaxis.set_major_locator(mdates.WeekdayLocator(interval=2))

## Alternative Method: All Combined

In [None]:
def check_divergences (models):
    ##################################

    # Check to see if there were divergences
    n_diverging = lambda x: x.trace['diverging'].nonzero()[0].size
    divergences = pd.Series([n_diverging(m) for m in models.values()], index=models.keys())
    has_divergences = divergences.gt(0)

    print('Diverging states:')
    display(divergences[has_divergences])

    # Rerun counties with divergences
    for county_state, n_divergences in divergences[has_divergences].items():
        models[county_state].run(chains=2)

    gc.collect()
    return models

In [None]:
def regional_rt_model(subset_df, 
                  region_col='County_State',
                  date_col='date',
                  case_col='new_cases',
                  output_path=None,
               ):

    
    ## Assuming no duplicates
    ## Consider the scenario where we intend to calculate these results 
    ## at a state level, instead of at the county level.
    
    subset_df = subset_df.groupby([region_col, date_col])\
        [case_col].sum().reset_index().sort_values(date_col)
    
    ######################################
    
    models = {}
    err_list = []
    NUM_REGIONS = subset_df[region_col].nunique()
    j = 0
    for region, grp in subset_df.set_index(date_col).groupby(region_col):
        
        j = j+1
        print (f'\t\t{j} of {NUM_REGIONS} regions in current subset...')
        
        
        try:
            if region in models:
                print(f'Skipping {region}, already in cache')
                continue

            models[region] = create_and_run_model(region, grp, case_col)
        except:
            err_list.append(region)

    gc.collect()

    ######################################
    
    # Check to see if there were divergences
    n_diverging = lambda x: x.trace['diverging'].nonzero()[0].size
    divergences = pd.Series([n_diverging(m) for m in models.values()], index=models.keys())
    has_divergences = divergences.gt(0)

    print('Diverging states:')
    display(divergences[has_divergences])

    # Rerun counties with divergences
    for region, n_divergences in divergences[has_divergences].items():
        models[region].run(chains=2)

    gc.collect()
    
    ######################################

    results = None

    for region, model in models.items():

        df = df_from_model(model)

        if results is None:
            results = df
        else:
            results = pd.concat([results, df], axis=0)

    ##################################
    
    if output_path is not None:
        results.to_csv(output_path)

    return results, err_list 

In [None]:
p_delay = calc_p_delay('../data/misc/latestdata.csv')
state = 'AL'
subset_df = cases_pop_df[cases_pop_df['State'].isin([state])]
results, err_list = regional_rt_model(subset_df, case_col='new_cases',region_col='state', output_path=f'../../DATA/rt_state/rt_state_{state}.csv')

In [None]:
p_delay = calc_p_delay('../data/misc/latestdata.csv')
subset_df = cases_pop_df[cases_pop_df['state'].isin(['DE'])]
results, err_list = regional_rt_model(subset_df, case_col='new_cases',region_col='state')

In [None]:
ncols = 4
nrows = int(np.ceil(results.index.levels[0].shape[0] / ncols))

fig, axes = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(14, nrows*3),
    sharey='row')

for ax, (county_state, result) in zip(axes.flat, results.groupby('region')):
    plot_rt(county_state, result.droplevel(0), ax)

fig.tight_layout()
fig.set_facecolor('w')

# Appendix

In [None]:
gc.collect()