In [None]:
import matplotlib.pyplot as pplt
import plotly.graph_objects as go
import theano.tensor as tt
import scipy as sc
import pymc3 as pm
import math as math
import numpy as np
import pandas as pd
import descartes

# Reading Live Data
The first step in building a model is collecting data. Here, we will pull data from the COVID Tracking Project (https://covidtracking.com) that updates daily. We will then trim and store our data accordingly.

In [None]:
# making dict of state populations
statePops = {'AK': 731545, 'AL': 4903185, 'AR': 3017804, 'AS': 49437, 'AZ': 7278717,
             'CA': 39512223, 'CO': 5758736, 'CT': 3565287, 'DC': 705749, 'DE': 973764,
             'FL': 21477737, 'GA': 10617423, 'GU': 168485, 'HI': 1415872, 'IA': 3155070,
             'ID': 1787065, 'IL': 12671821, 'IN': 6732219, 'KS': 2913314, 'KY': 4467673,
             'LA': 4648794, 'MA': 6892503, 'MD': 6045680, 'ME': 1344212, 'MI': 9986857,
             'MN': 5639632, 'MO': 6137428, 'MP': 51433, 'MS': 2976149, 'MT': 1068778,
             'NC': 10488084, 'ND': 762062, 'NE': 1934408, 'NH': 1359711, 'NJ': 8882190,
             'NM': 2096829, 'NV': 3080156, 'NY': 19453561, 'OH': 11689100, 'OK': 3956971,
             'OR': 4217737, 'PA': 12801989, 'PR': 3193694, 'RI': 1059361, 'SC': 5148714,
             'SD': 884659, 'TN': 6829174, 'TX': 28995881, 'UT': 3205958, 'VA': 8535519,
             'VI': 106235, 'VT': 623989, 'WA': 7614893, 'WI': 5822434, 'WV': 1792147,
             'WY': 578759}
    
# this function will attempt to fill in missing data from each state as accurately as possible
def fill_na(df):
    df = df.sort_values(by=['date'])
    filled_df = []
    
    # loop through table for each state
    for state in np.unique(df.state):
        # subset by state
        this_df = df.loc[df.state == state,]
        # fill NA forward with first good value
        this_df = this_df.fillna(method='ffill')
        # fill NA backward with zero, this ASSUMES deaths were 0 before first set of reported deaths
        this_df.death = this_df.death.fillna(0)
        # appending filled data to filledDf
        filled_df.append(this_df)

    # concatenating filledDf into dataframe
    filled_df = pd.concat(filled_df)

    return filled_df

# this function will take in a csv and a name and return a dataframe with that name
def load_us_df(file):
    df = pd.read_csv(file)
    df = df[['date','state','positive','negative','recovered','death','hospitalized','total']]
    df = fill_na(df)
    df['date'] = df['date'].apply(lambda d: str(d)[0:4] + '-' + str(d)[4:6] + '-' + str(d)[6:8])
    
    return df

# URL of data
historicalStatesURL = 'https://covidtracking.com/api/v1/states/daily.csv'
# loading data
stateData = load_us_df(historicalStatesURL)
stateData = stateData[~stateData.state.isin(["AS","DC","GU","MP","PR","VI"])]

stateData

# Bayesian Modelling
Here we define a function that will take in data as well as initial conditions and perform a Markov-Chain Monte Carlo simulation on the given data generation model to estimate the parameters of the SIR model in question.

In [None]:
# here, start is a parameter that will indicate which of 3 sets of starting values to work with
def sampleMCMC(data, pop, start):

    # splitting data into infections and time as numpy arrays
    dataDeath = data['death'].to_numpy()
    time = np.linspace(0,len(data)-1, len(data))

    # establishing model
    with pm.Model() as model:
        
        # create population number priors
        i0 = pm.Poisson('i0', mu=pop/1000)
        s0 = pm.Deterministic('s0', pop - i0)
        
        # extract starting components
        pos = float(data['positive'].iloc[-1])
        rec = float(data['recovered'].iloc[-1])
        dea = float(data['death'].iloc[-1])
        tot = float(data['total'].iloc[-1])
        
        # create starting values based on data, does not inform inference but starts at a reasonable value
        # start beta conditional on start argument
        beta_start = (
            ((pos/tot)/2 if start==1 else
            (pos/tot + pos/pop)/2 if start == 2 else
            (pos/tot)) if (not(math.isnan(pos)) and not(math.isnan(tot))) else .05)
        
        # start gamma conditional on start argument
        gamma_start = (
            ((rec/tot)/2 if start==1 else
            (rec/tot + rec/pop)/2 if start==2 else
            (rec/tot)) if (not(math.isnan(rec)) and not(math.isnan(tot))) else .047)
        
        # start rho conditional on start argument
        rho_start = (
            ((dea/pos)/2 if start==1 else
            (dea/pos + dea/pop)/2 if start==2 else
            (dea/pos)) if (not(math.isnan(dea)) and not(math.isnan(pos))) else .036)
        
        # creating priors for beta, gamma, and rho
        beta = pm.InverseGamma('beta', mu=.05, sigma=.5, testval=beta_start)
        gamma = pm.InverseGamma('gamma', mu=.047, sigma=.5, testval=gamma_start)
        rho = pm.TruncatedNormal('rho', mu=.036, sigma=.01, lower=0, upper=1, testval=rho_start)

        # create number of removed based on analytic solution and above parameters
        sirRem = pm.Deterministic('sirRem',
            pop - ((s0 + i0)**(beta/(beta - gamma)))*
            (s0 + i0*tt.exp(time*(beta - gamma)))**(-gamma/(beta - gamma)))
        # create number of deaths as a fraction of number of removed
        sirDeath = pm.Deterministic('sirDeath', rho*sirRem)
        
        # create variance prior
        sigma = pm.HalfCauchy('sigma', beta=2)
        
        # create likelihood with modelled counts and observed counts
        obsDeath = pm.TruncatedNormal('obsDeath', mu=sirDeath, sigma=sigma,
                                     lower=0, upper=pop, observed=dataDeath)

        # specifying model conditions
        step=pm.NUTS(target_accept=.99)
        start=pm.find_MAP()
        
        # execute sampling
        model_trace = pm.sample(draws=500, tune=500, step=step, start=start, chains=5, cores=16)

    # return posterior samples and other information
    return model_trace

In [None]:
summary = pd.DataFrame()
trace_results = pd.DataFrame(columns=['state','trace'])

for state in stateData['state'].unique():
    
    display(state + ":")
    display("=======================================================================================================")
                        
    try:
        # perform sampling, first set of starting values
        this_sample = sampleMCMC(stateData.loc[stateData.state == state], statePops[state], 1)
        
    except Exception as e:
        
        try:
            # perform sampling, second set of starting values
            this_sample = sampleMCMC(stateData.loc[stateData.state == state], statePops[state], 2)
            
        except Exception as e:
            
            try:
                # perform sampling, last set of starting values
                this_sample = sampleMCMC(stateData.loc[stateData.state == state], statePops[state], 3)
                
            except Exception as e:
    
                trace_results = trace_results.append({"state":state, "trace":None}, ignore_index=True)
        
                display(state + ' failed')
                display("=======================================================================================================")
                continue
                
    # create summary table
    these_results = (pm.summary(this_sample, var_names=['i0','beta','gamma','rho','sigma'], round_to=5).
        drop(['mcse_sd','ess_sd','ess_bulk','ess_tail'], axis=1).
        reset_index().rename(columns={"index": "param"}))
    these_results['state'] = state

    # make plots
    plot = pm.plot_trace(this_sample, var_names=('i0','beta','gamma','rho','sigma'))
    pplt.show()
    
    if sum(these_results[1:4]['r_hat']) > 3.15:
        trace_results = trace_results.append({"state":state, "trace":None}, ignore_index=True)
        
        display(state + ' failed')
        display("=======================================================================================================")
        
    else:

        summary = summary.append(these_results)

        # update trace table
        trace_results = trace_results.append({"state":state, "trace":this_sample}, ignore_index=True)

        display('summary:')
        display("=======================================================================================================")

        display(these_results)

        display(state + ' succeeded')
        display("=======================================================================================================")

# Plotting

In [None]:
# plotting beta per state
fig_beta = go.Figure(data=go.Choropleth(
    locations = summary[summary['param']=='beta']['state'], # Spatial coordinates
    z = summary[summary['param']=='beta']['mean'].astype(float), # Data to be color-coded
    locationmode = 'USA-states', # set of locations match entries in `locations`
    colorscale = 'YlOrBr',
    colorbar_title = "Transmission Rate per Day",
))

fig_beta.update_layout(
    title_text = 'MCMC Estimates of COVID-19 Transmission Rate by State',
    geo_scope='usa', # limite map scope to USA
)

fig_beta.show()

In [None]:
# plotting gamma per state
fig_gamma = go.Figure(data=go.Choropleth(
    locations = summary[summary['param']=='gamma']['state'], # Spatial coordinates
    z = summary[summary['param']=='gamma']['mean'].astype(float), # Data to be color-coded
    locationmode = 'USA-states', # set of locations match entries in `locations`
    colorscale = 'Greens',
    colorbar_title = "Removal Rate per Day",
))

fig_gamma.update_layout(
    title_text = 'MCMC Estimates of COVID-19 Removal Rate by State',
    geo_scope='usa', # limite map scope to USA
)

fig_gamma.show()

In [None]:
# plotting rho per state
fig_rho = go.Figure(data=go.Choropleth(
    locations = summary[summary['param']=='rho']['state'], # Spatial coordinates
    z = summary[summary['param']=='rho']['mean'].astype(float), # Data to be color-coded
    locationmode = 'USA-states', # set of locations match entries in `locations`
    colorscale = 'Reds',
    colorbar_title = "Mortality Rate per Infection",
))

fig_rho.update_layout(
    title_text = 'MCMC Estimates of COVID-19 Mortality Rate by State',
    geo_scope='usa', # limite map scope to USA
)

fig_rho.show()

In [None]:
# plotting rho per state
fig_rho = go.Figure(data=go.Choropleth(
    locations = summary[summary['param']=='rho']['state'], # Spatial coordinates
    z = summary[summary['param']=='rho']['mean'].astype(float), # Data to be color-coded
    locationmode = 'USA-states', # set of locations match entries in `locations`
    colorscale = 'Reds',
    colorbar_title = "Mortality Rate per Infection",
))

fig_rho.update_layout(
    title_text = 'MCMC Estimates of COVID-19 Mortality Rate by State',
    geo_scope='usa', # limite map scope to USA
)

fig_rho.show()

In [None]:
# plotting ovrsk per state
fig_ovrsk = go.Figure(data=go.Choropleth(
    locations = summary[summary['param']=='beta']['state'], # Spatial coordinates
    z = summary[summary['param']=='rho']['mean'].astype(float).to_numpy() * 
        (summary[summary['param']=='beta']['mean'].astype(float).to_numpy() / 
         summary[summary['param']=='gamma']['mean'].astype(float).to_numpy()), # Data to be color-coded
    locationmode = 'USA-states', # set of locations match entries in `locations`
    colorscale = 'Burgyl',
    colorbar_title = "Overall Risk",
))

fig_ovrsk.update_layout(
    title_text = 'Estimated Overall Risk of COVID-19 by State',
    geo_scope='usa', # limite map scope to USA
)

fig_ovrsk.show()