In [1]:
# from google.colab import drive
# drive.mount('/content/drive')

In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

# LSTM Networks to predict Dengue cases

Several departments in the U.S. Federal Government (Department of Health and Human Services, Department of Defense, Department of Commerce, and the Department of Homeland Security) have joined together, with the support of the Pandemic Prediction and Forecasting Science and Technology Interagency Working Group under the National Science and Technology Council, to design an infectious disease forecasting project with the aim of galvanizing efforts to predict epidemics of dengue.

As a result of this collaboration, the Epidemic Prediction Initiative has gathered a dataset that consists of dengue case counts for San Juan, Puerto Rico and Iquitos, Peru. 


## Data

The data consists in case counts between 1990-2009 for San Juan, Puerto Rico and 2000-2009 for Iquitos, Peru. 
For both datasets we have used a 70/30 split to be used for training and testing respectively.

### San Juan Dataset

In [3]:
sanjuan_dataset = pd.read_csv("dengueData/SanJuan/san_juan_training_data.csv")
print(sanjuan_dataset.head())
print(sanjuan_dataset.tail())

      season  season_week week_start_date  denv1_cases  denv2_cases  \
0  1990/1991            1      1990-04-30            0            0   
1  1990/1991            2      1990-05-07            0            0   
2  1990/1991            3      1990-05-14            0            0   
3  1990/1991            4      1990-05-21            0            0   
4  1990/1991            5      1990-05-28            0            0   

   denv3_cases  denv4_cases  other_positive_cases  additional_cases  \
0            0            0                     4                 0   
1            0            0                     5                 0   
2            0            0                     4                 0   
3            0            0                     3                 0   
4            0            0                     6                 0   

   total_cases  
0            4  
1            5  
2            4  
3            3  
4            6  
        season  season_week week_start_date 

In [None]:
sanjuan_cases = sanjuan_dataset["total_cases"].values
print("San Juan weekly observations", sanjuan_cases.shape)
split_index = int(len(sanjuan_cases)* 0.7)
sanjuan_train = sanjuan_cases[:split_index]
sanjuan_test = sanjuan_cases[split_index:]

print("Training observations: ", sanjuan_train.shape)
print("Testing observations: ", sanjuan_test.shape)

San Juan weekly observations (988,)
Training observations:  (691,)
Testing observations:  (297,)


### Iquitos dataset

In [None]:
iquitos_dataset = pd.read_csv("dengueData/Iquitos/iquitos_training_data.csv")
print(iquitos_dataset.head())
print(iquitos_dataset.tail())

      season  season_week week_start_date  denv1_cases  denv2_cases  \
0  2000/2001            1      2000-07-01            0            0   
1  2000/2001            2      2000-07-08            0            0   
2  2000/2001            3      2000-07-15            0            0   
3  2000/2001            4      2000-07-22            0            0   
4  2000/2001            5      2000-07-29            0            0   

   denv3_cases  denv4_cases  other_positive_cases  total_cases  
0            0            0                     0            0  
1            0            0                     0            0  
2            0            0                     0            0  
3            0            0                     0            0  
4            0            0                     0            0  
        season  season_week week_start_date  denv1_cases  denv2_cases  \
463  2008/2009           48      2009-05-28            0            0   
464  2008/2009           49      2009

In [None]:
iquitos_cases = iquitos_dataset["total_cases"].values
print("Iquitos weekly observations", iquitos_cases.shape)
split_index = int(len(iquitos_cases)* 0.7)
iquitos_train = iquitos_dataset[:split_index]
iquitos_test = iquitos_dataset[split_index:]

print("Training observations: ", iquitos_train.shape)
print("Testing observations: ", iquitos_test.shape)

Iquitos weekly observations (468,)
Training observations:  (327, 9)
Testing observations:  (141, 9)


## Our LSTM Model

In [None]:
from keras.models import Sequential
from keras.layers import LSTM, Dense, Input, Flatten
from keras.layers.merge import concatenate
from keras.constraints import non_neg
from keras.models import Model
from sklearn.metrics import mean_squared_error
from keras.optimizers import Adam
from keras.utils import plot_model

input_layer = Input(shape=(1,4))
b1_out = LSTM(64, return_sequences=False)(input_layer)

b2_out = Dense(32, activation="relu", kernel_regularizer="l2")(input_layer)
b2_out = Flatten()(b2_out)

concatenated = concatenate([b1_out, b2_out])
out = Dense(4, activation="relu", kernel_regularizer="l2")(concatenated)
out = Dense(4, activation="relu", kernel_regularizer="l2")(out)
out = Dense(1, activation="linear", kernel_constraint=non_neg(), name='output_layer')(out)

model = Model([input_layer], out)
model.compile(loss=["mse"], optimizer=Adam(0.0001), metrics=["mae", "mse", "mean_absolute_percentage_error"])

plot_model(model, show_shapes=True, to_file='model.png')

Using TensorFlow backend.


![model.png](attachment:model.png)

### Training our model

Since we use l2 regularization, our model isn't prone to overfitting.

![imagen.png](attachment:imagen.png)

## SARIMA model

We trained a model with a SARIMA model using time series cross-validation for 1-4 week ahead forecasts, as each test observation got insterted for further training, we re-trained the model using R's auto.arima function, which performs a grid search to find the optimal parameters p, d, q, P, D, Q for a SARIMA model.
![TimeSeriesCross1.png](attachment:TimeSeriesCross1.png)
![TimeSeriesCross4.png](attachment:TimeSeriesCross4.png)

## Results

In [None]:

sanjuan_LSTM_MSE = [94.96, 152.43, 228.64, 298.65]
sanjuan_SARIMA_MSE = [108.60, 175.98, 264.92, 397.47]

plt.title("San Juan MSE")
plt.ylabel("MSE")
plt.xlabel("Weeks ahead prediction")
plt.xticks(np.arange(4), range(1,5))
plt.plot(sanjuan_SARIMA_MSE, label="SARIMA")
plt.plot(sanjuan_LSTM_MSE, label="LSTM")
plt.legend()
plt.show()

In [None]:

# sanjuan_LSTM_MSE = [94.96, 152.43, 228.64, 298.65]
# sanjuan_SARIMA_MSE = [108.60, 175.98, 264.92, 397.47]

# improvement = []
# for i in range(len(sanjuan_LSTM_MSE)):
#     improvement.append( (sanjuan_SARIMA_MSE[i] / sanjuan_LSTM_MSE[i] * 100) - 100)
# plt.title("Improvement over SARIMA")
# plt.ylabel("% ")
# plt.xlabel("Weeks ahead prediction")
# plt.xticks(np.arange(4), range(1,5))
# plt.plot(improvement, label="LSTM improvement")
# # plt.plot(sanjuan_LSTM_MSE, label="LSTM")
# plt.legend()
# plt.show()

In [None]:
iquitos_LSTM_MSE = [61.39, 87.23, 101.81, 127.22]
iquitos_SARIMA_MSE = [60.80, 86.16, 108.77, 130.44]

plt.title("Iquitros MSE")
plt.ylabel("MSE")
plt.xlabel("Weeks ahead prediction")
plt.xticks(np.arange(4), range(1,5))
plt.plot(iquitos_SARIMA_MSE, label="SARIMA")
plt.plot(iquitos_LSTM_MSE, label="LSTM")
plt.legend()
plt.show()