Runs an extended Kalman filter on Prof Flaxman's SEIIR predictions and
measurement data for New York State from mid-April to 7 July 2020.

S - Susceptible
E - Exposed
I1 - Presymptomatic
I2 - Symptomatic
R - Recovered

In [None]:
#%load_ext autoreload
#%autoreload

In [None]:
#%reset

In [1]:
import math
import datetime

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt

import data_sets
import seiir_compartmental

In [2]:
# functions to support the Kalman filtering
def get_predicts_prior(day, seiir):
    x_hat = np.array([[seiir['S'].loc[day]],
                      [seiir['E'].loc[day]],
                      [seiir['I1'].loc[day]],
                      [seiir['I2'].loc[day]],
                      [seiir['R'].loc[day]]])

    beta_k = seiir['beta_pred'].loc[day]

    return x_hat, beta_k


def step_seiir(x_hat, constants, beta_k, days=7):
    s_dict = {'S': x_hat[0, 0],
              'E': x_hat[1, 0],
              'I1': x_hat[2, 0],
              'I2': x_hat[3, 0],
              'R': x_hat[4, 0]}

    s = pd.Series(s_dict)

    for i in range(days):
        infectious = s.loc['I1'] + s.loc['I2']
        s = seiir_compartmental.compartmental_covid_step(s, s.sum(),
                                                         infectious,
                                                         constants['alpha'],
                                                         beta_k,
                                                         constants['gamma1'],
                                                         constants['gamma2'],
                                                         constants['sigma'],
                                                         constants['theta'])
    x_hat_future_prior = np.array([[s.loc['S']],
                                   [s.loc['E']],
                                   [s.loc['I1']],
                                   [s.loc['I2']],
                                   [s.loc['R']]])

    return x_hat_future_prior


def predict_step(x_hat_k1_prior, P, Q, beta_k, constants):
    S = x_hat_k1_prior[0, 0]
    E = x_hat_k1_prior[1, 0]
    I1 = x_hat_k1_prior[2, 0]
    I2 = x_hat_k1_prior[3, 0]
    R = x_hat_k1_prior[4, 0]
    N = S + E + I1 + I2 + R
    alpha = constants['alpha']
    sigma = constants['sigma']
    gamma1 = constants['gamma1']
    gamma2 = constants['gamma2']

    part_f_S = np.array([[-beta_k * math.pow(I1 + I2, alpha) / N],
                         [beta_k * math.pow(I1 + I2, alpha) / N],
                         [0],
                         [0],
                         [0]])

    part_f_E = np.array([[0],
                         [-sigma],
                         [sigma],
                         [0],
                         [0]])

    part_f_I1 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [-gamma1],
                          [gamma1],
                          [0]])

    part_f_I2 = np.array([[-alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [alpha * beta_k * S * math.pow(I1+I2, alpha-1) / N],
                          [0],
                          [-gamma2],
                          [gamma2]])

    part_f_R = np.array([[0],
                         [0],
                         [0],
                         [0],
                         [0]])

    # 5x5
    f_jacob = np.concatenate([part_f_S, part_f_E, part_f_I1, part_f_I2,
                              part_f_R], axis=1)

    # 5x5
    # P_k1_prior = f_jacob * P * f_jacob^T + Q
    P_k1_prior = np.matmul(np.matmul(f_jacob, P), np.transpose(f_jacob)) + Q
    return P_k1_prior


def update_step(x_hat, x_hat_k1, P_k1, Rn, rho1, rho2, z_k):
    # 5x5
    ep = 10**-10
    H = np.array([[ep, 0, 0, 0, 0],
                  [0, ep, 0, 0, 0],
                  [0, 0, ep, rho1, 0],
                  [0, 0, 0, rho2, 0],
                  [0, 0, 0, 0, ep]])
    #H = np.array([[0, 0, 0, 0, 0],
    #                  [0, 0, 0, 0, 0],
    #                  [0, 0, 0, rho1, 0],
    #                  [0, 0, 0, rho2, 0],
    #                  [0, 0, 0, 0, 0]])
    
    # Si = H * P_k1 * H^T + Rn
    Si = np.matmul(np.matmul(H, P_k1), np.transpose(H)) + Rn

    # K_new = P_k1 * H^T * Si^(-1)
    K_new = np.matmul(np.matmul(P_k1, np.transpose(H)), np.linalg.inv(Si))
    y_new = np.matmul(H, x_hat)

    # 5x1
    diff = z_k - y_new

    x_hat_k1_post = x_hat_k1 + np.matmul(K_new, diff)

    P_k1_post = P_k1 - np.matmul(np.matmul(K_new, Si), np.transpose(K_new))

    # joseph formulation
    #joe2 = np.matmul(np.matmul(K_new, Rn), np.transpose(K_new))
    #joe1 = np.eye(5) - np.matmul(K_new, H)
    #P_k1_post = np.matmul(np.matmul(joe1, P_k1), np.transpose(joe1)) - joe2

    return x_hat_k1_post, P_k1_post

def get_predict_measures():
    predictions = pd.read_csv(
        r'case_vs_symptom/kalman/predicted_measures/predictions.csv',
        header=None, names=['measure_date', 'prediction_date', 'prop', 'case',
                            'prop_pred7', 'case_pred7', 'prop_pred1',
                            'case_pred1'],
        index_col=False, usecols=[0, 4, 5])

    start_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[0]))
    start = datetime.date(2019, 12, 31) + start_d
    end_d = datetime.timedelta(days=int(predictions['measure_date'].iloc[-1]))
    end = datetime.date(2019, 12, 31) + end_d

    measure_rng = pd.date_range(start=start, end=end)

    # code to save:
    # prediction_rng = pd.date_range(start=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[0])),
    # end=datetime.date(2019, 12, 31) +
    # datetime.timedelta(days=int(predictions['prediction_date'].iloc[-1])))

    predictions.set_index(measure_rng, inplace=True)

    return predictions


def get_data_sets(state, fips=None):
    # the post hoc seiir model predictions provided by Prof Flaxman without
    # any kalman filtering:
    #seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state_20200910.csv', header=0,
    #                    index_col='date', parse_dates=True)
    seiir = pd.read_csv(r'data/seiir_compartments_post-hoc_ny_state.csv', header=0,
                        index_col='date', parse_dates=True)
    

    fb_data = data_sets.create_symptom_df()
    fb_data_val = data_sets.create_symptom_df(valid=True)

    if fips is None:
        case_data = data_sets.create_case_df_state()
        case_data_geo = case_data.loc[state]['case_rate'].copy()
        fb_data_geo = fb_data.loc[state].groupby('date').sum().copy()
        fb_data_val_geo = fb_data_val.loc[state].groupby('date').sum().copy()

    else:
        case_data = data_sets.create_case_df_county()
        case_data_geo = case_data.loc[fips]['case_rate'].copy()
        
        if fips == 'New York City':
            nyc_fips = ['36005', '36061', '36047', '36085']
            fb_data_geo = fb_data.loc[(slice(None), '36081'), :].copy()
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val.loc[(slice(None), '36081'), :].copy()
            fb_data_val_geo = fb_data_val.mean(level='date')
            for borough in nyc_fips:
                fb_data_geo += fb_data.loc[(slice(None), borough), :].copy().mean(level='date')
                fb_data_val_geo += fb_data_val.loc[(slice(None), borough), :].copy().mean(level='date')
        else:
            fb_data_geo = fb_data.loc[(slice(None), fips), :].copy()
            fb_data_val_geo = fb_data_val.loc[(slice(None), fips), :].copy()
            # collapse down to a single index column (date)
            fb_data_geo = fb_data_geo.mean(level='date')
            fb_data_val_geo = fb_data_val_geo.mean(level='date')


    return seiir, fb_data_geo, fb_data_val_geo, case_data_geo


def calc_fb_ma7(fb_data):
    """
    Returns a Pandas series
    """
    # the fb_data is a DataFrame while the case_data is a Series
    fb_ma7 = fb_data.rolling(window=7).mean()
    fb_ma7 = fb_ma7.iloc[6:, :]
    prop_ma7 = fb_ma7['num_stl'].div(fb_ma7['n'])

    return prop_ma7, fb_ma7['n'].copy(), fb_ma7['num_stl'].copy()

In [3]:
# set constants
K0 = datetime.date(2020, 4, 12)

the_state = 'NY'
#the_county = data_sets.get_fips(the_state, 'New York City')
the_county = 'New York City'
constants = {
    'alpha': 0.948786,
    'gamma1': 0.500000,
    'gamma2': 0.662215,
    'sigma': 0.266635,
    'theta': 6.000000
    }

# set initial values for Kalman filter parameters
P_mult = 1
Q_mult = 1

# Rn is the R noise matrix; it remains constant thru the stepping of the
# Kalman filter
# prior to experiments, Rn_mult = 5*10**-4
Rn_mult = 5*10**-8

Rn_22 = 88
Rn_32 = 150

Rn_23 = 0
Rn_33 = 1

Rn = Rn_mult * np.array([[0, 0, 0, 0, 0],
                         [0, 0, 0, 0, 0],
                         [0, 0, Rn_22, 0, 0],
                         [0, 0, Rn_32, Rn_33, 0],
                         [0, 0, 0, 0, 0]])
Q = Q_mult * np.eye(5)
P = P_mult * np.eye(5)

# generate data
seiir, fb_data, fb_data_val, case_data = get_data_sets(
    data_sets.STATES[the_state], fips=the_county)


if the_county == 'New York City':
    county_pop = 0
    for each in ['36081', '36005', '36061', '36047', '36085']:
        this_count, state_pop = data_sets.get_pops(each)
        county_pop += this_count
else:
    county_pop, state_pop = data_sets.get_pops(the_county)

b = county_pop / state_pop

# calculate moving averages on the fb and case data
case_ma7 = case_data.rolling(window=7).mean()
case_ma7_all = case_ma7.iloc[6:]
prop_ma7, n_ma7, num_stl_ma7 = calc_fb_ma7(fb_data)
prop_ma7_valid, n_ma7_valid, num_stl_ma7_valid = calc_fb_ma7(fb_data_val)

# get starting compartment values for the state level
x_hat_state_k0, beta_k0 = get_predicts_prior(K0, seiir)
x_hat_k0 = b * x_hat_state_k0
I2_county = x_hat_k0[3, 0]
rho1 = prop_ma7.loc['2020-04-12'] / I2_county
rho2 = case_ma7_all.loc['2020-04-12'] / I2_county


# approximate rho values
# I2_county = b * I2
# prop_county = rho1 * I2_county
# case_county = rho2 * I2_county


# create empty dictionaries to hold the estimated values
prop_est = {}
case_est = {}
seiir_pred = {}



# Original data run ----------------
start = K0
d = start
while d <= datetime.date(2020, 9, 30):
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [prop_ma7.loc[d]],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

step -----------
P:
[[ 1.10878568 -0.10878568  0.11660686  0.03783076 -0.15443762]
 [-0.10878568  1.17987991 -0.18770108 -0.03783076  0.15443762]
 [ 0.11660686 -0.18770108  1.32109422 -0.25        0.        ]
 [ 0.03783076 -0.03783076 -0.25        1.68852871 -0.43852871]
 [-0.15443762  0.15443762  0.         -0.43852871  1.43852871]]
P_post:
[[-4.44089210e-16  2.22044605e-16  2.22044605e-16 -1.99840144e-15
  -3.60822483e-16]
 [ 2.22044605e-16  0.00000000e+00  1.11022302e-16 -5.62050406e-16
   1.11022302e-16]
 [ 2.22044605e-16  1.11022302e-16  1.24180214e+00 -7.15216268e-07
  -1.73160087e-17]
 [-1.99146255e-15 -5.62050406e-16 -7.17224716e-07  4.50154017e-06
  -3.33066907e-16]
 [-3.60822483e-16  1.38777878e-16 -1.20971525e-17 -3.33066907e-16
   0.00000000e+00]]
step -----------
P:
[[ 1.06667867e+00 -6.66786681e-02  1.43876173e-01 -1.43875592e-01
  -5.81009852e-07]
 [-6.66786681e-02  1.06667867e+00 -1.43876173e-01  1.43875592e-01
   5.81009852e-07]
 [ 1.43876173e-01 -1.43876173e-01  1.310

step -----------
P:
[[ 1.06165339e+00 -6.16533897e-02  1.36706398e-01 -1.36705868e-01
  -5.29647246e-07]
 [-6.16533897e-02  1.06165339e+00 -1.36706398e-01  1.36705868e-01
   5.29647247e-07]
 [ 1.36706398e-01 -1.36706398e-01  1.30312540e+00 -3.03125720e-01
   3.16082391e-07]
 [-1.36705869e-01  1.36705869e-01 -3.03125721e-01  1.30312801e+00
  -2.29013569e-06]
 [-5.29235969e-07  5.29235969e-07  3.16994334e-07 -2.29104764e-06
   1.00000197e+00]]
P_post:
[[-4.44089210e-16  4.16333634e-17 -1.48344476e-10  3.05311332e-15
  -1.05879118e-22]
 [ 3.46944695e-17  2.22044605e-16  1.63636382e-10 -2.74780199e-15
   5.29395592e-22]
 [-1.66533454e-16  0.00000000e+00  1.21250622e+00 -9.54642394e-07
  -4.23516474e-22]
 [ 4.71844785e-16 -1.11022302e-16 -9.57396747e-07  4.50153724e-06
   2.11758237e-21]
 [-1.05879118e-22  5.29395592e-22  6.39318461e-10 -2.87686905e-15
   2.22044605e-16]]
step -----------
P:
[[ 1.06151280e+00 -6.15127954e-02  1.36550696e-01 -1.36550167e-01
  -5.29038900e-07]
 [-6.15127954e-

step -----------
P:
[[ 1.06211618e+00 -6.21161753e-02  1.37210850e-01 -1.37210318e-01
  -5.31752421e-07]
 [-6.21161753e-02  1.06211618e+00 -1.37210850e-01  1.37210318e-01
   5.31752422e-07]
 [ 1.37210850e-01 -1.37210850e-01  1.30309154e+00 -3.03091857e-01
   3.15880575e-07]
 [-1.37210319e-01  1.37210319e-01 -3.03091858e-01  1.30309415e+00
  -2.28993388e-06]
 [-5.31339755e-07  5.31339755e-07  3.16792134e-07 -2.29084543e-06
   1.00000197e+00]]
P_post:
[[ 2.22044605e-16  5.55111512e-17 -1.48889345e-10  1.44328993e-15
   1.16467030e-21]
 [ 6.24500451e-17 -2.22044605e-16  1.64222469e-10 -2.58126853e-15
   5.18807680e-21]
 [ 2.49800181e-16  2.77555756e-17  1.21234974e+00 -9.53938224e-07
  -6.35274710e-22]
 [-1.11022302e-15  0.00000000e+00 -9.56691115e-07  4.50153724e-06
   1.69406589e-21]
 [ 1.16467030e-21  5.18807680e-21  6.38649773e-10 -2.87386209e-15
   2.22044605e-16]]
step -----------
P:
[[ 1.06199932e+00 -6.19993205e-02  1.37080798e-01 -1.37080267e-01
  -5.31266680e-07]
 [-6.19993205e-

step -----------
P:
[[ 1.06333830e+00 -6.33383003e-02  1.38540088e-01 -1.38539551e-01
  -5.37178945e-07]
 [-6.33383003e-02  1.06333830e+00 -1.38540088e-01  1.38539551e-01
   5.37178946e-07]
 [ 1.38540089e-01 -1.38540089e-01  1.30303035e+00 -3.03030666e-01
   3.15515893e-07]
 [-1.38539552e-01  1.38539552e-01 -3.03030667e-01  1.30303296e+00
  -2.28956919e-06]
 [-5.36762546e-07  5.36762545e-07  3.16426689e-07 -2.29047999e-06
   1.00000197e+00]]
P_post:
[[ 0.00000000e+00  2.77555756e-17 -1.50450513e-10  2.69229083e-15
  -1.90582413e-21]
 [ 2.77555756e-17  2.22044605e-16  1.65848946e-10 -1.44328993e-15
  -4.55280209e-21]
 [ 0.00000000e+00 -2.49800181e-16  1.21195299e+00 -9.52152906e-07
   2.64697796e-22]
 [ 8.32667268e-17  1.22124533e-15 -9.54902115e-07  4.50153723e-06
   4.23516474e-22]
 [-2.01170325e-21 -4.44692297e-21  6.37051535e-10 -2.86667035e-15
  -4.44089210e-16]]
step -----------
P:
[[ 1.06401348e+00 -6.40134780e-02  1.39266864e-01 -1.39266324e-01
  -5.40187339e-07]
 [-6.40134780e-

step -----------
P:
[[ 1.06502395e+00 -6.50239463e-02  1.40323998e-01 -1.40323453e-01
  -5.45030504e-07]
 [-6.50239463e-02  1.06502395e+00 -1.40323998e-01  1.40323453e-01
   5.45030505e-07]
 [ 1.40323999e-01 -1.40323999e-01  1.30282533e+00 -3.02825641e-01
   3.14293990e-07]
 [-1.40323454e-01  1.40323454e-01 -3.02825642e-01  1.30282793e+00
  -2.28834728e-06]
 [-5.44609748e-07  5.44609748e-07  3.15202000e-07 -2.28925529e-06
   1.00000197e+00]]
P_post:
[[ 0.00000000e+00  4.16333634e-17 -1.52283963e-10  1.94289029e-15
   8.47032947e-22]
 [ 5.55111512e-17 -6.66133815e-16  1.67920566e-10 -2.13717932e-15
   2.22346149e-21]
 [ 1.11022302e-16 -2.77555756e-17  1.21134155e+00 -9.49401467e-07
  -1.58818678e-22]
 [-6.66133815e-16  5.55111512e-16 -9.52144228e-07  4.50153722e-06
   0.00000000e+00]
 [ 9.52912066e-22  2.32934060e-21  6.33699494e-10 -2.85158682e-15
  -4.44089210e-16]]
step -----------
P:
[[ 1.06521923e+00 -6.52192285e-02  1.40536888e-01 -1.40536342e-01
  -5.45811428e-07]
 [-6.52192285e-

step -----------
P:
[[ 1.06890094e+00 -6.89009418e-02  1.44378966e-01 -1.44378404e-01
  -5.62115777e-07]
 [-6.89009418e-02  1.06890094e+00 -1.44378966e-01  1.44378404e-01
   5.62115777e-07]
 [ 1.44378966e-01 -1.44378966e-01  1.30254105e+00 -3.02541365e-01
   3.12599779e-07]
 [-1.44378405e-01  1.44378405e-01 -3.02541366e-01  1.30254365e+00
  -2.28665306e-06]
 [-5.61684159e-07  5.61684159e-07  3.13504220e-07 -2.28755750e-06
   1.00000197e+00]]
P_post:
[[ 2.22044605e-16 -3.05311332e-16 -1.56918450e-10  2.83106871e-15
  -1.05879118e-22]
 [-3.19189120e-16  4.44089210e-16  1.72885289e-10 -9.71445147e-16
  -7.41153829e-22]
 [ 2.77555756e-17 -4.16333634e-16  1.21004815e+00 -9.43581328e-07
   0.00000000e+00]
 [ 1.11022302e-16  1.77635684e-15 -9.46311682e-07  4.50153719e-06
   8.47032947e-22]
 [-1.05879118e-22 -8.47032947e-22  6.27934780e-10 -2.82564558e-15
  -4.44089210e-16]]
step -----------
P:
[[ 1.06922554e+00 -6.92255385e-02  1.44711716e-01 -1.44711152e-01
  -5.63548044e-07]
 [-6.92255385e-

step -----------
P:
[[ 1.07319524e+00 -7.31952438e-02  1.48726560e-01 -1.48725980e-01
  -5.80691631e-07]
 [-7.31952438e-02  1.07319524e+00 -1.48726560e-01  1.48725980e-01
   5.80691631e-07]
 [ 1.48726561e-01 -1.48726561e-01  1.30220095e+00 -3.02201260e-01
   3.10572835e-07]
 [-1.48725980e-01  1.48725980e-01 -3.02201261e-01  1.30220355e+00
  -2.28462610e-06]
 [-5.80248692e-07  5.80248692e-07  3.11472851e-07 -2.28552612e-06
   1.00000197e+00]]
P_post:
[[ 2.22044605e-16  1.52655666e-16 -1.61826108e-10  2.41473508e-15
  -2.64697796e-21]
 [ 1.52655666e-16 -4.44089210e-16  1.78160459e-10 -2.47024623e-15
  -9.52912066e-22]
 [ 5.55111512e-17 -5.55111512e-17  1.20861732e+00 -9.37142764e-07
  -3.17637355e-22]
 [-3.60822483e-16  3.88578059e-16 -9.39859158e-07  4.50153716e-06
   8.47032947e-22]
 [-2.43521972e-21 -1.05879118e-21  6.21305414e-10 -2.79581393e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.07360007e+00 -7.36000713e-02  1.49125777e-01 -1.49125195e-01
  -5.82477300e-07]
 [-7.36000713e-

[[ 1.07465694e+00 -7.46569355e-02  1.50158659e-01 -1.50158072e-01
  -5.87182459e-07]
 [-7.46569355e-02  1.07465694e+00 -1.50158659e-01  1.50158072e-01
   5.87182459e-07]
 [ 1.50158659e-01 -1.50158659e-01  1.30201758e+00 -3.02017888e-01
   3.09479986e-07]
 [-1.50158073e-01  1.50158073e-01 -3.02017889e-01  1.30202017e+00
  -2.28353325e-06]
 [-5.86736233e-07  5.86736233e-07  3.10377491e-07 -2.28443075e-06
   1.00000197e+00]]
P_post:
[[-2.22044605e-16 -1.80411242e-16 -1.63270786e-10  3.38618023e-15
   0.00000000e+00]
 [-1.66533454e-16  2.22044605e-16  1.79792653e-10 -2.80331314e-15
  -1.05879118e-22]
 [-1.38777878e-16 -8.32667268e-17  1.20809712e+00 -9.34801915e-07
   0.00000000e+00]
 [ 5.82867088e-16  8.32667268e-17 -9.37512796e-07  4.50153715e-06
   0.00000000e+00]
 [ 1.05879118e-22 -1.05879118e-22  6.18401709e-10 -2.78274845e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.07448074e+00 -7.44807375e-02  1.49983023e-01 -1.49982437e-01
  -5.86462805e-07]
 [-7.44807375e-02  1.07448074e+00 -

P_post:
[[-6.66133815e-16  1.52655666e-16 -1.67227177e-10  3.80251386e-15
  -3.81164826e-21]
 [ 1.38777878e-16 -6.66133815e-16  1.83907112e-10 -3.08086889e-15
  -1.05879118e-22]
 [-2.77555756e-16  8.32667268e-17  1.20706514e+00 -9.30158109e-07
  -5.82335151e-22]
 [ 9.43689571e-16 -1.11022302e-16 -9.32859626e-07  4.50153713e-06
   1.27054942e-21]
 [-3.81164826e-21 -1.05879118e-22  6.14382197e-10 -2.76466006e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.07874373e+00 -7.87437349e-02  1.54149653e-01 -1.54149049e-01
  -6.04055905e-07]
 [-7.87437349e-02  1.07874373e+00 -1.54149653e-01  1.54149049e-01
   6.04055905e-07]
 [ 1.54149654e-01 -1.54149654e-01  1.30176628e+00 -3.01766593e-01
   3.07982326e-07]
 [-1.54149050e-01  1.54149050e-01 -3.01766594e-01  1.30176888e+00
  -2.28203558e-06]
 [-6.03598975e-07  6.03598976e-07  3.08876819e-07 -2.28293007e-06
   1.00000197e+00]]
P_post:
[[ 0.00000000e+00 -4.30211422e-16 -1.67972358e-10  2.27595720e-15
   2.22346149e-21]
 [-4.44089210e-16  0.000000

step -----------
P:
[[ 1.07992034e+00 -7.99203390e-02  1.55229760e-01 -1.55229151e-01
  -6.09618158e-07]
 [-7.99203390e-02  1.07992034e+00 -1.55229760e-01  1.55229151e-01
   6.09618158e-07]
 [ 1.55229761e-01 -1.55229761e-01  1.30150483e+00 -3.01505140e-01
   3.06424132e-07]
 [-1.55229152e-01  1.55229152e-01 -3.01505141e-01  1.30150742e+00
  -2.28047738e-06]
 [-6.09159512e-07  6.09159512e-07  3.07314965e-07 -2.28136821e-06
   1.00000197e+00]]
P_post:
[[-2.22044605e-16  4.16333634e-17 -1.68755315e-10  4.46864767e-15
  -4.65868121e-21]
 [ 5.55111512e-17  4.44089210e-16  1.85815113e-10 -2.99760217e-15
   0.00000000e+00]
 [-3.60822483e-16  2.77555756e-17  1.20632024e+00 -9.26806133e-07
  -5.29395592e-23]
 [ 1.60982339e-15 -2.77555756e-17 -9.29499299e-07  4.50153712e-06
  -1.69406589e-21]
 [-4.65868121e-21 -1.05879118e-22  6.09599190e-10 -2.74313992e-15
   4.44089210e-16]]
step -----------
P:
[[ 1.07907857e+00 -7.90785658e-02  1.54429367e-01 -1.54428761e-01
  -6.06094004e-07]
 [-7.90785658e-

step -----------
P:
[[ 1.07561192e+00 -7.56119234e-02  1.51098730e-01 -1.51098138e-01
  -5.91199429e-07]
 [-7.56119234e-02  1.07561192e+00 -1.51098730e-01  1.51098138e-01
   5.91199429e-07]
 [ 1.51098730e-01 -1.51098730e-01  1.30194856e+00 -3.01948867e-01
   3.09068636e-07]
 [-1.51098139e-01  1.51098139e-01 -3.01948868e-01  1.30195115e+00
  -2.28312190e-06]
 [-5.90750751e-07  5.90750751e-07  3.09965253e-07 -2.28401851e-06
   1.00000197e+00]]
P_post:
[[-4.44089210e-16 -5.55111512e-17 -1.64342651e-10  3.63598041e-15
  -3.17637355e-21]
 [-5.55111512e-17  2.22044605e-16  1.80935988e-10 -2.96984659e-15
   5.29395592e-22]
 [-1.94289029e-16  2.77555756e-17  1.20778600e+00 -9.33401888e-07
  -4.76456033e-22]
 [ 8.04911693e-16 -8.32667268e-17 -9.36109784e-07  4.50153715e-06
   0.00000000e+00]
 [-3.28225267e-21  4.23516474e-22  6.17007401e-10 -2.77647405e-15
   2.22044605e-16]]
step -----------
P:
[[ 1.07555064e+00 -7.55506367e-02  1.51036966e-01 -1.51036375e-01
  -5.90967937e-07]
 [-7.55506367e-

step -----------
P:
[[ 1.08129645e+00 -8.12964455e-02  1.56567612e-01 -1.56566997e-01
  -6.14730888e-07]
 [-8.12964455e-02  1.08129645e+00 -1.56567612e-01  1.56566997e-01
   6.14730889e-07]
 [ 1.56567612e-01 -1.56567612e-01  1.30153235e+00 -3.01532660e-01
   3.06588143e-07]
 [-1.56566998e-01  1.56566998e-01 -3.01532661e-01  1.30153494e+00
  -2.28064139e-06]
 [-6.14268010e-07  6.14268010e-07  3.07479597e-07 -2.28153284e-06
   1.00000197e+00]]
P_post:
[[ 4.44089210e-16 -3.33066907e-16 -1.70624737e-10  2.10942375e-15
  -3.17637355e-22]
 [-3.05311332e-16  2.22044605e-16  1.87661692e-10 -2.60902411e-15
  -1.16467030e-21]
 [ 2.22044605e-16 -1.38777878e-16  1.20594732e+00 -9.25128058e-07
  -2.11758237e-22]
 [-7.77156117e-16  3.88578059e-16 -9.27818500e-07  4.50153711e-06
   1.27054942e-21]
 [-4.23516474e-22 -1.05879118e-21  6.08865428e-10 -2.73983479e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.08192994e+00 -8.19299438e-02  1.57164585e-01 -1.57163968e-01
  -6.17309388e-07]
 [-8.19299438e-

step -----------
P:
[[ 1.09041577e+00 -9.04157747e-02  1.64936102e-01 -1.64935450e-01
  -6.51140128e-07]
 [-9.04157747e-02  1.09041577e+00 -1.64936101e-01  1.64935450e-01
   6.51140129e-07]
 [ 1.64936102e-01 -1.64936102e-01  1.30087692e+00 -3.00877218e-01
   3.02681873e-07]
 [-1.64935451e-01  1.64935451e-01 -3.00877219e-01  1.30087950e+00
  -2.27673510e-06]
 [-6.50656007e-07  6.50656007e-07  3.03565007e-07 -2.27761823e-06
   1.00000197e+00]]
P_post:
[[-4.44089210e-16 -1.11022302e-16 -1.80254534e-10  5.02375919e-15
  -1.58818678e-21]
 [-1.24900090e-16  0.00000000e+00  1.97919042e-10 -3.02535774e-15
   1.05879118e-22]
 [-4.71844785e-16 -2.77555756e-17  1.20306103e+00 -9.12140041e-07
  -5.29395592e-23]
 [ 1.94289029e-15  1.38777878e-16 -9.14803065e-07  4.50153705e-06
  -8.47032947e-22]
 [-1.58818678e-21  0.00000000e+00  5.96025443e-10 -2.68205783e-15
  -2.22044605e-16]]
step -----------
P:
[[ 1.09148537e+00 -9.14853656e-02  1.65878017e-01 -1.65877361e-01
  -6.55468783e-07]
 [-9.14853656e-

step -----------
P:
[[ 1.09462027e+00 -9.46202750e-02  1.68595162e-01 -1.68594494e-01
  -6.68207666e-07]
 [-9.46202750e-02  1.09462027e+00 -1.68595162e-01  1.68594494e-01
   6.68207667e-07]
 [ 1.68595163e-01 -1.68595163e-01  1.30040533e+00 -3.00405634e-01
   2.99871345e-07]
 [-1.68594495e-01  1.68594495e-01 -3.00405635e-01  1.30040791e+00
  -2.27392455e-06]
 [-6.67715492e-07  6.67715492e-07  3.00748308e-07 -2.27480151e-06
   1.00000197e+00]]
P_post:
[[ 0.00000000e+00  6.93889390e-17 -1.83999121e-10  2.38697950e-15
   6.35274710e-22]
 [ 5.55111512e-17  2.22044605e-16  2.02105693e-10 -3.71924713e-15
  -1.69406589e-21]
 [ 1.38777878e-16  1.11022302e-16  1.20166246e+00 -9.05846669e-07
   2.64697796e-22]
 [-7.49400542e-16 -5.27355937e-16 -9.08495553e-07  4.50153702e-06
   4.23516474e-22]
 [ 7.41153829e-22 -1.69406589e-21  5.88706253e-10 -2.64912138e-15
  -2.22044605e-16]]
step -----------
P:
[[ 1.09474661e+00 -9.47466059e-02  1.68710561e-01 -1.68709892e-01
  -6.68607734e-07]
 [-9.47466059e-

step -----------
P:
[[ 1.09367047e+00 -9.36704723e-02  1.67774729e-01 -1.67774065e-01
  -6.64402630e-07]
 [-9.36704723e-02  1.09367047e+00 -1.67774729e-01  1.67774065e-01
   6.64402630e-07]
 [ 1.67774729e-01 -1.67774729e-01  1.30050521e+00 -3.00505512e-01
   3.00466598e-07]
 [-1.67774065e-01  1.67774065e-01 -3.00505513e-01  1.30050779e+00
  -2.27451981e-06]
 [-6.63912283e-07  6.63912283e-07  3.01344868e-07 -2.27539808e-06
   1.00000197e+00]]
P_post:
[[ 2.22044605e-16  1.38777878e-16 -1.83145499e-10  4.05231404e-15
  -1.27054942e-21]
 [ 1.38777878e-16 -2.22044605e-16  2.01162864e-10 -3.80251386e-15
   5.29395592e-22]
 [-2.22044605e-16  2.22044605e-16  1.20197353e+00 -9.07246423e-07
  -1.05879118e-22]
 [ 9.43689571e-16 -5.82867088e-16 -9.09898421e-07  4.50153703e-06
  -4.23516474e-22]
 [-1.27054942e-21  5.29395592e-22  5.90297676e-10 -2.65628389e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.09456902e+00 -9.45690200e-02  1.68574192e-01 -1.68573524e-01
  -6.67634406e-07]
 [-9.45690200e-

step -----------
P:
[[ 1.10023072e+00 -1.00230715e-01  1.73439443e-01 -1.73438754e-01
  -6.89037330e-07]
 [-1.00230715e-01  1.10023072e+00 -1.73439443e-01  1.73438754e-01
   6.89037331e-07]
 [ 1.73439443e-01 -1.73439443e-01  1.30012111e+00 -3.00121406e-01
   2.98177419e-07]
 [-1.73438755e-01  1.73438755e-01 -3.00121407e-01  1.30012368e+00
  -2.27223062e-06]
 [-6.88532610e-07  6.88532610e-07  2.99050792e-07 -2.27310399e-06
   1.00000197e+00]]
P_post:
[[ 0.00000000e+00  5.55111512e-17 -1.89891630e-10  3.30291350e-15
   9.52912066e-22]
 [ 6.93889390e-17  6.66133815e-16  2.08193379e-10 -1.97064587e-15
  -4.34104385e-21]
 [-8.32667268e-17 -3.05311332e-16  1.20000329e+00 -8.98380549e-07
   4.76456033e-22]
 [ 8.32667268e-17  1.30451205e-15 -9.01014287e-07  4.50153699e-06
   8.47032947e-22]
 [ 9.52912066e-22 -4.34104385e-21  5.82022984e-10 -2.61904747e-15
   0.00000000e+00]]
step -----------
P:
[[ 1.10024196e+00 -1.00241956e-01  1.73414406e-01 -1.73413717e-01
  -6.89628486e-07]
 [-1.00241956e-

In [None]:
# MASE calculation -----------------

left = predicted_case.index[0]
right = predicted_case.index[-1]

# take error between prediction and actual
e = (case_ma7_all.loc[left:right] - predicted_case.loc[left:right]).abs()

case_ma7_diff = case_ma7_all.diff(7)
denom = (case_ma7_diff.loc[left:right]).abs()

term = e.div(denom)
term.replace([np.inf, -np.inf], np.nan, inplace=True)
term.dropna(inplace=True)
mase = term.sum() / len(term)
print('mase:', mase)

In [None]:
# Not currently used ---------------
# Validation data run ---------------------
start = prop_ma7_valid.index[0]
d = start
while d <= prop_ma7_valid.index[-1]:
    # each cycle of the while loop executes a step

    # get state level compartments
    x_hat_state_k, beta_k = get_predicts_prior(d, seiir)

    # step the state level compartments 7 days forward
    x_hat_state_k1 = step_seiir(x_hat_state_k, constants, beta_k)

    # convert the state level compartments to county level values
    x_hat_k = b * x_hat_state_k
    x_hat_k1 = b * x_hat_state_k1

    # get measurements
    z_k = np.array([[0],
                    [0],
                    [prop_ma7_valid.loc[d]],
                    [case_ma7_all.loc[d]],
                    [0]])

    # predict step
    P = predict_step(x_hat_k1, P, Q, beta_k, constants)
    print('step -----------')
    print('P:')
    print(P)

    # update step
    x_hat_post, P_post = update_step(x_hat_k, x_hat_k1, P, Rn,
                                     rho1, rho2, z_k)
    print('P_post:')
    print(P_post)

    # store estimated values for proportion and case rate
    indexDate = d + datetime.timedelta(days=7)
    prop_est[indexDate] = rho1 * x_hat_post[3, 0]
    case_est[indexDate] = rho2 * x_hat_post[3, 0]
    seiir_pred[indexDate] = b * x_hat_k1[3, 0]

    # update the P and d
    P = P_post
    d += datetime.timedelta(days=1)
    
# create pandas series of the estimated case rate
predicted_case = pd.Series(case_est)
predicted_seiir_prior = pd.Series(seiir_pred)

In [4]:
# Plotting constants and variables ----------------
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
purple = '#33016F'
gold = '#9E7A27'
gray = '#797979'
width = 4
%matplotlib qt

tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

In [5]:
# Single plot with all lines

matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]



fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, num_stl_ma7, c='red', label='Facebook Positive Symptoms',
         linewidth=width)
#plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Number of Positive Symptom Response per Day')
plt.legend(loc='upper right')

plt.show()

In [None]:
# plot validation data

fig_v, ax_v = plt.subplots(1)
plt.sca(ax_v)


plt.plot(case_ma7_all.loc[start:end].index, case_ma7_all.loc[start:end], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
#plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax_v.twinx()
plt.sca(ax32)
plt.plot(prop_ma7_valid.index, 100*prop_ma7_valid, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')

plt.show()

In [None]:
# compute squared error for the overlapping time period
overlap = pd.date_range(start='2020-04-19', end='2020-07-07')
overlap_dt = [x.to_pydatetime().date() for x in overlap]

# error between the measured case rate (smoothed) and the
# SEIIR model scaled to the county (without any Kalman adjustment)
seiir_sq_err = 0

# error between the measured case rate (smoothed) and the predicted
# output of the Kalman filter
kalman_sq_err = 0

for day in overlap_dt:
    seiir_sq_err += (case_ma7_all[day] - predicted_seiir_prior[day])**2
    kalman_sq_err += (predicted_case[day] - case_ma7_all[day])**2

print('SSE between seiir forecast and case rate:', seiir_sq_err)
print('SSE between kalman forecast and case rate:', kalman_sq_err)


In [None]:
# plot findings -- multiple plots
matplotlib.rcParams.update({'font.size': 20})
plt.style.use('seaborn-whitegrid')
tick_start = K0
tick_end = predicted_seiir_prior.index[-1]

week_interval = pd.date_range(start=tick_start, end=tick_end, freq='W')
week_interval = [x.to_pydatetime().date() for x in week_interval]

fig1, ax11 = plt.subplots(1)
plt.sca(ax11)
width = 4
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
plt.ylabel('Number of Cases per Day')
# plt.xlabel('Date')
plt.legend(loc='upper left')

fig2, ax21 = plt.subplots(1)
plt.sca(ax21)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

fig3, ax31 = plt.subplots(1)
plt.sca(ax31)
plt.plot(case_ma7_all.loc[start:d].index, case_ma7_all.loc[start:d], label='Case Positive Rate', c=gold,
         linewidth=width)
plt.plot(predicted_seiir_prior.index, predicted_seiir_prior,
         label='IHME 7-day Forecast', c=gray, linewidth=width)
plt.plot(predicted_case.index, predicted_case, label='Our 7-Day Forecast',
         c=purple, linewidth=width)
plt.ylim(-2, 72)
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')
# plt.xlabel('Date')
plt.ylabel('Number of Cases per Day')
plt.legend(loc='upper left')

ax32 = ax31.twinx()
plt.sca(ax32)
plt.plot(prop_ma7.index, 100*prop_ma7, c='red', label='Facebook Symptom Rate',
         linewidth=width)
plt.ylim(-.065, 2.5)
plt.grid(axis='y', linestyle=':')
plt.xticks(week_interval, rotation=30, ha='right', rotation_mode='anchor')

plt.ylabel('Percentage of Positive Symptom Response')
plt.legend(loc='upper right')


plt.show()