# Predicting equine West Nile Virus (WNV) cases by county using a graph LSTM neural network (GLSTM) for binary classification

This tutorial walks through predicting equine West Nile virus (WNV) cases by county, using a graph LSTM (GLSTM) neural network model that leverages both spatial and temporal dependencies. The model integrates graph neural networks with LSTM layers to capture complex patterns in WNV transmission, enabling binary predictions of equine case presence. By incorporating county-level features and historical data, this GLSTM-based framework aims to provide an accurate, data-driven tool for identifying regions at high risk for equine WNV cases, supporting more targeted intervention and prevention strategies. Please note that to run this notebook, you will need to use the grwg_2024_env Jupyter kernel.

**Primary Libraries/Packages**:

| Name               | Description                                                                                                      | Link                                                                                      |
|--------------------|------------------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------|
| `pandas`           | Data manipulation and analysis library for Python.                                                              | [pandas](https://pandas.pydata.org/)                                                     |
| `geopandas`        | Extends `pandas` to handle geographic data and spatial operations.                                              | [geopandas](https://geopandas.org/)                                                      |
| `numpy`            | Fundamental library for numerical computation in Python.                                                        | [numpy](https://numpy.org/)                                                              |
| `torch`            | Deep learning framework by PyTorch for tensor computation and neural network models.                            | [torch](https://pytorch.org/)                                                            |
| `matplotlib`       | Comprehensive library for creating static, animated, and interactive visualizations in Python.                  | [matplotlib](https://matplotlib.org/)                                                    |
| `torchmetrics`     | Library of metrics for evaluating PyTorch models, supporting many tasks like classification and regression.      | [torchmetrics](https://torchmetrics.readthedocs.io/)                                     |
| `networkx`         | Library for creating, analyzing, and visualizing complex networks and graphs.                                   | [networkx](https://networkx.github.io/)                                                  |
| `libpysal`         | Spatial analysis library in Python, used for spatial statistics and econometrics.                               | [libpysal](https://pysal.org/libpysal/)                                                  |
| `scikit-learn`     | Comprehensive library for machine learning, with tools for classification, regression, clustering, and more.    | [scikit-learn](https://scikit-learn.org/)                                                |
| `torch-geometric`  | Extension library for PyTorch with support for deep learning on graphs and other irregular structures.          | [torch-geometric](https://pytorch-geometric.readthedocs.io/)                             |
| `captum`           | Model interpretability library for PyTorch, providing tools to understand and interpret model predictions.      | [captum](https://captum.ai/)                                                             |
| `imbalanced-learn` | Tools for handling imbalanced datasets, such as oversampling and undersampling techniques.                      | [imbalanced-learn](https://imbalanced-learn.org/)                                        |
| `seaborn`          | Statistical data visualization library based on `matplotlib`, providing a high-level interface for drawing plots.| [seaborn](https://seaborn.pydata.org/)                                                   |
| `pytorch-lightning`| Lightweight wrapper for PyTorch that simplifies model training and accelerates development.                     | [pytorch-lightning](https://www.pytorchlightning.ai/)                                    |


*Terminology*:

* Binary Classification: A type of classification where the model predicts one of two possible classes, often represented as "0" and "1". In our case, 0 refers to disease absence and 1 disease presence. Data is collected/aggregated to county per week for a subset of US states (KS, OK, TX, LA, AR, MS) and predictions are made per county/week. 

* Graph Neural Network (GNN): A neural network architecture designed to operate on graph-structured data. It captures relationships between entities, such as spatial connections. 

* LSTM (Long Short-Term Memory): A type of recurrent neural network (RNN) that can capture long-term dependencies in sequential data, useful for time series predictions.

* GLSTM (Graph LSTM): A combination of GNNs and LSTMs, where spatial dependencies (modeled by GNNs) and temporal dependencies (modeled by LSTMs) are integrated to handle both spatial and temporal data.

*Tutorial Outline*:
* 1\. **[Exploring the dataset and problem space](#1.-Exploring-the-dataset-and-problem-space)**
* 2\. **[Training the GLSTM model to predict disease presence/absence](#2.-Training-the-GLSTM-model-to-predict-disase-presence/absence)**
* 3\. **[Evaluating training performance and visualizing predictions](#3.-Evaluating-training-performance-and-visualizing-predictions)**
* 4\. **[Mapping Predictions vs. Reported vs. Differences](#4.-Mapping-Predictions-vs.-Reported-vs.-Differences)**

# 0. Preliminaries

In [None]:
# Import necessary libraries
import torch
import pandas as pd
import matplotlib.pyplot as plt
import geopandas as gpd
import numpy as np
from sklearn.metrics import precision_score, recall_score, f1_score
import torch.nn as nn
from GLSTM_models import GLSTM4
from GLSTM_utils import merge_data, align, get_neighbors, resample_and_order, split_normalize_format, plot_loss_accuracy, get_predictions, match_predictions, conf_mat, plot_gradient_importance, plot_roc_curve, map_results,  keep_most_frequent, calculate_feature_importance, plot_feature_importance
from GLSTM_training import train2, evaluate

# Data and Model Varibales:

# a. Habitat Heterogeneity
All data was collected at a 1x1 km spatial resolution. Source: [Tuanmu, M.-N., and W. Jetz. (2015)](https://doi.org/10.1111/geb.12365)

| **Description**                                                                                 |
|-------------------------------------------------------------------------------------------------|
| Coefficient of variation (cv)/ Normalized dispersion of EVI                                   |
| Evenness of EVI                                                                                |
| Range of EVI                                                                                   |
| Shannon / Diversity of EVI                                                                     |
| Simpson / Diversity of EVI                                                                      |
| Standard Deviation(std) / Dispersion of EVI                                                    |
| Contrast / Exponentially weighted difference in EVI between adjacent pixels                   |
| Correlation / Linear dependency of EVI on adjacent pixels                                     |
| Dissimilarity / Difference in EVI between adjacent pixels                                      |
| Entropy / Disorderliness of EVI                                                                  |
| Homogeneity / Similarity of EVI between adjacent pixels                                        |
| Maximum / Dominance of EVI combinations between adjacent pixels                                 |
| Uniformity / Orderliness of EVI                                                                |
| Variance / Dispersion of EVI combinations between adjacent pixels                               |


# b. Land Cover Data
All data was extracted at a 1x1 km spatial resolution. Source: [Tuanmu and Jetz (2014)](https://doi.org/10.1111/geb.12182)

| **Class Number** | **Description**                                     |
|-------------------|-----------------------------------------------------|
| 1                 | Evergreen/Deciduous Needleleaf Trees               |
| 2                 | Evergreen Broadleaf Trees                            |
| 3                 | Deciduous Broadleaf Trees                            |
| 4                 | Mixed/Other Trees                                   |
| 5                 | Shrubs                                              |
| 6                 | Herbaceous Vegetation                               |
| 7                 | Cultivated and Managed Vegetation                   |
| 8                 | Regularly Flooded Vegetation                        |
| 9                 | Urban/Built-up                                      |
| 10                | Snow/Ice                                           |
| 11                | Barren                                             |
| 12                | Open Water                                         |

# c. Topographic Data
All data extracted at a 1x1 km spatial resolution. Source: [Amatulli et al. (2018)](https://doi.org/10.1038/sdata.2018.40)

| **Description**                  |
|----------------------------------|
| Aspect Cosine                    |
| Aspect Sine                      |
| Elevation                        |
| Profile Curvature (pcurv)       |
| Roughness                        |
| Slope                            |
| Tangential Curvature (tcurv)    |
| Topographic Position Index (tpi) |
| Terrain Ruggedness Index (tri)   |
| Vector Ruggedness Measure (vrm)  |

# d. Other Feature and Target Variable

| **Description**                                                             | **Source**                                                                                                                                                 | **Resolution**         | **Static/Dynamic** |
|-----------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------|------------------------|---------------------|
| **Target**                                                                  |                                                                                                                                                           |                        |                     |
| Equine West Nile Virus (WNV) Incidence Reports                             | [ArboNET](https://www.cdc.gov/mosquitoes/php/arbonet/index.html)                                                                                       | Weekly                 | Dynamic             |
| **Features**                                                                |                                                                                                                                                           |                        |                     |
| Bird and Mosquito West Nile Virus Surveillance Reports                       | [ArboNET](https://www.cdc.gov/mosquitoes/php/arbonet/index.html)                                                                                       | Weekly                 | Dynamic             |
| Temperature                                                                 | [PRISM Climate Group](https://prism.oregonstate.edu)                                                                                                    | Monthly                | Dynamic             |
| Precipitation Data                                                           | [PRISM Climate Group](https://prism.oregonstate.edu)                                                                                                    | Monthly                | Dynamic             |
| Drought Data                                                                | [NOAA](https://www.ncdc.noaa.gov/)                                                                                                                      | Monthly                | Dynamic             |
| Normalized Difference Vegetation Index (NDVI)                              | [MODIS/Terra Vegetation Indices](https://doi.org/10.1038/sdata.2018.227)                                                                               | Annual, 1x1 km        | Dynamic             |
| County-Level Human Population Estimate                                       | [U.S. Census Bureau](https://www.census.gov/)                                                                                                           | 2020                   | Static              |
| Bird Species Richness                                                       | [Humphreys et al. (2021)](https://doi.org/10.3390/v13091811), [eBird Database](https://doi.org/10.1016/j.biocon.2013.11.003)                           | Weekly                 | Dynamic             |
| County-level Horse Counts                                                   | [Humphreys et al. (2021)](https://doi.org/10.3390/v13091811), [Gridded Livestock of the World](https://doi.org/10.1038/sdata.2018.227)                 | Time-invariant         | Static              |
          |


### 1. Exploring the dataset and problem space

In [None]:
#read in the data
reduced_features = pd.read_csv("reduced_features.csv")
south_counties = gpd.read_file("south_counties.shp")
earthenv = pd.read_csv("earthenv_2024-09-03.csv")
habitat_hetero = pd.read_csv("habitat_hetero_2024-09-03.csv")
topographic = pd.read_csv("topographic_2024-09-03.csv")
ndvi = pd.read_csv("ndvi_2024-09-05.csv")

In [None]:
#merge all tabular data into a single dataframe using the merge_data function from GLSTM_utils.py
all_merge = merge_data(earthenv, topographic, habitat_hetero, reduced_features, ndvi)

In [None]:
#Merge the csv with the shapefile, standardize column names, and drop unneccesary columns 
df = align(all_merge, south_counties)

In [None]:
#Generates the adjacency matrix from the shapefile
adj_matrix = get_neighbors(south_counties)

In [None]:
#2798 samples with WNV Presence out of 319302 total samples. 
df['Binary'].sum()

In [None]:
#Performs oversampling to balance the classes, ensures temporal order, and provides a DF used to map the results
ordered, resampled_indices, to_match = resample_and_order(df, resample = True)

In [None]:
# Filter data for validation
val_data = ordered[(ordered['Year'] == 2018) | (ordered['Year'] ==2019)]
    
# Filter data for training (excluding test/val years)
train_data = ordered[~ordered['Year'].isin(['2012', '2018', '2019'])]

# Fileter data for testing
test_data = ordered[ordered['Year'] ==2012]

In [None]:
#Performs train/test split, normalization, and generates geometric data objects used as inputs for Pytorch models
#Set split to false while inputing a subset of data to generate a training set and perform all other preprocessing steps
data_train, X_train_normalized, node_id_train = split_normalize_format(train_data, resampled_indices, adj_matrix, split=False)

In [None]:
#Performs train/test split, normalization, and generates geometric data objects used as inputs for Pytorch models
#Set split to false while inputing a subset of data to generate a validation set and perform all other preprocessing steps
data_val, X_val_normalized, node_id_val = split_normalize_format(val_data, resampled_indices, adj_matrix, split=False)

In [None]:
#Performs train/test split, normalization, and generates geometric data objects used as inputs for Pytorch models
#Set split to false while inputing a subset of data to generate a testing set and perform all other preprocessing steps
data_test, X_test_normalized, node_id_test = split_normalize_format(test_data, resampled_indices, adj_matrix, split=False)

### 2. Training the GLSTM model to predict disease presence/absence

In [None]:
# Instantiate the model
input_dim = X_train_normalized.shape[1]
#input_dim = 54
hidden_dim1 = 32
dropout = 0.2
activation_function = torch.relu
hidden_dim2 = 32
hidden_dim3 = 32
hidden_dim4 = 32
hidden_dim5 = 32
output_dim = 2  # Assuming binary classification
model = GLSTM4(input_dim, hidden_dim1, hidden_dim2, hidden_dim3, hidden_dim4, hidden_dim5, output_dim, dropout, activation_function)

In [None]:
# Set up training parameters, optimizer, and criterion
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
%%time
# Train loop
num_epochs = 100
#store histories
train_loss_history = []
train_acc_history = []
val_loss_history = []
val_acc_history = []
all_gradients =[]
for epoch in range(num_epochs):
    #train
    train_loss, train_acc, gradients, param_names = train2(model, data_train, criterion, optimizer)
    
    #save gradients
    all_gradients.append(gradients)
    test_loss, test_accuracy = evaluate(model, data_test, criterion, optimizer)

    #save training and testing data
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    val_loss_history.append(test_loss)
    val_acc_history.append(test_accuracy)

    #print evaluation metrics
    print(f'Epoch {epoch+1}/{num_epochs}, TestLoss: {test_loss:.4f}, TestAccuracy: {test_accuracy:.4f}, TrainLoss: {train_loss:.4f}, TrainAccuracy: {train_acc:.4f}')

In [None]:
name_mapping_features = {
    'AVI': 'Avian WNV',
    'MOS': 'Mosquito WNV',
    'SEN': 'Chicken WNV',
    'PPT': 'Precipitation',
    'Temp': 'Temperature',
    'None': 'No Drought',
    'D0': 'Drought 0',
    'D1': 'Drought 1',
    'D2': 'Drought 2',
    'D3': 'Drought 3',
    'D4': 'Drought 4',
    'Richness': 'Bird Richness',
    'Horses': 'Horse Count',
    'class_1': 'Needleleaf Trees',
    'class_10': 'Snow/Ice',
    'class_11': 'Barren',
    'class_12': 'Open Water',
    'class_2': 'Evergreen Trees',
    'class_3': 'Deciduous Trees',
    'class_4': 'Mixed Trees',
    'class_5': 'Shrubs',
    'class_6': 'Herbaceous',
    'class_7': 'Cultivated Vegetation',
    'class_8': 'Flooded Vegetation',
    'class_9': 'Urban',
    'Contrast': 'EVI Contrast',
    'Correlation': 'EVI Correlation',
    'cv': 'EVI Dispersion',
    'Dissimilarity': 'EVI Dissimilarity',
    'Entropy': 'EVI Entropy',
    'evenness': 'EVI Evenness',
    'Homogeneity': 'EVI Homogeneity',
    'Maximum': 'EVI Dominance',
    'range': 'EVI Range',
    'shannon': 'Shannon Index',
    'simpson': 'Simpson Index',
    'std': 'EVI Std Dev',
    'Uniformity': 'EVI Uniformity',
    'Variance': 'EVI Variance',
    'aspectcosine': 'Eastness',
    'aspectsine': 'Northness',
    'elevation': 'Elevation',
    'pcurv': 'Profile Curve',
    'roughness': 'Roughness',
    'slope': 'Slope',
    'tcurv': 'Tangential Curve',
    'tpi': 'Topo Position',
    'tri': 'Terrain Roughness',
    'vrm': 'Vector Ruggedness',
    'Value': 'NDVI'
}


In [None]:
# Calculate feature importance
importance_dict = calculate_feature_importance(model, data_train, 0, feature_names, device='cpu')

plot_feature_importance(importance_dict, threshold=5, name_mapping=name_mapping_features)

### 3. Evaluating training performance and visualizing predictions

In [None]:
plot_loss_accuracy(train_loss_history, train_acc_history, val_loss_history, val_acc_history)

In [None]:
y_true, y_pred = get_predictions(model, data_val)

In [None]:
conf_mat(y_true, y_pred)

In [None]:
precision = precision_score(y_true, y_pred, average='binary')
recall = recall_score(y_true, y_pred, average='binary')
f1 = f1_score(y_true, y_pred, average='binary')

print(f'Precision: {precision:.4f}')
print(f'Recall: {recall:.4f}')
print(f'F1 Score: {f1:.4f}')

In [None]:
plot_roc_curve(y_true, y_pred)

In [None]:
name_mapping_grads = {
    'conv1.bias': 'Convolution 1 Bias',
    'conv1.lin.weight': 'Convolution 1 Weights',
    'lstm1.weight_ih_l0': 'LSTM 1 Input Weights',
    'lstm1.weight_hh_l0': 'LSTM 1 Hidden Weights',
    'lstm1.bias_ih_l0': 'LSTM 1 Input Bias',
    'lstm1.bias_hh_l0': 'LSTM 1 Hidden Bias',
    'conv2.bias': 'Convolution 2 Bias',
    'conv2.lin.weight': 'Convolution 2 Weights',
    'conv3.bias': 'Convolution 3 Bias',
    'conv3.lin.weight': 'Convolution 3 Weights',
    'conv4.bias': 'Convolution 4 Bias',
    'conv4.lin.weight': 'Convolution 4 Weights',
    'conv5.bias': 'Convolution 5 Bias',
    'conv5.lin.weight': 'Convolution 5 Weights',
    'fc.weight': 'Fully Connected Layer Weights',
    'fc.bias': 'Fully Connected Layer Bias'
}


In [None]:
plot_gradient_importance(param_names, gradients, name_mapping_grads)

In [None]:
preds = match_predictions(model, data_test, node_id_test, to_match, south_counties)

In [None]:
df_unique = keep_most_frequent(preds, ['Year', 'Week', 'node_id'], 'ClassLabel')

### 4. Mapping Predictions vs. Reported vs. Differences

In [None]:
states = gpd.read_file('cb_2018_us_state_500k.shp')
south_states = states[states['NAME'].isin(['Mississippi', 'Arkansas', 'Louisiana', 'Texas', 'Kansas', 'Oklahoma'])]

In [None]:
# Check the CRS of both GeoDataFrames
print(f"Base GeoDataFrame CRS: {results_agg_gdf.crs}")
print(f"Overlay GeoDataFrame CRS (south states): {south_states.crs}")

In [None]:
results_agg_gdf, states_gdf = prepare_for_mapping(df_unique, south_counties, south_states)

In [None]:
map_results4(2012, south_counties, results_agg_gdf, south_states)

In [None]:
#get 2018, 2019 predictions
preds = match_predictions(model, data_val, node_id_val, to_match, south_counties, [2018,2019])

In [None]:
df_unique = keep_most_frequent(preds, ['Year', 'Week', 'node_id'], 'ClassLabel')

In [None]:
results_agg_gdf, states_gdf = prepare_for_mapping(df_unique, south_counties, south_states)

In [None]:
map_results4([2018,2019], south_counties, results_agg_gdf, south_states)