Runs an extended Kalman filter on Prof Flaxman's SEIIR predictions and
measurement data for New York State from mid-April to 7 July 2020.

S - Susceptible
E - Exposed
I1 - Presymptomatic
I2 - Symptomatic
R - Recovered

In [None]:
#%load_ext autoreload
#%autoreload

In [None]:
#%reset

In [1]:
import math
import datetime

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

import data_sets
import seiir_compartmental

In [2]:
# functions to support the Kalman filtering
def get_predicts_prior(day, seiir):
    x_hat = np.array([[seiir['S'].loc[day]],
                      [seiir['E'].loc[day]],
                      [seiir['I1'].loc[day]],
                      [seiir['I2'].loc[day]],
                      [seiir['R'].loc[day]]])

    beta_k = seiir['beta_pred'].loc[day]

    return x_hat, beta_k


def step_seiir(x_hat, constants, beta_k, days=7):
    s_dict = {'S': x_hat[0, 0],
              'E': x_hat[1, 0],
              'I1': x_hat[2, 0],
              'I2': x_hat[3, 0],
              'R': x_hat[4, 0]}

    s = pd.Series(s_dict)

    for i in range(days):
        infectious = s.loc['I1'] + s.loc['I2']
        s = seiir_compartmental.compartmental_covid_step(s, s.sum(),
                                                         infectious,
                                                         constants['alpha'],
                                                         beta_k,
                                                         constants['gamma1'],
                                                         constants['gamma2'],
                                                         constants['sigma'],
                                                         constants['theta'])
    x_hat_future_prior = np.array([[s.loc['S']],
                                   [s.loc['E']],
                                   [s.loc['I1']],
                                   [s.loc['I2']],
                                   [s.loc['R']]])

    return x_hat_future_prior


def predict_step(x_hat_k1_prior, P, Q, beta_k, constants):
    S = x_hat_k1_prior[0, 0]
    E = x_hat_k1_prior[1, 0]
    I1 = x_hat_k1_prior[2, 0]
    I2 = x_hat_k1_prior[3, 0]
    R = x_hat_k1_prior[4, 0]
    N = S + E + I1 + I2 + R
    alpha = constants['alpha']
    sigma = constants['sigma']
    gamma1 = constants['gamma1']
    gamma2 = constants['gamma2']

    part_f_S = np.array([[-beta_k * math.pow(I1 + I2, alpha) / N],
                         [beta_k * math.pow(I1 + I2, alpha) / N],
                         [0],
                         [0],
                         [0]])

    part_f_E = np.array([[0],
                         [-sigma],
                         [sigma],
                         [0],
                         [0]])

    part_f_I1 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [-gamma1],
                          [gamma1],
                          [0]])

    part_f_I2 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [0],
                          [-gamma2],
                          [gamma2]])

    part_f_R = np.array([[0],
                         [0],
                         [0],
                         [0],
                         [0]])

    # 5x5
    f_jacob = np.concatenate([part_f_S, part_f_E, part_f_I1, part_f_I2,
                              part_f_R], axis=1)

    # 5x5
    # P_k1_prior = f_jacob * P * f_jacob^T + Q
    P_k1_prior = np.matmul(np.matmul(f_jacob, P), np.transpose(f_jacob)) + Q
    return P_k1_prior


def update_step(x_hat, x_hat_k1, P_k1, Rn, rho1, rho2, z_k):
    # 5x5
    ep = 10**-10
    H = np.array([[ep, 0, 0, 0, 0],
                  [0, ep, 0, 0, 0],
                  [0, 0, ep, 0, 0],
                  [0, 0, 0, rho2, 0],
                  [0, 0, 0, 0, ep]])
    #H = np.array([[0, 0, 0, 0, 0],
    #                  [0, 0, 0, 0, 0],
    #                  [0, 0, 0, rho1, 0],
    #                  [0, 0, 0, rho2, 0],
    #                  [0, 0, 0, 0, 0]])
    
    # Si = H * P_k1 * H^T + Rn
    Si = np.matmul(np.matmul(H, P_k1), np.transpose(H)) + Rn

    # K_new = P_k1 * H^T * Si^(-1)
    K_new = np.matmul(np.matmul(P_k1, np.transpose(H)), np.linalg.inv(Si))
    y_new = np.matmul(H, x_hat)

    # 5x1
    diff = z_k - y_new

    x_hat_k1_post = x_hat_k1 + np.matmul(K_new, diff)

    P_k1_post = P_k1 - np.matmul(np.matmul(K_new, Si), np.transpose(K_new))

    # joseph formulation
    #joe2 = np.matmul(np.matmul(K_new, Rn), np.transpose(K_new))
    #joe1 = np.eye(5) - np.matmul(K_new, H)
    #P_k1_post = np.matmul(np.matmul(joe1, P_k1), np.transpose(joe1)) - joe2

    return x_hat_k1_post, P_k1_post

def get_predict_measures():
    predictions = pd.read_csv(
        r'case_vs_symptom/kalman/predicted_measures/predictions.csv',
        header=None, names=['measure_date', 'prediction_date', 'prop', 'case',
                            'prop_pred7', 'case_pred7', 'prop_pred1',
                            'case_pred1'],
        index_col=False, usecols=[0, 4, 5])

    start_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[0]))
    start = datetime.date(2019, 12, 31) + start_d
    end_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[-1]))
    end = datetime.date(2019, 12, 31) + end_d

    measure_rng = pd.date_range(start=start, end=end)

    # code to save:
    # prediction_rng = pd.date_range(start=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[0])),
    # end=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[-1])))

    predictions.set_index(measure_rng, inplace=True)

    return predictions


def get_data_sets(state, fips=None):
    # the post hoc seiir model predictions provided by Prof Flaxman without
    # any kalman filtering:
    #seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state_20200910.csv', header=0,
    #                    index_col='date', parse_dates=True)
    seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state.csv', header=0,
                        index_col='date', parse_dates=True)
    

    fb_data = data_sets.create_symptom_df()
    fb_data_val = data_sets.create_symptom_df(valid=True)

    if fips is None:
        case_data = data_sets.create_case_df_state()
        case_data_geo = case_data.loc[state]['case_rate'].copy()
        fb_data_geo = fb_data.loc[state].groupby('date').sum().copy()
        fb_data_val_geo = fb_data_val.loc[state].groupby('date').sum().copy()

    else:
        case_data = data_sets.create_case_df_county()
        case_data_geo = case_data.loc[fips]['case_rate'].copy()
        
        if fips == 'New York City':
            nyc_fips = ['36005', '36061', '36047', '36085']
            fb_data_geo = fb_data.loc[(slice(None), '36081'), :].copy()
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val.loc[(slice(None), '36081'), :].copy()
            fb_data_val_geo = fb_data_val.mean(level='date')
            for borough in nyc_fips:
                fb_data_geo += fb_data.loc[(slice(None), borough), :].copy().mean(level='date')
                fb_data_val_geo += fb_data_val.loc[(slice(None), borough), :].copy().mean(level='date')
        else:
            fb_data_geo = fb_data.loc[(slice(None), fips), :].copy()
            fb_data_val_geo = fb_data_val.loc[(slice(None), fips), :].copy()
            # collapse down to a single index column (date)
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val_geo.mean(level='date')


    return seiir, fb_data_geo, fb_data_val_geo, case_data_geo


def calc_fb_ma7(fb_data):
    """
    Returns a Pandas series
    """
    # the fb_data is a DataFrame while the case_data is a Series
    fb_ma7 = fb_data.rolling(window=7).mean()
    fb_ma7 = fb_ma7.iloc[6:, :]
    prop_ma7 = fb_ma7['num_stl'].div(fb_ma7['n'])

    return prop_ma7, fb_ma7['n'].copy(), fb_ma7['num_stl'].copy()

In [3]:
# set constants
K0 = datetime.date(2020, 4, 12)

the_state = 'NY'
the_county = data_sets.get_fips(the_state, 'Albany')
#the_county = 'New York City'
constants = {
    'alpha': 0.948786,
    'gamma1': 0.500000,
    'gamma2': 0.662215,
    'sigma': 0.266635,
    'theta': 6.000000
    }

# set initial values for Kalman filter parameters
P_mult = 1
Q_mult = 1

# Rn is the R noise matrix; it remains constant thru the stepping of the
# Kalman filter
# prior to experiments, Rn_mult = 5*10**-4
Rn_mult = 5*10**-8

Rn_22 = 88
Rn_32 = 150

Rn_23 = 0
Rn_33 = 1

Rn = Rn_mult * np.array([[0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0],
                         [0, 0, Rn_22, 0, 0],
                         [0, 0, Rn_32, Rn_33, 0],
                         [0, 0, 0, 0, 0]])
Q = Q_mult * np.eye(5)
P = P_mult * np.eye(5)

# generate data
seiir, fb_data, fb_data_val, case_data = get_data_sets(
    data_sets.STATES[the_state], fips=the_county)


if the_county == 'New York City':
    county_pop = 0
    for each in ['36081', '36005', '36061', '36047', '36085']:
        this_count, state_pop = data_sets.get_pops(each)
        county_pop += this_count
else:
    county_pop, state_pop = data_sets.get_pops(the_county)

b = county_pop / state_pop

# calculate moving averages on the fb and case data
case_ma7 = case_data.rolling(window=7).mean()
case_ma7_all = case_ma7.iloc[6:]
prop_ma7, n_ma7, num_stl_ma7 = calc_fb_ma7(fb_data)
prop_ma7_valid, n_ma7_valid, num_stl_ma7_valid = calc_fb_ma7(fb_data_val)

# get starting compartment values for the state level
x_hat_state_k0, beta_k0 = get_predicts_prior(K0, seiir)
x_hat_k0 = b * x_hat_state_k0
I2_county = x_hat_k0[3, 0]
rho1 = prop_ma7.loc['2020-04-12'] / I2_county
rho2 = case_ma7_all.loc['2020-04-12'] / I2_county


# approximate rho values
# I2_county = b * I2
# prop_county = rho1 * I2_county
# case_county = rho2 * I2_county


# create empty dictionaries to hold the estimated values
prop_est = {}
case_est = {}
seiir_pred = {}



# Original data run ----------------
start = K0
d = start
while d <= datetime.date(2020, 9, 30):
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [0],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

step -----------
P:
[[ 1.15261946 -0.15261946  0.13811573  0.04480889 -0.18292461]
 [-0.15261946  1.22371368 -0.20920995 -0.04480889  0.18292461]
 [ 0.13811573 -0.20920995  1.32109422 -0.25        0.        ]
 [ 0.04480889 -0.04480889 -0.25        1.68852871 -0.43852871]
 [-0.18292461  0.18292461  0.         -0.43852871  1.43852871]]
P_post:
[[ 0.00000000e+00  3.88578059e-16  9.43689571e-16 -6.66133815e-15
   4.02455846e-15]
 [ 4.16333634e-16  0.00000000e+00 -1.38777878e-16  1.30451205e-15
  -1.11022302e-16]
 [ 8.88178420e-16 -1.38777878e-16  1.23339885e+00 -4.44279673e-05
   2.38889244e-16]
 [-6.64052147e-15  1.34614542e-15 -4.44436887e-05  2.79647662e-04
   1.27675648e-15]
 [ 4.02455846e-15 -1.38777878e-16  2.67024913e-16  1.11022302e-15
  -6.66133815e-16]]
step -----------
P:
[[ 1.09295356e+00 -9.29535624e-02  1.69279965e-01 -1.69237206e-01
  -4.27582931e-05]
 [-9.29535624e-02  1.09295356e+00 -1.69279965e-01  1.69237206e-01
   4.27582931e-05]
 [ 1.69279967e-01 -1.69279967e-01  1.308

P_post:
[[ 4.44089210e-16 -4.16333634e-17 -1.38310077e-09  1.46069268e-12
   6.09863722e-20]
 [-6.93889390e-17 -4.44089210e-16  1.52100463e-09 -1.49641410e-12
  -1.69406589e-19]
 [-1.66533454e-16 -3.60822483e-16  1.20437215e+00 -5.71307502e-05
   2.03287907e-20]
 [ 1.19348975e-15  1.83186799e-15 -5.71518273e-05  2.79635752e-04
   0.00000000e+00]
 [ 6.09863722e-20 -1.62630326e-19  4.73549323e-09 -1.32387146e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.08594390e+00 -8.59438962e-02  1.60844959e-01 -1.60805601e-01
  -3.93582833e-05]
 [-8.59438961e-02  1.08594390e+00 -1.60844959e-01  1.60805601e-01
   3.93582836e-05]
 [ 1.60844962e-01 -1.60844962e-01  1.30109304e+00 -3.01111954e-01
   1.89164196e-05]
 [-1.60805608e-01  1.60805608e-01 -3.01111961e-01  1.30125351e+00
  -1.41544724e-04]
 [-3.93545550e-05  3.93545550e-05  1.89233986e-05 -1.41551703e-04
   1.00012263e+00]]
P_post:
[[ 6.66133815e-16 -1.24900090e-16 -1.38250078e-09  1.45855550e-12
  -1.55854062e-19]
 [-1.66533454e-16 -2.220446

step -----------
P:
[[ 1.08658531e+00 -8.65853075e-02  1.61434965e-01 -1.61395452e-01
  -3.95138199e-05]
 [-8.65853074e-02  1.08658531e+00 -1.61434965e-01  1.61395451e-01
   3.95138201e-05]
 [ 1.61434968e-01 -1.61434968e-01  1.30105917e+00 -3.01078073e-01
   1.89038830e-05]
 [-1.61395458e-01  1.61395458e-01 -3.01078080e-01  1.30121961e+00
  -1.41532183e-04]
 [-3.95100789e-05  3.95100789e-05  1.89108595e-05 -1.41539159e-04
   1.00012263e+00]]
P_post:
[[-8.88178420e-16  2.91433544e-16 -1.38772888e-09  1.46679890e-12
  -2.64274280e-19]
 [ 3.19189120e-16  0.00000000e+00  1.52566526e-09 -1.50293666e-12
  -1.35525272e-20]
 [-6.66133815e-16  8.32667268e-17  1.20422818e+00 -5.70905116e-05
  -1.01643954e-20]
 [ 2.74780199e-15 -1.38777878e-16 -5.71115802e-05  2.79635741e-04
  -5.42101086e-20]
 [-2.57498016e-19 -6.77626358e-21  4.73272863e-09 -1.32309886e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.08652125e+00 -8.65212473e-02  1.61374666e-01 -1.61335166e-01
  -3.94997609e-05]
 [-8.65212472e-

step -----------
P:
[[ 1.08832584e+00 -8.83258444e-02  1.63027070e-01 -1.62987139e-01
  -3.99310570e-05]
 [-8.83258443e-02  1.08832584e+00 -1.63027069e-01  1.62987138e-01
   3.99310573e-05]
 [ 1.63027072e-01 -1.63027072e-01  1.30097646e+00 -3.00995330e-01
   1.88732653e-05]
 [-1.62987145e-01  1.62987145e-01 -3.00995338e-01  1.30113684e+00
  -1.41501554e-04]
 [-3.99272824e-05  3.99272824e-05  1.88802340e-05 -1.41508523e-04
   1.00012263e+00]]
P_post:
[[ 2.22044605e-16  6.93889390e-17 -1.40281209e-09  1.47828971e-12
  -7.45388994e-20]
 [ 6.93889390e-17 -6.66133815e-16  1.54130650e-09 -1.51895163e-12
   2.03287907e-19]
 [ 5.55111512e-17  4.44089210e-16  1.20370253e+00 -5.69435960e-05
  -1.69406589e-20]
 [ 2.77555756e-17 -1.85962357e-15 -5.69646272e-05  2.79635700e-04
  -5.42101086e-20]
 [-6.77626358e-20  1.96511644e-19  4.71652846e-09 -1.31857042e-12
   2.22044605e-16]]
step -----------
P:
[[ 1.08911901e+00 -8.91190059e-02  1.63743591e-01 -1.63703467e-01
  -4.01235653e-05]
 [-8.91190058e-

step -----------
P:
[[ 1.09120937e+00 -9.12093688e-02  1.65582599e-01 -1.65541939e-01
  -4.06605933e-05]
 [-9.12093687e-02  1.09120937e+00 -1.65582599e-01  1.65541938e-01
   4.06605936e-05]
 [ 1.65582602e-01 -1.65582602e-01  1.30067057e+00 -3.00689331e-01
   1.87600282e-05]
 [-1.65541945e-01  1.65541945e-01 -3.00689338e-01  1.30083073e+00
  -1.41388275e-04]
 [-4.06567727e-05  4.06567727e-05  1.87669657e-05 -1.41395212e-04
   1.00012263e+00]]
P_post:
[[ 0.00000000e+00  3.60822483e-16 -1.42383960e-09  1.49830148e-12
  -2.03287907e-20]
 [ 3.60822483e-16 -4.44089210e-16  1.56464136e-09 -1.53893565e-12
   1.89735380e-19]
 [-1.11022302e-16  3.88578059e-16  1.20274625e+00 -5.66763008e-05
  -6.77626358e-21]
 [ 1.11022302e-16 -1.41553436e-15 -5.66972567e-05  2.79635625e-04
  -1.08420217e-19]
 [-2.03287907e-20  1.76182853e-19  4.67815139e-09 -1.30784177e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.09096778e+00 -9.09677774e-02  1.65367558e-01 -1.65326956e-01
  -4.06023737e-05]
 [-9.09677773e-

step -----------
P:
[[ 1.09277408e+00 -9.27740796e-02  1.66976186e-01 -1.66935158e-01
  -4.10282484e-05]
 [-9.27740795e-02  1.09277408e+00 -1.66976186e-01  1.66935158e-01
   4.10282487e-05]
 [ 1.66976189e-01 -1.66976189e-01  1.30059616e+00 -3.00614892e-01
   1.87324837e-05]
 [-1.66935165e-01  1.66935165e-01 -3.00614899e-01  1.30075626e+00
  -1.41360720e-04]
 [-4.10243985e-05  4.10243985e-05  1.87394144e-05 -1.41367651e-04
   1.00012263e+00]]
P_post:
[[-4.44089210e-16  3.88578059e-16 -1.43720605e-09  1.51154089e-12
  -5.42101086e-20]
 [ 3.74700271e-16  4.44089210e-16  1.57855615e-09 -1.55028768e-12
   0.00000000e+00]
 [-3.05311332e-16  5.55111512e-17  1.20227943e+00 -5.65458268e-05
  -3.38813179e-21]
 [ 1.05471187e-15 -1.38777878e-16 -5.65667499e-05  2.79635588e-04
  -2.71050543e-20]
 [-5.42101086e-20  6.77626358e-21  4.66397433e-09 -1.30387872e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.09342804e+00 -9.34280359e-02  1.67556321e-01 -1.67515141e-01
  -4.11798234e-05]
 [-9.34280358e-

step -----------
P:
[[ 1.09834820e+00 -9.83482009e-02  1.71792355e-01 -1.71749987e-01
  -4.23678967e-05]
 [-9.83482008e-02  1.09834820e+00 -1.71792354e-01  1.71749986e-01
   4.23678970e-05]
 [ 1.71792357e-01 -1.71792357e-01  1.30015280e+00 -3.00171369e-01
   1.85683578e-05]
 [-1.71749994e-01  1.71749994e-01 -3.00171376e-01  1.30031257e+00
  -1.41196534e-04]
 [-4.23639550e-05  4.23639550e-05  1.85752446e-05 -1.41203420e-04
   1.00012263e+00]]
P_post:
[[-2.22044605e-16 -4.02455846e-16 -1.47988372e-09  1.55034319e-12
  -4.74338450e-20]
 [-3.88578059e-16 -8.88178420e-16  1.62440900e-09 -1.59069979e-12
  -3.38813179e-20]
 [ 2.77555756e-17  5.55111512e-17  1.20053673e+00 -5.60587289e-05
   6.77626358e-21]
 [ 1.66533454e-16  0.00000000e+00 -5.60795213e-05  2.79635452e-04
  -5.42101086e-20]
 [-5.42101086e-20 -2.71050543e-20  4.60129300e-09 -1.28635622e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.09876787e+00 -9.87678735e-02  1.72153161e-01 -1.72110697e-01
  -4.24634653e-05]
 [-9.87678734e-

step -----------
P:
[[ 1.10343958e+00 -1.03439584e-01  1.76068396e-01 -1.76024832e-01
  -4.35638251e-05]
 [-1.03439584e-01  1.10343958e+00 -1.76068396e-01  1.76024832e-01
   4.35638253e-05]
 [ 1.76068399e-01 -1.76068399e-01  1.29976253e+00 -2.99780950e-01
   1.84238831e-05]
 [-1.76024839e-01  1.76024839e-01 -2.99780957e-01  1.29992201e+00
  -1.41052006e-04]
 [-4.35598026e-05  4.35598026e-05  1.84307315e-05 -1.41058854e-04
   1.00012263e+00]]
P_post:
[[ 8.88178420e-16 -2.63677968e-16 -1.51802224e-09  1.58367763e-12
   2.37169225e-19]
 [-2.49800181e-16  0.00000000e+00  1.66522388e-09 -1.62631020e-12
   6.77626358e-21]
 [ 4.16333634e-16 -2.77555756e-17  1.19897441e+00 -5.56220496e-05
   1.01643954e-20]
 [-1.72084569e-15  1.11022302e-16 -5.56427260e-05  2.79635330e-04
   5.42101086e-20]
 [ 2.37169225e-19  6.77626358e-21  4.54588661e-09 -1.27086736e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.10366778e+00 -1.03667777e-01  1.76256933e-01 -1.76213315e-01
  -4.36173448e-05]
 [-1.03667777e-

step -----------
P:
[[ 1.10447685e+00 -1.04476846e-01  1.76924955e-01 -1.76881149e-01
  -4.38054191e-05]
 [-1.04476846e-01  1.10447685e+00 -1.76924955e-01  1.76881149e-01
   4.38054194e-05]
 [ 1.76924958e-01 -1.76924958e-01  1.29968116e+00 -2.99699558e-01
   1.83937641e-05]
 [-1.76881156e-01  1.76881156e-01 -2.99699565e-01  1.29984059e+00
  -1.41021875e-04]
 [-4.38013807e-05  4.38013807e-05  1.84006044e-05 -1.41028715e-04
   1.00012263e+00]]
P_post:
[[-2.22044605e-16 -5.55111512e-17 -1.52555582e-09  1.59397495e-12
  -1.89735380e-19]
 [-4.16333634e-17 -4.44089210e-16  1.67329439e-09 -1.63363767e-12
  -2.03287907e-20]
 [-4.16333634e-16  5.55111512e-17  1.19865762e+00 -5.55335056e-05
  -6.77626358e-21]
 [ 2.05391260e-15 -1.38777878e-16 -5.55541583e-05  2.79635305e-04
  -8.13151629e-20]
 [-1.69406589e-19 -1.35525272e-20  4.53446337e-09 -1.26767412e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.10501071e+00 -1.05010711e-01  1.77371451e-01 -1.77327529e-01
  -4.39220952e-05]
 [-1.05010710e-

step -----------
P:
[[ 1.11389660e+00 -1.13896595e-01  1.84505445e-01 -1.84459487e-01
  -4.59582129e-05]
 [-1.13896595e-01  1.11389660e+00 -1.84505445e-01  1.84459487e-01
   4.59582132e-05]
 [ 1.84505448e-01 -1.84505448e-01  1.29895732e+00 -2.98975448e-01
   1.81258068e-05]
 [-1.84459495e-01  1.84459495e-01 -2.98975455e-01  1.29911621e+00
  -1.40753819e-04]
 [-4.59540348e-05  4.59540348e-05  1.81325767e-05 -1.40760589e-04
   1.00012263e+00]]
P_post:
[[ 2.22044605e-16 -3.88578059e-16 -1.59300456e-09  1.65131797e-12
   2.71050543e-19]
 [-4.16333634e-16 -4.44089210e-16  1.74542875e-09 -1.69608771e-12
  -6.77626358e-21]
 [ 4.71844785e-16  5.55111512e-17  1.19582657e+00 -5.47422066e-05
   2.03287907e-20]
 [-2.19269047e-15  5.55111512e-17 -5.47626508e-05  2.79635084e-04
   8.13151629e-20]
 [ 2.64274280e-19 -2.03287907e-20  4.43459989e-09 -1.23975732e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.11327741e+00 -1.13277410e-01  1.84003031e-01 -1.83957198e-01
  -4.58333266e-05]
 [-1.13277410e-

step -----------
P:
[[ 1.10593685e+00 -1.05936849e-01  1.78118163e-01 -1.78074015e-01
  -4.41486669e-05]
 [-1.05936849e-01  1.10593685e+00 -1.78118163e-01  1.78074014e-01
   4.41486672e-05]
 [ 1.78118166e-01 -1.78118166e-01  1.29955096e+00 -2.99569310e-01
   1.83455646e-05]
 [-1.78074022e-01  1.78074022e-01 -2.99569317e-01  1.29971029e+00
  -1.40973658e-04]
 [-4.41446075e-05  4.41446075e-05  1.83523914e-05 -1.40980485e-04
   1.00012263e+00]]
P_post:
[[ 0.00000000e+00 -4.16333634e-17 -1.53568139e-09  1.60060853e-12
   4.74338450e-20]
 [-4.16333634e-17  0.00000000e+00  1.68419440e-09 -1.64304681e-12
   0.00000000e+00]
 [ 1.94289029e-16  0.00000000e+00  1.19820483e+00 -5.54069449e-05
   6.77626358e-21]
 [-6.66133815e-16 -8.32667268e-17 -5.54275629e-05  2.79635270e-04
   0.00000000e+00]
 [ 5.42101086e-20  6.77626358e-21  4.51720157e-09 -1.26284850e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.10550768e+00 -1.05507684e-01  1.77757078e-01 -1.77713019e-01
  -4.40590786e-05]
 [-1.05507684e-

step -----------
P:
[[ 1.11163001e+00 -1.11630013e-01  1.82753569e-01 -1.82708162e-01
  -4.54064756e-05]
 [-1.11630013e-01  1.11163001e+00 -1.82753569e-01  1.82708162e-01
   4.54064759e-05]
 [ 1.82753572e-01 -1.82753572e-01  1.29926244e+00 -2.99280680e-01
   1.82387611e-05]
 [-1.82708170e-01  1.82708170e-01 -2.99280688e-01  1.29942155e+00
  -1.40866815e-04]
 [-4.54023228e-05  4.54023228e-05  1.82455615e-05 -1.40873615e-04
   1.00012263e+00]]
P_post:
[[-4.44089210e-16  3.05311332e-16 -1.58038876e-09  1.64335212e-12
  -1.49077799e-19]
 [ 3.19189120e-16 -4.44089210e-16  1.73075818e-09 -1.68409731e-12
   0.00000000e+00]
 [-4.44089210e-16  1.38777878e-16  1.19657072e+00 -5.49502157e-05
  -1.69406589e-20]
 [ 1.60982339e-15 -2.22044605e-16 -5.49707193e-05  2.79635142e-04
  -5.42101086e-20]
 [-1.49077799e-19  0.00000000e+00  4.46692198e-09 -1.24879360e-12
  -4.44089210e-16]]
step -----------
P:
[[ 1.11260794e+00 -1.12607939e-01  1.83515582e-01 -1.83469941e-01
  -4.56412440e-05]
 [-1.12607939e-

step -----------
P:
[[ 1.12562070e+00 -1.25620701e-01  1.93535041e-01 -1.93486544e-01
  -4.84972575e-05]
 [-1.25620701e-01  1.12562070e+00 -1.93535041e-01  1.93486543e-01
   4.84972578e-05]
 [ 1.93535044e-01 -1.93535044e-01  1.29823582e+00 -2.98253684e-01
   1.78587244e-05]
 [-1.93486551e-01  1.93486551e-01 -2.98253690e-01  1.29839418e+00
  -1.40486638e-04]
 [-4.84929082e-05  4.84929082e-05  1.78654266e-05 -1.40493339e-04
   1.00012263e+00]]
P_post:
[[ 0.00000000e+00 -3.05311332e-16 -1.67708755e-09  1.72720171e-12
   2.71050543e-20]
 [-3.05311332e-16  0.00000000e+00  1.83397081e-09 -1.77255433e-12
  -6.77626358e-21]
 [ 3.33066907e-16 -5.55111512e-17  1.19248896e+00 -5.38093413e-05
  -3.38813179e-21]
 [-1.66533454e-15  1.38777878e-16 -5.38295489e-05  2.79634823e-04
   2.71050543e-20]
 [ 1.35525272e-20 -6.77626358e-21  4.32588100e-09 -1.20936586e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.12717090e+00 -1.27170897e-01  1.94688426e-01 -1.94639594e-01
  -4.88323164e-05]
 [-1.27170897e-

P:
[[ 1.13193171e+00 -1.31931708e-01  1.98156230e-01 -1.98106350e-01
  -4.98796171e-05]
 [-1.31931708e-01  1.13193171e+00 -1.98156230e-01  1.98106350e-01
   4.98796174e-05]
 [ 1.98156233e-01 -1.98156233e-01  1.29769276e+00 -2.97710415e-01
   1.76576834e-05]
 [-1.98106358e-01  1.98106358e-01 -2.97710421e-01  1.29785071e+00
  -1.40285522e-04]
 [-4.98751914e-05  4.98751914e-05  1.76643321e-05 -1.40292171e-04
   1.00012263e+00]]
P_post:
[[ 0.00000000e+00  1.66533454e-16 -1.71623263e-09  1.76306192e-12
   8.80914265e-20]
 [ 1.66533454e-16  2.22044605e-16  1.87601906e-09 -1.80816473e-12
   2.03287907e-20]
 [ 5.55111512e-17  0.00000000e+00  1.19065242e+00 -5.32960097e-05
   2.03287907e-20]
 [-3.33066907e-16 -1.11022302e-16 -5.33160811e-05  2.79634680e-04
  -2.71050543e-20]
 [ 8.13151629e-20  2.03287907e-20  4.25814571e-09 -1.19043049e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.13187757e+00 -1.31877566e-01  1.98105697e-01 -1.98055818e-01
  -4.98791660e-05]
 [-1.31877566e-01  1.13187757e+0

step -----------
P:
[[ 1.12902052e+00 -1.29020525e-01  1.96015236e-01 -1.95965966e-01
  -4.92692855e-05]
 [-1.29020525e-01  1.12902052e+00 -1.96015236e-01  1.95965966e-01
   4.92692858e-05]
 [ 1.96015239e-01 -1.96015239e-01  1.29786727e+00 -2.97884992e-01
   1.77222862e-05]
 [-1.95965974e-01  1.95965974e-01 -2.97884999e-01  1.29802535e+00
  -1.40350149e-04]
 [-4.92648994e-05  4.92648994e-05  1.77289513e-05 -1.40356814e-04
   1.00012263e+00]]
P_post:
[[-4.44089210e-16  5.55111512e-17 -1.69619246e-09  1.74582571e-12
   3.38813179e-20]
 [ 5.55111512e-17 -4.44089210e-16  1.85512486e-09 -1.79031789e-12
   6.77626358e-21]
 [-5.55111512e-17  2.77555756e-17  1.19145656e+00 -5.35207677e-05
  -3.38813179e-21]
 [ 5.55111512e-17 -5.55111512e-17 -5.35408958e-05  2.79634743e-04
   0.00000000e+00]
 [ 2.71050543e-20  1.35525272e-20  4.28415022e-09 -1.19769974e-12
  -2.22044605e-16]]
step -----------
P:
[[ 1.12937417e+00 -1.29374168e-01  1.96282658e-01 -1.96233320e-01
  -4.93377850e-05]
 [-1.29374168e-

[[ 6.66133815e-16 -4.16333634e-16 -1.73517206e-09  1.77888260e-12
  -1.69406589e-19]
 [-4.16333634e-16  0.00000000e+00  1.89620852e-09 -1.82534543e-12
   5.42101086e-20]
 [ 3.88578059e-16  0.00000000e+00  1.18985311e+00 -5.30726009e-05
  -3.72694497e-20]
 [-1.24900090e-15 -8.32667268e-17 -5.30926162e-05  2.79634618e-04
   8.13151629e-20]
 [-1.62630326e-19  6.09863722e-20  4.23179809e-09 -1.18306498e-12
   0.00000000e+00]]
step -----------
P:
[[ 1.13550415e+00 -1.35504153e-01  2.00743711e-01 -2.00693084e-01
  -5.06272209e-05]
 [-1.35504153e-01  1.13550415e+00 -2.00743711e-01  2.00693084e-01
   5.06272212e-05]
 [ 2.00743714e-01 -2.00743714e-01  1.29746328e+00 -2.97480851e-01
   1.75727359e-05]
 [-2.00693092e-01  2.00693092e-01 -2.97480858e-01  1.29762106e+00
  -1.40200543e-04]
 [-5.06227483e-05  5.06227483e-05  1.75793634e-05 -1.40207170e-04
   1.00012263e+00]]
P_post:
[[ 4.44089210e-16 -3.05311332e-16 -1.74001505e-09  1.78607129e-12
  -8.80914265e-20]
 [-2.77555756e-16 -4.44089210e-16  

In [None]:
# MASE calculation -----------------

left = predicted_case.index[0]
right = predicted_case.index[-1]

# take error between prediction and actual
e = (case_ma7_all.loc[left:right] - predicted_case.loc[left:right]).abs()

case_ma7_diff = case_ma7_all.diff(7)
denom = (case_ma7_diff.loc[left:right]).abs()

term = e.div(denom)
term.replace([np.inf, -np.inf], np.nan, inplace=True)
term.dropna(inplace=True)
mase = term.sum() / len(term)
print('mase:', mase)

In [None]:
# Not currently used ---------------
# Validation data run ---------------------
start = prop_ma7_valid.index[0]
d = start
while d <= prop_ma7_valid.index[-1]:
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [prop_ma7_valid.loc[d]],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

In [4]:
# Plotting constants and variables ----------------
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
purple = '#33016F'
gold = '#9E7A27'
gray = '#797979'
width = 4
%matplotlib qt

tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

In [5]:
# Single plot with all lines

matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]



fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, num_stl_ma7, c='red', label='Facebook Positive Symptoms',
         linewidth=width)
#plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Number of Positive Symptom Response per Day')
plt.legend(loc='upper right')

plt.show()

In [None]:
# plot validation data

fig_v, ax_v = plt.subplots(1)
plt.sca(ax_v)


plt.plot(case_ma7_all.loc[start:end].index, case_ma7_all.loc[start:end], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax_v.twinx()
plt.sca(ax32)
plt.plot(prop_ma7_valid.index, 100*prop_ma7_valid, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')

plt.show()

In [10]:
# compute squared error for the overlapping time period ----------------------


overlap = pd.date_range(start='2020-04-19', end='2020-07-07')
# convert to needed data type:
overlap_dt = [x.to_pydatetime().date() for x in overlap]

# error between the measured case rate (smoothed) and the
# SEIIR model scaled to the county (without any Kalman adjustment)
seiir_sq_err = 0

# error between the measured case rate (smoothed) and the predicted
# output of the Kalman filter
kalman_sq_err = 0

for day in overlap_dt:
    seiir_sq_err += (case_ma7_all[day] - predicted_seiir_prior[day])**2
    kalman_sq_err += (predicted_case[day] - case_ma7_all[day])**2
    
seiir_mse = seiir_sq_err / len(overlap_dt)
kalman_mse = kalman_sq_err / len(overlap_dt)

print('SSE between seiir forecast and measured case rate:', seiir_sq_err)
print('MSE between seiir forecast and measured case rate:', seiir_mse)
print()
print('SSE between kalman forecast and measured case rate:', kalman_sq_err)
print('MSE between kalman forecast and measured case rate:', kalman_mse)

SSE between seiir forecast and measured case rate: 24310.658825013477
MSE between seiir forecast and measured case rate: 303.8832353126685

SSE between kalman forecast and measured case rate: 10119.845325075148
MSE between kalman forecast and measured case rate: 126.49806656343935


In [None]:
# plot findings -- multiple plots
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=tick_start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

fig1, ax11 = plt.subplots(1)
plt.sca(ax11)
width = 4
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
plt.ylabel('Number of Cases per Day')
# plt.xlabel('Date')
plt.legend(loc='upper left')

fig2, ax21 = plt.subplots(1)
plt.sca(ax21)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, 100*prop_ma7, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')


plt.show()