Runs an extended Kalman filter on Prof Flaxman's SEIIR predictions and
measurement data for New York State from mid-April to 7 July 2020.

S - Susceptible
E - Exposed
I1 - Presymptomatic
I2 - Symptomatic
R - Recovered

In [None]:
#%load_ext autoreload
#%autoreload

In [None]:
#%reset

In [1]:
import math
import datetime

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

import data_sets
import seiir_compartmental

In [2]:
# functions to support the Kalman filtering
def get_predicts_prior(day, seiir):
    x_hat = np.array([[seiir['S'].loc[day]],
                      [seiir['E'].loc[day]],
                      [seiir['I1'].loc[day]],
                      [seiir['I2'].loc[day]],
                      [seiir['R'].loc[day]]])

    beta_k = seiir['beta_pred'].loc[day]

    return x_hat, beta_k


def step_seiir(x_hat, constants, beta_k, days=7):
    s_dict = {'S': x_hat[0, 0],
              'E': x_hat[1, 0],
              'I1': x_hat[2, 0],
              'I2': x_hat[3, 0],
              'R': x_hat[4, 0]}

    s = pd.Series(s_dict)

    for i in range(days):
        infectious = s.loc['I1'] + s.loc['I2']
        s = seiir_compartmental.compartmental_covid_step(s, s.sum(),
                                                         infectious,
                                                         constants['alpha'],
                                                         beta_k,
                                                         constants['gamma1'],
                                                         constants['gamma2'],
                                                         constants['sigma'],
                                                         constants['theta'])
    x_hat_future_prior = np.array([[s.loc['S']],
                                   [s.loc['E']],
                                   [s.loc['I1']],
                                   [s.loc['I2']],
                                   [s.loc['R']]])

    return x_hat_future_prior


def predict_step(x_hat_k1_prior, P, Q, beta_k, constants):
    S = x_hat_k1_prior[0, 0]
    E = x_hat_k1_prior[1, 0]
    I1 = x_hat_k1_prior[2, 0]
    I2 = x_hat_k1_prior[3, 0]
    R = x_hat_k1_prior[4, 0]
    N = S + E + I1 + I2 + R
    alpha = constants['alpha']
    sigma = constants['sigma']
    gamma1 = constants['gamma1']
    gamma2 = constants['gamma2']

    part_f_S = np.array([[-beta_k * math.pow(I1 + I2, alpha) / N],
                         [beta_k * math.pow(I1 + I2, alpha) / N],
                         [0],
                         [0],
                         [0]])

    part_f_E = np.array([[0],
                         [-sigma],
                         [sigma],
                         [0],
                         [0]])

    part_f_I1 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [-gamma1],
                          [gamma1],
                          [0]])

    part_f_I2 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [0],
                          [-gamma2],
                          [gamma2]])

    part_f_R = np.array([[0],
                         [0],
                         [0],
                         [0],
                         [0]])

    # 5x5
    f_jacob = np.concatenate([part_f_S, part_f_E, part_f_I1, part_f_I2,
                              part_f_R], axis=1)

    # 5x5
    # P_k1_prior = f_jacob * P * f_jacob^T + Q
    P_k1_prior = np.matmul(np.matmul(f_jacob, P), np.transpose(f_jacob)) + Q
    return P_k1_prior


def update_step(x_hat, x_hat_k1, P_k1, Rn, rho1, rho2, z_k):
    # 5x5
    ep = 10**-10
    H = np.array([[ep, 0, 0, 0, 0],
                  [0, ep, 0, 0, 0],
                  [0, 0, ep, rho1, 0],
                  [0, 0, 0, rho2, 0],
                  [0, 0, 0, 0, ep]])
    #H = np.array([[0, 0, 0, 0, 0],
    #                  [0, 0, 0, 0, 0],
    #                  [0, 0, 0, rho1, 0],
    #                  [0, 0, 0, rho2, 0],
    #                  [0, 0, 0, 0, 0]])
    
    # Si = H * P_k1 * H^T + Rn
    Si = np.matmul(np.matmul(H, P_k1), np.transpose(H)) + Rn

    # K_new = P_k1 * H^T * Si^(-1)
    K_new = np.matmul(np.matmul(P_k1, np.transpose(H)), np.linalg.inv(Si))
    y_new = np.matmul(H, x_hat)

    # 5x1
    diff = z_k - y_new

    x_hat_k1_post = x_hat_k1 + np.matmul(K_new, diff)

    P_k1_post = P_k1 - np.matmul(np.matmul(K_new, Si), np.transpose(K_new))

    # joseph formulation
    #joe2 = np.matmul(np.matmul(K_new, Rn), np.transpose(K_new))
    #joe1 = np.eye(5) - np.matmul(K_new, H)
    #P_k1_post = np.matmul(np.matmul(joe1, P_k1), np.transpose(joe1)) - joe2

    return x_hat_k1_post, P_k1_post

def get_predict_measures():
    predictions = pd.read_csv(
        r'case_vs_symptom/kalman/predicted_measures/predictions.csv',
        header=None, names=['measure_date', 'prediction_date', 'prop', 'case',
                            'prop_pred7', 'case_pred7', 'prop_pred1',
                            'case_pred1'],
        index_col=False, usecols=[0, 4, 5])

    start_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[0]))
    start = datetime.date(2019, 12, 31) + start_d
    end_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[-1]))
    end = datetime.date(2019, 12, 31) + end_d

    measure_rng = pd.date_range(start=start, end=end)

    # code to save:
    # prediction_rng = pd.date_range(start=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[0])),
    # end=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[-1])))

    predictions.set_index(measure_rng, inplace=True)

    return predictions


def get_data_sets(state, fips=None):
    # the post hoc seiir model predictions provided by Prof Flaxman without
    # any kalman filtering:
    #seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state_20200910.csv', header=0,
    #                    index_col='date', parse_dates=True)
    seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state.csv', header=0,
                        index_col='date', parse_dates=True)
    

    fb_data = data_sets.create_symptom_df()
    fb_data_val = data_sets.create_symptom_df(valid=True)

    if fips is None:
        case_data = data_sets.create_case_df_state()
        case_data_geo = case_data.loc[state]['case_rate'].copy()
        fb_data_geo = fb_data.loc[state].groupby('date').sum().copy()
        fb_data_val_geo = fb_data_val.loc[state].groupby('date').sum().copy()

    else:
        case_data = data_sets.create_case_df_county()
        case_data_geo = case_data.loc[fips]['case_rate'].copy()
        
        if fips == 'New York City':
            nyc_fips = ['36005', '36061', '36047', '36085']
            fb_data_geo = fb_data.loc[(slice(None), '36081'), :].copy()
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val.loc[(slice(None), '36081'), :].copy()
            fb_data_val_geo = fb_data_val.mean(level='date')
            for borough in nyc_fips:
                fb_data_geo += fb_data.loc[(slice(None), borough), :].copy().mean(level='date')
                fb_data_val_geo += fb_data_val.loc[(slice(None), borough), :].copy().mean(level='date')
        else:
            fb_data_geo = fb_data.loc[(slice(None), fips), :].copy()
            fb_data_val_geo = fb_data_val.loc[(slice(None), fips), :].copy()
            # collapse down to a single index column (date)
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val_geo.mean(level='date')


    return seiir, fb_data_geo, fb_data_val_geo, case_data_geo


def calc_fb_ma7(fb_data):
    """
    Returns a Pandas series
    """
    # the fb_data is a DataFrame while the case_data is a Series
    fb_ma7 = fb_data.rolling(window=7).mean()
    fb_ma7 = fb_ma7.iloc[6:, :]
    prop_ma7 = fb_ma7['num_stl'].div(fb_ma7['n'])

    return prop_ma7, fb_ma7['n'].copy(), fb_ma7['num_stl'].copy()

In [3]:
# set constants
K0 = datetime.date(2020, 4, 12)

the_state = 'NY'
the_county = data_sets.get_fips(the_state, 'Albany')
#the_county = 'New York City'
constants = {
    'alpha': 0.948786,
    'gamma1': 0.500000,
    'gamma2': 0.662215,
    'sigma': 0.266635,
    'theta': 6.000000
    }

# set initial values for Kalman filter parameters
P_mult = 1
Q_mult = 1

# Rn is the R noise matrix; it remains constant thru the stepping of the
# Kalman filter
# prior to experiments, Rn_mult = 5*10**-4
Rn_mult = 5*10**-8

Rn_22 = 88
Rn_32 = 150

Rn_23 = 0
Rn_33 = 1

Rn = Rn_mult * np.array([[0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0],
                         [0, 0, Rn_22, 0, 0],
                         [0, 0, Rn_32, Rn_33, 0],
                         [0, 0, 0, 0, 0]])
Q = Q_mult * np.eye(5)
P = P_mult * np.eye(5)

# generate data
seiir, fb_data, fb_data_val, case_data = get_data_sets(
    data_sets.STATES[the_state], fips=the_county)


if the_county == 'New York City':
    county_pop = 0
    for each in ['36081', '36005', '36061', '36047', '36085']:
        this_count, state_pop = data_sets.get_pops(each)
        county_pop += this_count
else:
    county_pop, state_pop = data_sets.get_pops(the_county)

b = county_pop / state_pop

# calculate moving averages on the fb and case data
case_ma7 = case_data.rolling(window=7).mean()
case_ma7_all = case_ma7.iloc[6:]
prop_ma7, n_ma7, num_stl_ma7 = calc_fb_ma7(fb_data)
prop_ma7_valid, n_ma7_valid, num_stl_ma7_valid = calc_fb_ma7(fb_data_val)

# get starting compartment values for the state level
x_hat_state_k0, beta_k0 = get_predicts_prior(K0, seiir)
x_hat_k0 = b * x_hat_state_k0
I2_county = x_hat_k0[3, 0]
rho1 = prop_ma7.loc['2020-04-12'] / I2_county
rho2 = case_ma7_all.loc['2020-04-12'] / I2_county


# approximate rho values
# I2_county = b * I2
# prop_county = rho1 * I2_county
# case_county = rho2 * I2_county


# create empty dictionaries to hold the estimated values
prop_est = {}
case_est = {}
seiir_pred = {}



# Original data run ----------------
start = K0
d = start
while d <= datetime.date(2020, 9, 30):
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [prop_ma7.loc[d]],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

step -----------
P:
[[ 1.15255384 -0.15255384  0.13808601  0.04479925 -0.18288526]
 [-0.15255384  1.22364807 -0.20918024 -0.04479925  0.18288526]
 [ 0.13808601 -0.20918024  1.32109422 -0.25        0.        ]
 [ 0.04479925 -0.04479925 -0.25        1.68852871 -0.43852871]
 [-0.18288526  0.18288526  0.         -0.43852871  1.43852871]]
P_post:
[[-6.66133815e-16  5.55111512e-17  1.66533454e-16 -1.38083989e-15
  -4.71844785e-16]
 [ 8.32667268e-17  0.00000000e+00  8.32667268e-17 -5.41233725e-16
   1.94289029e-16]
 [ 1.38777878e-16  8.32667268e-17  1.23341048e+00 -4.44726865e-05
   1.97640969e-17]
 [-1.31145095e-15 -5.82867088e-16 -4.44884238e-05  2.79928156e-04
  -7.77156117e-16]
 [-4.71844785e-16  2.22044605e-16  3.35510485e-17 -8.32667268e-16
   0.00000000e+00]]
step -----------
P:
[[ 1.09293092e+00 -9.29309201e-02  1.69260125e-01 -1.69217329e-01
  -4.27957342e-05]
 [-9.29309201e-02  1.09293092e+00 -1.69260125e-01  1.69217329e-01
   4.27957342e-05]
 [ 1.69260127e-01 -1.69260127e-01  1.308

step -----------
P:
[[ 1.08598406e+00 -8.59840602e-02  1.60886637e-01 -1.60847234e-01
  -3.94029047e-05]
 [-8.59840601e-02  1.08598406e+00 -1.60886636e-01  1.60847233e-01
   3.94029050e-05]
 [ 1.60886639e-01 -1.60886639e-01  1.30110844e+00 -3.01127384e-01
   1.89411003e-05]
 [-1.60847240e-01  1.60847240e-01 -3.01127391e-01  1.30126908e+00
  -1.41692402e-04]
 [-3.93991710e-05  3.93991710e-05  1.89480881e-05 -1.41699389e-04
   1.00012275e+00]]
P_post:
[[ 2.22044605e-16 -8.32667268e-17 -1.38459169e-09  1.46188617e-12
   4.06575815e-20]
 [-9.71445147e-17 -6.66133815e-16  1.52225407e-09 -1.50110480e-12
  -6.77626358e-21]
 [ 1.38777878e-16  2.77555756e-17  1.20442151e+00 -5.72018617e-05
   1.35525272e-20]
 [-5.82867088e-16  0.00000000e+00 -5.72229655e-05  2.79916226e-04
  -2.71050543e-20]
 [ 2.71050543e-20 -1.35525272e-20  4.74411760e-09 -1.32761266e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.08575773e+00 -8.57577308e-02  1.60673934e-01 -1.60634582e-01
  -3.93518176e-05]
 [-8.57577307e-

step -----------
P:
[[ 1.08650483e+00 -8.65048326e-02  1.61360803e-01 -1.61321269e-01
  -3.95341819e-05]
 [-8.65048325e-02  1.08650483e+00 -1.61360803e-01  1.61321269e-01
   3.95341822e-05]
 [ 1.61360806e-01 -1.61360806e-01  1.30106251e+00 -3.01081433e-01
   1.89240797e-05]
 [-1.61321276e-01  1.61321276e-01 -3.01081440e-01  1.30122312e+00
  -1.41675375e-04]
 [-3.95304388e-05  3.95304388e-05  1.89310635e-05 -1.41682358e-04
   1.00012275e+00]]
P_post:
[[-6.66133815e-16  1.38777878e-16 -1.38835618e-09  1.46532786e-12
  -8.80914265e-20]
 [ 1.11022302e-16 -4.44089210e-16  1.52639867e-09 -1.50515711e-12
  -6.77626358e-21]
 [ 1.66533454e-16  5.55111512e-17  1.20425229e+00 -5.71545155e-05
  -1.01643954e-20]
 [-1.16573418e-15  0.00000000e+00 -5.71756069e-05  2.79916213e-04
   0.00000000e+00]
 [-8.13151629e-20 -6.77626358e-21  4.73814869e-09 -1.32594234e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.08633932e+00 -8.63393238e-02  1.61206515e-01 -1.61167019e-01
  -3.94961949e-05]
 [-8.63393237e-

P_post:
[[ 4.44089210e-16 -1.11022302e-16 -1.41583062e-09  1.49172341e-12
  -5.42101086e-20]
 [-9.71445147e-17  2.22044605e-16  1.55543892e-09 -1.53185797e-12
  -1.62630326e-19]
 [ 1.66533454e-16  8.32667268e-17  1.20323893e+00 -5.68709945e-05
  -6.77626358e-21]
 [-4.99600361e-16 -3.88578059e-16 -5.68920122e-05  2.79916134e-04
   8.13151629e-20]
 [-5.42101086e-20 -1.69406589e-19  4.70477857e-09 -1.31660457e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.09021288e+00 -9.02128763e-02  1.64713687e-01 -1.64673246e-01
  -4.04408357e-05]
 [-9.02128762e-02  1.09021288e+00 -1.64713686e-01  1.64673246e-01
   4.04408360e-05]
 [ 1.64713689e-01 -1.64713689e-01  1.30080973e+00 -3.00828562e-01
   1.88304125e-05]
 [-1.64673252e-01  1.64673252e-01 -3.00828569e-01  1.30097015e+00
  -1.41581673e-04]
 [-4.04370250e-05  4.04370250e-05  1.88373719e-05 -1.41588632e-04
   1.00012275e+00]]
P_post:
[[-2.22044605e-16 -2.49800181e-16 -1.41891304e-09  1.49516510e-12
   2.71050543e-19]
 [-2.63677968e-16 -4.440892

step -----------
P:
[[ 1.09052989e+00 -9.05298894e-02  1.64978457e-01 -1.64937922e-01
  -4.05358749e-05]
 [-9.05298893e-02  1.09052989e+00 -1.64978457e-01  1.64937921e-01
   4.05358752e-05]
 [ 1.64978460e-01 -1.64978460e-01  1.30072086e+00 -3.00739657e-01
   1.87974777e-05]
 [-1.64937928e-01  1.64937928e-01 -3.00739664e-01  1.30088121e+00
  -1.41548726e-04]
 [-4.05320625e-05  4.05320625e-05  1.88044269e-05 -1.41555675e-04
   1.00012275e+00]]
P_post:
[[ 6.66133815e-16 -3.60822483e-16 -1.41975653e-09  1.49541490e-12
   2.50721752e-19]
 [-3.88578059e-16  2.22044605e-16  1.56031549e-09 -1.53546620e-12
   1.35525272e-20]
 [ 2.49800181e-16 -5.55111512e-17  1.20295925e+00 -5.67927366e-05
   2.37169225e-20]
 [-6.93889390e-16  2.77555756e-17 -5.68137292e-05  2.79916112e-04
   0.00000000e+00]
 [ 2.37169225e-19  6.77626358e-21  4.69019775e-09 -1.31252412e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.09088490e+00 -9.08848973e-02  1.65306828e-01 -1.65266218e-01
  -4.06101383e-05]
 [-9.08848971e-

step -----------
P:
[[ 1.09497277e+00 -9.49727704e-02  1.68902435e-01 -1.68860842e-01
  -4.15934263e-05]
 [-9.49727703e-02  1.09497277e+00 -1.68902435e-01  1.68860842e-01
   4.15934265e-05]
 [ 1.68902438e-01 -1.68902438e-01  1.30045115e+00 -3.00469849e-01
   1.86975381e-05]
 [-1.68860849e-01  1.68860849e-01 -3.00469856e-01  1.30061130e+00
  -1.41448749e-04]
 [-4.15895339e-05  4.15895339e-05  1.87044620e-05 -1.41455673e-04
   1.00012275e+00]]
P_post:
[[ 8.88178420e-16 -4.57966998e-16 -1.45641540e-09  1.52858282e-12
   0.00000000e+00]
 [-4.57966998e-16 -2.22044605e-16  1.59900054e-09 -1.56899493e-12
   1.35525272e-20]
 [ 4.44089210e-16 -2.77555756e-16  1.20160506e+00 -5.64138587e-05
  -2.03287907e-20]
 [-1.55431223e-15  1.08246745e-15 -5.64347536e-05  2.79916006e-04
   8.13151629e-20]
 [ 0.00000000e+00  0.00000000e+00  4.64604447e-09 -1.30016898e-12
   2.22044605e-16]]
step -----------
P:
[[ 1.09548020e+00 -9.54802028e-02  1.69338987e-01 -1.69297269e-01
  -4.17182725e-05]
 [-9.54802027e-

step -----------
P:
[[ 1.09928897e+00 -9.92889660e-02  1.72593705e-01 -1.72551074e-01
  -4.26308451e-05]
 [-9.92889659e-02  1.09928897e+00 -1.72593705e-01  1.72551074e-01
   4.26308454e-05]
 [ 1.72593708e-01 -1.72593708e-01  1.30008908e+00 -3.00107644e-01
   1.85633687e-05]
 [-1.72551081e-01  1.72551081e-01 -3.00107651e-01  1.30024897e+00
  -1.41314530e-04]
 [-4.26268838e-05  4.26268838e-05  1.85702562e-05 -1.41321417e-04
   1.00012275e+00]]
P_post:
[[ 0.00000000e+00 -2.49800181e-16 -1.48869753e-09  1.55839230e-12
  -6.09863722e-20]
 [-2.35922393e-16 -4.44089210e-16  1.63379768e-09 -1.60088609e-12
  -6.77626358e-21]
 [ 3.88578059e-16  2.77555756e-17  1.20025101e+00 -5.60350105e-05
  -3.04931861e-20]
 [-1.63757896e-15 -1.94289029e-16 -5.60558028e-05  2.79915900e-04
   1.08420217e-19]
 [-8.13151629e-20 -1.35525272e-20  4.59612347e-09 -1.28619953e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.09969763e+00 -9.96976278e-02  1.72940939e-01 -1.72898213e-01
  -4.27259816e-05]
 [-9.96976277e-

step -----------
P:
[[ 1.10408249e+00 -1.04082489e-01  1.76589705e-01 -1.76545937e-01
  -4.37675111e-05]
 [-1.04082489e-01  1.10408249e+00 -1.76589704e-01  1.76545937e-01
   4.37675114e-05]
 [ 1.76589707e-01 -1.76589707e-01  1.29967775e+00 -2.99696166e-01
   1.84109481e-05]
 [-1.76545944e-01  1.76545944e-01 -2.99696173e-01  1.29983733e+00
  -1.41162053e-04]
 [-4.37634765e-05  4.37634765e-05  1.84177949e-05 -1.41168899e-04
   1.00012275e+00]]
P_post:
[[ 0.00000000e+00 -4.16333634e-16 -1.52341101e-09  1.59125491e-12
   5.42101086e-20]
 [-4.02455846e-16  0.00000000e+00  1.67138686e-09 -1.63352665e-12
   2.71050543e-20]
 [ 2.22044605e-16 -5.55111512e-17  1.19875973e+00 -5.56177667e-05
  -2.37169225e-20]
 [-8.32667268e-16  2.77555756e-17 -5.56384467e-05  2.79915783e-04
   1.35525272e-19]
 [ 4.74338450e-20  2.03287907e-20  4.54120304e-09 -1.27083096e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.10408441e+00 -1.04084413e-01  1.76594925e-01 -1.76551161e-01
  -4.37643691e-05]
 [-1.04084413e-

step -----------
P:
[[ 1.10694682e+00 -1.06946823e-01  1.78972301e-01 -1.78927905e-01
  -4.43960787e-05]
 [-1.06946823e-01  1.10694682e+00 -1.78972301e-01  1.78927905e-01
   4.43960790e-05]
 [ 1.78972304e-01 -1.78972304e-01  1.29957475e+00 -2.99593126e-01
   1.83727810e-05]
 [-1.78927912e-01  1.78927912e-01 -2.99593133e-01  1.29973426e+00
  -1.41123872e-04]
 [-4.43919941e-05  4.43919941e-05  1.83796181e-05 -1.41130708e-04
   1.00012275e+00]]
P_post:
[[-2.22044605e-16  2.35922393e-16 -1.54722521e-09  1.61401448e-12
  -9.48676901e-20]
 [ 2.22044605e-16  0.00000000e+00  1.69567119e-09 -1.65520375e-12
   0.00000000e+00]
 [-1.38777878e-16  2.77555756e-17  1.19795293e+00 -5.53920468e-05
   6.77626358e-21]
 [ 5.55111512e-16 -8.32667268e-17 -5.54126709e-05  2.79915720e-04
  -1.35525272e-19]
 [-8.80914265e-20  2.03287907e-20  4.51790456e-09 -1.26431222e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.10837245e+00 -1.08372452e-01  1.80135202e-01 -1.80090485e-01
  -4.47167336e-05]
 [-1.08372452e-

step -----------
P:
[[ 1.11210564e+00 -1.12105645e-01  1.83060600e-01 -1.83014970e-01
  -4.56297579e-05]
 [-1.12105644e-01  1.11210564e+00 -1.83060600e-01  1.83014970e-01
   4.56297582e-05]
 [ 1.83060603e-01 -1.83060603e-01  1.29899500e+00 -2.99013162e-01
   1.81579478e-05]
 [-1.83014977e-01  1.83014977e-01 -2.99013169e-01  1.29915408e+00
  -1.40908959e-04]
 [-4.56256071e-05  4.56256071e-05  1.81647274e-05 -1.40915738e-04
   1.00012275e+00]]
P_post:
[[ 0.00000000e+00  6.93889390e-17 -1.57943811e-09  1.64443459e-12
  -1.15196481e-19]
 [ 6.93889390e-17  2.22044605e-16  1.73190581e-09 -1.68559611e-12
  -1.35525272e-20]
 [-3.05311332e-16 -1.11022302e-16  1.19630693e+00 -5.49314994e-05
   1.35525272e-20]
 [ 1.22124533e-15  3.05311332e-16 -5.49519957e-05  2.79915591e-04
  -1.08420217e-19]
 [-1.15196481e-19 -2.03287907e-20  4.45129718e-09 -1.24567267e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.11107698e+00 -1.11076978e-01  1.82243703e-01 -1.82198308e-01
  -4.53952782e-05]
 [-1.11076978e-

step -----------
P:
[[ 1.10515643e+00 -1.05156430e-01  1.77477388e-01 -1.77433375e-01
  -4.40135103e-05]
 [-1.05156430e-01  1.10515643e+00 -1.77477388e-01  1.77433375e-01
   4.40135105e-05]
 [ 1.77477391e-01 -1.77477391e-01  1.29960680e+00 -2.99625184e-01
   1.83846527e-05]
 [-1.77433382e-01  1.77433382e-01 -2.99625191e-01  1.29976633e+00
  -1.41135748e-04]
 [-4.40094591e-05  4.40094591e-05  1.83914916e-05 -1.41142586e-04
   1.00012275e+00]]
P_post:
[[ 0.00000000e+00 -2.77555756e-17 -1.53143564e-09  1.59844360e-12
   1.35525272e-19]
 [-2.77555756e-17  4.44089210e-16  1.67965003e-09 -1.64018799e-12
  -1.62630326e-19]
 [ 1.38777878e-16 -1.38777878e-16  1.19843936e+00 -5.55281334e-05
   3.38813179e-20]
 [-7.49400542e-16  5.27355937e-16 -5.55487890e-05  2.79915758e-04
   2.71050543e-20]
 [ 1.28749008e-19 -1.62630326e-19  4.52971575e-09 -1.26761674e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.10505927e+00 -1.05059267e-01  1.77396277e-01 -1.77352285e-01
  -4.39922805e-05]
 [-1.05059267e-

step -----------
P:
[[ 1.11353323e+00 -1.13533231e-01  1.84251914e-01 -1.84206024e-01
  -4.58902431e-05]
 [-1.13533231e-01  1.11353323e+00 -1.84251914e-01  1.84206024e-01
   4.58902434e-05]
 [ 1.84251917e-01 -1.84251917e-01  1.29909051e+00 -2.99108705e-01
   1.81933434e-05]
 [-1.84206031e-01  1.84206031e-01 -2.99108712e-01  1.29924966e+00
  -1.40944368e-04]
 [-4.58860599e-05  4.58860599e-05  1.82001339e-05 -1.40951158e-04
   1.00012275e+00]]
P_post:
[[ 2.22044605e-16 -3.74700271e-16 -1.59474967e-09  1.65595315e-12
   2.30392962e-19]
 [-3.74700271e-16  2.22044605e-16  1.74643988e-09 -1.69922409e-12
  -6.77626358e-21]
 [ 1.94289029e-16 -8.32667268e-17  1.19599000e+00 -5.48428460e-05
   0.00000000e+00]
 [-8.88178420e-16  5.55111512e-17 -5.48633268e-05  2.79915566e-04
   1.35525272e-19]
 [ 2.23616698e-19 -6.77626358e-21  4.44984224e-09 -1.24526623e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.11448941e+00 -1.14489410e-01  1.84997395e-01 -1.84951284e-01
  -4.61115645e-05]
 [-1.14489410e-

step -----------
P:
[[ 1.12522930e+00 -1.25229297e-01  1.93243202e-01 -1.93194742e-01
  -4.84603730e-05]
 [-1.25229297e-01  1.12522930e+00 -1.93243202e-01  1.93194742e-01
   4.84603733e-05]
 [ 1.93243205e-01 -1.93243205e-01  1.29826645e+00 -2.98284337e-01
   1.78879814e-05]
 [-1.93194749e-01  1.93194749e-01 -2.98284344e-01  1.29842498e+00
  -1.40638892e-04]
 [-4.84560248e-05  4.84560248e-05  1.78946928e-05 -1.40645603e-04
   1.00012275e+00]]
P_post:
[[ 4.44089210e-16 -1.66533454e-16 -1.67622022e-09  1.72994952e-12
  -1.35525272e-19]
 [-1.66533454e-16  0.00000000e+00  1.83289398e-09 -1.77388659e-12
   1.96511644e-19]
 [ 2.77555756e-17  5.55111512e-17  1.19260223e+00 -5.38949992e-05
  -3.38813179e-20]
 [ 1.38777878e-16 -1.66533454e-16 -5.39152348e-05  2.79915301e-04
  -8.13151629e-20]
 [-1.35525272e-19  1.82959117e-19  4.33386840e-09 -1.21281395e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.12714169e+00 -1.27141690e-01  1.94675291e-01 -1.94626425e-01
  -4.88665086e-05]
 [-1.27141690e-

step -----------
P:
[[ 1.13173723e+00 -1.31737228e-01  1.98008177e-01 -1.97958282e-01
  -4.98947384e-05]
 [-1.31737228e-01  1.13173723e+00 -1.98008177e-01  1.97958282e-01
   4.98947387e-05]
 [ 1.98008180e-01 -1.98008180e-01  1.29768697e+00 -2.97704643e-01
   1.76732492e-05]
 [-1.97958290e-01  1.97958290e-01 -2.97704649e-01  1.29784507e+00
  -1.40424081e-04]
 [-4.98903119e-05  4.98903119e-05  1.76799040e-05 -1.40430735e-04
   1.00012275e+00]]
P_post:
[[ 8.88178420e-16 -1.94289029e-16 -1.71611939e-09  1.76467174e-12
  -1.15196481e-19]
 [-1.38777878e-16  2.22044605e-16  1.87620461e-09 -1.81057946e-12
  -8.13151629e-20]
 [ 2.49800181e-16  5.55111512e-17  1.19069693e+00 -5.33619140e-05
   3.38813179e-21]
 [-6.93889390e-16 -3.60822483e-16 -5.33820083e-05  2.79915152e-04
   0.00000000e+00]
 [-1.08420217e-19 -7.45388994e-20  4.26317226e-09 -1.19303070e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.13140376e+00 -1.31403758e-01  1.97753174e-01 -1.97703339e-01
  -4.98357483e-05]
 [-1.31403758e-

step -----------
P:
[[ 1.12949368e+00 -1.29493679e-01  1.96386462e-01 -1.96337065e-01
  -4.93969882e-05]
 [-1.29493679e-01  1.12949368e+00 -1.96386462e-01  1.96337065e-01
   4.93969885e-05]
 [ 1.96386465e-01 -1.96386465e-01  1.29790414e+00 -2.97921896e-01
   1.77537231e-05]
 [-1.96337073e-01  1.96337073e-01 -2.97921903e-01  1.29806241e+00
  -1.40504584e-04]
 [-4.93925878e-05  4.93925878e-05  1.77603982e-05 -1.40511259e-04
   1.00012275e+00]]
P_post:
[[ 2.22044605e-16 -1.11022302e-16 -1.70293132e-09  1.75312542e-12
   5.42101086e-20]
 [-1.11022302e-16  4.44089210e-16  1.86153806e-09 -1.79786741e-12
   1.35525272e-20]
 [ 8.32667268e-17 -2.77555756e-17  1.19135936e+00 -5.35472585e-05
  -1.01643954e-20]
 [-2.49800181e-16  0.00000000e+00 -5.35674014e-05  2.79915204e-04
   1.35525272e-19]
 [ 6.77626358e-20  2.03287907e-20  4.28785349e-09 -1.19993739e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.13002670e+00 -1.30026702e-01  1.96768986e-01 -1.96719466e-01
  -4.95196231e-05]
 [-1.30026701e-

P_post:
[[ 4.44089210e-16 -1.11022302e-16 -1.74082154e-09  1.78596027e-12
   1.49077799e-19]
 [-8.32667268e-17  6.66133815e-16  1.90173055e-09 -1.83070226e-12
  -5.42101086e-20]
 [ 2.49800181e-16 -4.44089210e-16  1.18976398e+00 -5.31008962e-05
  -6.77626358e-21]
 [-1.08246745e-15  1.41553436e-15 -5.31209263e-05  2.79915079e-04
   1.62630326e-19]
 [ 1.49077799e-19 -6.09863722e-20  4.23477588e-09 -1.18508485e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.13529887e+00 -1.35298866e-01  2.00584052e-01 -2.00533405e-01
  -5.06470544e-05]
 [-1.35298866e-01  1.13529887e+00 -2.00584052e-01  2.00533405e-01
   5.06470547e-05]
 [ 2.00584056e-01 -2.00584056e-01  1.29744099e+00 -2.97458577e-01
   1.75821047e-05]
 [-2.00533413e-01  2.00533413e-01 -2.97458584e-01  1.29759892e+00
  -1.40332902e-04]
 [-5.06425817e-05  5.06425817e-05  1.75887371e-05 -1.40339534e-04
   1.00012275e+00]]
P_post:
[[ 4.44089210e-16  8.32667268e-17 -1.73940790e-09  1.78593251e-12
   6.77626358e-21]
 [ 5.55111512e-17 -2.220446

In [None]:
# MASE calculation -----------------

left = predicted_case.index[0]
right = predicted_case.index[-1]

# take error between prediction and actual
e = (case_ma7_all.loc[left:right] - predicted_case.loc[left:right]).abs()

case_ma7_diff = case_ma7_all.diff(7)
denom = (case_ma7_diff.loc[left:right]).abs()

term = e.div(denom)
term.replace([np.inf, -np.inf], np.nan, inplace=True)
term.dropna(inplace=True)
mase = term.sum() / len(term)
print('mase:', mase)

In [None]:
# Not currently used ---------------
# Validation data run ---------------------
start = prop_ma7_valid.index[0]
d = start
while d <= prop_ma7_valid.index[-1]:
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [prop_ma7_valid.loc[d]],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

In [None]:
# Plotting constants and variables ----------------
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
purple = '#33016F'
gold = '#9E7A27'
gray = '#797979'
width = 4
%matplotlib qt

tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

In [None]:
# Single plot with all lines

matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]



fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, num_stl_ma7, c='red', label='Facebook Positive Symptoms',
         linewidth=width)
#plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Number of Positive Symptom Response per Day')
plt.legend(loc='upper right')

plt.show()

In [None]:
# plot validation data

fig_v, ax_v = plt.subplots(1)
plt.sca(ax_v)


plt.plot(case_ma7_all.loc[start:end].index, case_ma7_all.loc[start:end], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax_v.twinx()
plt.sca(ax32)
plt.plot(prop_ma7_valid.index, 100*prop_ma7_valid, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')

plt.show()

In [4]:
# compute squared error for the overlapping time period -------------------------


overlap = pd.date_range(start='2020-04-19', end='2020-07-07')
# convert to needed data type:
overlap_dt = [x.to_pydatetime().date() for x in overlap]

# error between the measured case rate (smoothed) and the
# SEIIR model scaled to the county (without any Kalman adjustment)
seiir_sq_err = 0

# error between the measured case rate (smoothed) and the predicted
# output of the Kalman filter
kalman_sq_err = 0

for day in overlap_dt:
    seiir_sq_err += (case_ma7_all[day] - predicted_seiir_prior[day])**2
    kalman_sq_err += (predicted_case[day] - case_ma7_all[day])**2

seiir_mse = seiir_sq_err / len(overlap_dt)
kalman_mse = kalman_sq_err / len(overlap_dt)

print('SSE between seiir forecast and case rate:', seiir_sq_err)
print('MSE between seiir forecast and measured case rate:', seiir_mse)
print()
print('SSE between kalman forecast and case rate:', kalman_sq_err)
print('MSE between kalman forecast and measured case rate:', kalman_mse)

SSE between seiir forecast and case rate: 24321.005802395495
MSE between seiir forecast and measured case rate: 304.01257252994367

SSE between kalman forecast and case rate: 10118.27370247036
MSE between kalman forecast and measured case rate: 126.4784212808795


In [None]:
# plot findings -- multiple plots
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=tick_start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

fig1, ax11 = plt.subplots(1)
plt.sca(ax11)
width = 4
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
plt.ylabel('Number of Cases per Day')
# plt.xlabel('Date')
plt.legend(loc='upper left')

fig2, ax21 = plt.subplots(1)
plt.sca(ax21)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, 100*prop_ma7, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')


plt.show()