In [None]:
# FILE: train_stan.ipynb

import os
import random
import pickle
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from epiweeks import Week
from utils import date_today, gravity_law_commute_dist  # Ensure these are defined in utils.py
from model import STAN, CustomGraph  # Import STAN and CustomGraph from model.py

# Ensure reproducibility
RANDOM_SEED = 123
def seed_torch(seed=RANDOM_SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

seed_torch()

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set environment variables
os.environ['NUMEXPR_MAX_THREADS'] = '16'
os.environ['NUMEXPR_NUM_THREADS'] = '8'

# Load and merge data
raw_data = pickle.load(open('./data/state_covid_data.pickle','rb'))
raw_data.to_csv('./data/state_covid_data.csv', index=False)
pop_data = pd.read_csv('./data/uszips.csv')
pop_data = pop_data.groupby('state_name').agg({
    'population':'sum',
    'density':'mean',
    'lat':'mean',
    'lng':'mean'
}).reset_index()
raw_data = pd.merge(raw_data, pop_data, how='inner', left_on='state', right_on='state_name')

# Generate location similarity based on Gravity Law
loc_list = list(raw_data['state'].unique())
loc_dist_map = {}

for each_loc in loc_list:
    loc_dist_map[each_loc] = {}
    for each_loc2 in loc_list:
        lat1 = raw_data[raw_data['state'] == each_loc]['lat'].unique()[0]
        lng1 = raw_data[raw_data['state'] == each_loc]['lng'].unique()[0]
        pop1 = raw_data[raw_data['state'] == each_loc]['population'].unique()[0]

        lat2 = raw_data[raw_data['state'] == each_loc2]['lat'].unique()[0]
        lng2 = raw_data[raw_data['state'] == each_loc2]['lng'].unique()[0]
        pop2 = raw_data[raw_data['state'] == each_loc2]['population'].unique()[0]

        loc_dist_map[each_loc][each_loc2] = gravity_law_commute_dist(
            lat1, lng1, pop1,
            lat2, lng2, pop2,
            r=0.5
        )

num_locations = len(loc_list)
print(f"Number of unique locations: {num_locations}")

# Convert loc_dist_map to a DataFrame for visualization
loc_dist_df = pd.DataFrame(loc_dist_map).fillna(0)

# Create a heatmap of location similarities
plt.figure(figsize=(15, 12))
sns.heatmap(loc_dist_df, cmap='viridis', linewidths=.5)
plt.title('Location Similarity Based on Gravity Law')
plt.xlabel('Location')
plt.ylabel('Location')
plt.show()

# Generate Adjacency Map based on distance threshold
dist_threshold = 18

for each_loc in loc_dist_map:
    loc_dist_map[each_loc] = {k: v for k, v in sorted(loc_dist_map[each_loc].items(), key=lambda item: item[1], reverse=True)}

adj_map = {}
for each_loc in loc_dist_map:
    adj_map[each_loc] = []
    for i, each_loc2 in enumerate(loc_dist_map[each_loc]):
        if loc_dist_map[each_loc][each_loc2] > dist_threshold:
            if i <= 3:
                adj_map[each_loc].append(each_loc2)
            else:
                break
        else:
            if i <= 1:
                adj_map[each_loc].append(each_loc2)
            else:
                break

rows = []
cols = []
for each_loc in adj_map:
    for each_loc2 in adj_map[each_loc]:
        rows.append(loc_list.index(each_loc))
        cols.append(loc_list.index(each_loc2))

# Initialize adjacency matrix
num_states = len(loc_list)
adj_matrix = np.zeros((num_states, num_states), dtype=np.float32)

# Create a mapping from state to index
state_to_index = {state: idx for idx, state in enumerate(loc_list)}

# Populate adjacency matrix based on adj_map
for each_loc in adj_map:
    i = state_to_index[each_loc]
    for each_loc2 in adj_map[each_loc]:
        j = state_to_index[each_loc2]
        adj_matrix[i][j] = 1  # Binary adjacency; set to 1 if connected

# Add self-loops to the adjacency matrix
adj_matrix += np.eye(num_states, dtype=np.float32)

# Normalize adjacency matrix
adj_matrix = adj_matrix / adj_matrix.sum(axis=1, keepdims=True)

# Convert adjacency matrix to tensor
adj_matrix = torch.tensor(adj_matrix, dtype=torch.float32).to(device)
print(f"Adjacency Matrix Shape: {adj_matrix.shape}")

# Create a directed graph using NetworkX
G_nx = nx.DiGraph()

# Add nodes
for state in adj_map.keys():
    G_nx.add_node(state)

# Add edges
for state, neighbors in adj_map.items():
    for neighbor in neighbors:
        G_nx.add_edge(state, neighbor)

# Plot the adjacency graph
plt.figure(figsize=(15, 12))
pos = nx.spring_layout(G_nx, seed=RANDOM_SEED)  # For consistent layout
nx.draw(
    G_nx,
    pos,
    with_labels=True,
    node_size=3000,
    node_color='skyblue',
    font_size=10,
    font_weight='bold',
    edge_color='gray'
)
plt.title('Adjacency Map of States Based on Gravity Law')
plt.show()

# Initialize CustomGraph
custom_g = CustomGraph(G_nx, device)

# Preprocess features
active_cases = []
confirmed_cases = []
new_cases = []
death_cases = []
static_feat = []

for each_loc in loc_list:
    active = raw_data[raw_data['state'] == each_loc]['active'].values
    confirmed = raw_data[raw_data['state'] == each_loc]['confirmed'].values
    new = raw_data[raw_data['state'] == each_loc]['new_cases'].values
    deaths = raw_data[raw_data['state'] == each_loc]['deaths'].values
    static = raw_data[raw_data['state'] == each_loc][['population','density','lng','lat']].values
    active_cases.append(active)
    confirmed_cases.append(confirmed)
    new_cases.append(new)
    death_cases.append(deaths)
    static_feat.append(static)

active_cases = np.array(active_cases)
confirmed_cases = np.array(confirmed_cases)
death_cases = np.array(death_cases)
new_cases = np.array(new_cases)
static_feat = np.array(static_feat)[:, 0, :]
recovered_cases = confirmed_cases - active_cases - death_cases
susceptible_cases = np.expand_dims(static_feat[:, 0], -1) - active_cases - recovered_cases

# Compute differences for dynamic features
dI = np.concatenate((np.zeros((active_cases.shape[0],1), dtype=np.float32), np.diff(active_cases, axis=1)), axis=-1)
dR = np.concatenate((np.zeros((recovered_cases.shape[0],1), dtype=np.float32), np.diff(recovered_cases, axis=1)), axis=-1)
dS = np.concatenate((np.zeros((susceptible_cases.shape[0],1), dtype=np.float32), np.diff(susceptible_cases, axis=1)), axis=-1)

# Build normalizer
normalizer = {'S':{}, 'I':{}, 'R':{}, 'dS':{}, 'dI':{}, 'dR':{}}

for i, each_loc in enumerate(loc_list):
    normalizer['S'][each_loc] = (np.mean(susceptible_cases[i]), np.std(susceptible_cases[i]))
    normalizer['I'][each_loc] = (np.mean(active_cases[i]), np.std(active_cases[i]))
    normalizer['R'][each_loc] = (np.mean(recovered_cases[i]), np.std(recovered_cases[i]))
    normalizer['dI'][each_loc] = (np.mean(dI[i]), np.std(dI[i]))
    normalizer['dR'][each_loc] = (np.mean(dR[i]), np.std(dR[i]))
    normalizer['dS'][each_loc] = (np.mean(dS[i]), np.std(dS[i]))

# Prepare data for training, validation, and testing
def prepare_data(data, sum_I, sum_R, history_window=5, pred_window=15, slide_step=5):
    # Data shape: n_loc, timestep, n_feat
    n_loc = data.shape[0]
    timestep = data.shape[1]
    n_feat = data.shape[2]

    x = []
    y_I = []
    y_R = []
    last_I = []
    last_R = []
    concat_I = []
    concat_R = []
    for i in range(0, timestep, slide_step):
        if i + history_window + pred_window - 1 >= timestep or i + history_window >= timestep:
            break
        x.append(data[:, i:i + history_window, :].reshape((n_loc, history_window * n_feat)))

        concat_I.append(data[:, i + history_window - 1, 0])
        concat_R.append(data[:, i + history_window - 1, 1])
        last_I.append(sum_I[:, i + history_window - 1])
        last_R.append(sum_R[:, i + history_window - 1])

        y_I.append(data[:, i + history_window:i + history_window + pred_window, 0])
        y_R.append(data[:, i + history_window:i + history_window + pred_window, 1])

    x = np.array(x, dtype=np.float32).transpose((1, 0, 2))
    last_I = np.array(last_I, dtype=np.float32).transpose((1, 0))
    last_R = np.array(last_R, dtype=np.float32).transpose((1, 0))
    concat_I = np.array(concat_I, dtype=np.float32).transpose((1, 0))
    concat_R = np.array(concat_R, dtype=np.float32).transpose((1, 0))
    y_I = np.array(y_I, dtype=np.float32).transpose((1, 0, 2))
    y_R = np.array(y_R, dtype=np.float32).transpose((1, 0, 2))
    return x, last_I, last_R, concat_I, concat_R, y_I, y_R

valid_window = 25
test_window = 25

history_window = 6
pred_window = 15
slide_step = 5

normalize = True

dynamic_feat = np.concatenate(
    (
        np.expand_dims(dI, axis=-1),
        np.expand_dims(dR, axis=-1),
        np.expand_dims(dS, axis=-1)
    ),
    axis=-1
)

# Normalize dynamic features
if normalize:
    for i, each_loc in enumerate(loc_list):
        dynamic_feat[i, :, 0] = (dynamic_feat[i, :, 0] - normalizer['dI'][each_loc][0]) / normalizer['dI'][each_loc][1]
        dynamic_feat[i, :, 1] = (dynamic_feat[i, :, 1] - normalizer['dR'][each_loc][0]) / normalizer['dR'][each_loc][1]
        dynamic_feat[i, :, 2] = (dynamic_feat[i, :, 2] - normalizer['dS'][each_loc][0]) / normalizer['dS'][each_loc][1]

dI_mean = []
dI_std = []
dR_mean = []
dR_std = []

for i, each_loc in enumerate(loc_list):
    dI_mean.append(normalizer['dI'][each_loc][0])
    dR_mean.append(normalizer['dR'][each_loc][0])
    dI_std.append(normalizer['dI'][each_loc][1])
    dR_std.append(normalizer['dR'][each_loc][1])

dI_mean = np.array(dI_mean)
dI_std = np.array(dI_std)
dR_mean = np.array(dR_mean)
dR_std = np.array(dR_std)

# Split data into train, validation, and test sets
train_feat = dynamic_feat[:, :-valid_window - test_window, :]
val_feat = dynamic_feat[:, -valid_window - test_window:-test_window, :]
test_feat = dynamic_feat[:, -test_window:, :]

train_x, train_I, train_R, train_cI, train_cR, train_yI, train_yR = prepare_data(
    train_feat,
    active_cases[:, :-valid_window - test_window],
    recovered_cases[:, :-valid_window - test_window],
    history_window,
    pred_window,
    slide_step
)
val_x, val_I, val_R, val_cI, val_cR, val_yI, val_yR = prepare_data(
    val_feat,
    active_cases[:, -valid_window - test_window:-test_window],
    recovered_cases[:, -valid_window - test_window:-test_window],
    history_window,
    pred_window,
    slide_step
)
test_x, test_I, test_R, test_cI, test_cR, test_yI, test_yR = prepare_data(
    test_feat,
    active_cases[:, -test_window:],
    recovered_cases[:, -test_window:],
    history_window,
    pred_window,
    slide_step
)

# Initialize the STAN model
model = STAN(
    custom_g,
    in_dim=3 * history_window,
    hidden_dim1=32,
    hidden_dim2=32,
    gru_dim=32,
    num_heads=1,
    pred_window=pred_window,
    device=device
).to(device)

# Define optimizer and loss criterion
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.MSELoss()

# Convert data to tensors and move to device
train_x = torch.tensor(train_x, dtype=torch.float32).to(device)
train_I = torch.tensor(train_I, dtype=torch.float32).to(device)
train_R = torch.tensor(train_R, dtype=torch.float32).to(device)
train_cI = torch.tensor(train_cI, dtype=torch.float32).to(device)
train_cR = torch.tensor(train_cR, dtype=torch.float32).to(device)
train_yI = torch.tensor(train_yI, dtype=torch.float32).to(device)
train_yR = torch.tensor(train_yR, dtype=torch.float32).to(device)

val_x = torch.tensor(val_x, dtype=torch.float32).to(device)
val_I = torch.tensor(val_I, dtype=torch.float32).to(device)
val_R = torch.tensor(val_R, dtype=torch.float32).to(device)
val_cI = torch.tensor(val_cI, dtype=torch.float32).to(device)
val_cR = torch.tensor(val_cR, dtype=torch.float32).to(device)
val_yI = torch.tensor(val_yI, dtype=torch.float32).to(device)
val_yR = torch.tensor(val_yR, dtype=torch.float32).to(device)

test_x = torch.tensor(test_x, dtype=torch.float32).to(device)
test_I = torch.tensor(test_I, dtype=torch.float32).to(device)
test_R = torch.tensor(test_R, dtype=torch.float32).to(device)
test_cI = torch.tensor(test_cI, dtype=torch.float32).to(device)
test_cR = torch.tensor(test_cR, dtype=torch.float32).to(device)
test_yI = torch.tensor(test_yI, dtype=torch.float32).to(device)
test_yR = torch.tensor(test_yR, dtype=torch.float32).to(device)

dI_mean = torch.tensor(dI_mean, dtype=torch.float32).to(device).reshape((dI_mean.shape[0], 1, 1))
dI_std = torch.tensor(dI_std, dtype=torch.float32).to(device).reshape((dI_std.shape[0], 1, 1))
dR_mean = torch.tensor(dR_mean, dtype=torch.float32).to(device).reshape((dR_mean.shape[0], 1, 1))
dR_std = torch.tensor(dR_std, dtype=torch.float32).to(device).reshape((dR_std.shape[0], 1, 1))

N = torch.tensor(static_feat[:, 0], dtype=torch.float32).to(device).unsqueeze(-1)

# Training parameters
all_loss = []
file_name = './save/stan.pth'
min_loss = 1e10

loc_name = 'California'
cur_loc = loc_list.index(loc_name)

epoch_count = 50 if normalize else 300
scale = 0.1

# Create save directory if it doesn't exist
os.makedirs('./save/', exist_ok=True)

# Training loop
for epoch in tqdm(range(epoch_count), desc='Training Epochs'):
    model.train()
    optimizer.zero_grad()

    # Forward pass
    try:
        active_pred, recovered_pred, phy_active, phy_recover, _ = model(
            train_x,
            train_cI[cur_loc],
            train_cR[cur_loc],
            N[cur_loc],
            train_I[cur_loc],
            train_R[cur_loc],
            torch.tensor(dI[cur_loc], dtype=torch.float32).to(device),
            torch.tensor(dR[cur_loc], dtype=torch.float32).to(device),
            h=None
        )
    except Exception as e:
        print(f"Error during forward pass: {e}")
        break

    # Normalize if required
    if normalize:
        phy_active = (phy_active - dI_mean[cur_loc]) / dI_std[cur_loc]
        phy_recover = (phy_recover - dR_mean[cur_loc]) / dR_std[cur_loc]

    # Compute loss
    loss = (
        criterion(active_pred.squeeze(), train_yI[cur_loc].squeeze()) +
        criterion(recovered_pred.squeeze(), train_yR[cur_loc].squeeze()) +
        scale * criterion(phy_active.squeeze(), train_yI[cur_loc].squeeze()) +
        scale * criterion(phy_recover.squeeze(), train_yR[cur_loc].squeeze())
    )

    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    all_loss.append(loss.item())

    # Validation
    model.eval()
    with torch.no_grad():
        try:
            # Forward pass on training data to get hidden state
            _, _, _, _, prev_h = model(
                train_x,
                train_cI[cur_loc],
                train_cR[cur_loc],
                N[cur_loc],
                train_I[cur_loc],
                train_R[cur_loc],
                torch.tensor(dI[cur_loc], dtype=torch.float32).to(device),
                torch.tensor(dR[cur_loc], dtype=torch.float32).to(device),
                h=None
            )
            # Forward pass on validation data using previous hidden state
            val_active_pred, val_recovered_pred, val_phy_active, val_phy_recover, _ = model(
                val_x,
                val_cI[cur_loc],
                val_cR[cur_loc],
                N[cur_loc],
                val_I[cur_loc],
                val_R[cur_loc],
                torch.tensor(dI[cur_loc], dtype=torch.float32).to(device),
                torch.tensor(dR[cur_loc], dtype=torch.float32).to(device),
                h=prev_h
            )
        except Exception as e:
            print(f"Error during validation pass: {e}")
            break

        if normalize:
            val_phy_active = (val_phy_active - dI_mean[cur_loc]) / dI_std[cur_loc]
            val_phy_recover = (val_phy_recover - dR_mean[cur_loc]) / dR_std[cur_loc]

        val_loss = (
            criterion(val_active_pred.squeeze(), val_yI[cur_loc].squeeze()) +
            criterion(val_recovered_pred.squeeze(), val_yR[cur_loc].squeeze()) +
            scale * criterion(val_phy_active.squeeze(), val_yI[cur_loc].squeeze()) +
            scale * criterion(val_phy_recover.squeeze(), val_yR[cur_loc].squeeze())
        )

    # Save the model if validation loss has decreased
    if val_loss.item() < min_loss:
        state = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }
        torch.save(state, file_name)
        min_loss = val_loss.item()

    # Print loss every epoch
    print(f"Epoch {epoch+1}/{epoch_count}, Loss: {loss.item():.4f}, Val Loss: {val_loss.item():.4f}")

# Plot training loss
plt.figure(figsize=(10, 6))
plt.plot(all_loss, label='Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.show()

# Load the best model
checkpoint = torch.load(file_name)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.eval()

# Make predictions with the test set
prev_x = torch.cat((train_x, val_x), dim=1)
prev_I = torch.cat((train_I, val_I), dim=1)
prev_R = torch.cat((train_R, val_R), dim=1)
prev_cI = torch.cat((train_cI, val_cI), dim=1)
prev_cR = torch.cat((train_cR, val_cR), dim=1)
prev_dI = torch.cat((torch.tensor(dI, dtype=torch.float32).to(device), torch.tensor(dI, dtype=torch.float32).to(device)), dim=1)
prev_dR = torch.cat((torch.tensor(dR, dtype=torch.float32).to(device), torch.tensor(dR, dtype=torch.float32).to(device)), dim=1)

# Forward pass using combined train and validation data to get hidden state
prev_active_pred, _, prev_phyactive_pred, _, h = model(
    prev_x,
    train_cI[cur_loc],
    train_cR[cur_loc],
    N[cur_loc],
    prev_I[cur_loc],
    prev_R[cur_loc],
    prev_dI[cur_loc],
    prev_dR[cur_loc],
    h=None
)

# Forward pass on test set using the hidden state
test_pred_active, test_pred_recovered, test_pred_phy_active, test_pred_phy_recover, _ = model(
    test_x,
    test_cI[cur_loc],
    test_cR[cur_loc],
    N[cur_loc],
    test_I[cur_loc],
    test_R[cur_loc],
    torch.tensor(dI[cur_loc], dtype=torch.float32).to(device),
    torch.tensor(dR[cur_loc], dtype=torch.float32).to(device),
    h
)

if normalize:
    print(f'Estimated alpha in SIR model: {model.alpha_scaled.item():.4f}')
    print(f'Estimated beta in SIR model: {model.beta_scaled.item():.4f}')

# Cumulate predicted dI
pred_I = []

for i in range(test_pred_active.size(1)):
    if normalize:
        cur_pred = (test_pred_active[0, i, :].detach().cpu().numpy() * dI_std[cur_loc].reshape(1, 1).detach().cpu().numpy()) + dI_mean[cur_loc].reshape(1, 1).detach().cpu().numpy()
    else:
        cur_pred = test_pred_active[0, i, :].detach().cpu().numpy()
    cur_pred = np.cumsum(cur_pred)
    cur_pred = cur_pred + test_I[cur_loc, i].detach().cpu().item()
    pred_I.append(cur_pred)
pred_I = np.array(pred_I)

# Function to get real y values
def get_real_y(data, history_window=5, pred_window=15, slide_step=5):
    # Data shape: n_loc, timestep, n_feat
    n_loc = data.shape[0]
    timestep = data.shape[1]

    y = []
    for i in range(0, timestep, slide_step):
        if i + history_window + pred_window - 1 >= timestep or i + history_window >= timestep:
            break
        y.append(data[:, i + history_window:i + history_window + pred_window])
    y = np.array(y, dtype=np.float32).transpose((1, 0, 2))
    return y

I_true = get_real_y(active_cases[:], history_window, pred_window, slide_step)

# Plot predictions vs true values for Active Cases
plt.figure(figsize=(12, 6))
plt.plot(I_true[cur_loc, -1, :], c='r', label='Ground truth')
plt.plot(pred_I[-1, :], c='b', label='Prediction')
plt.xlabel('Time Step')
plt.ylabel('Active Cases')
plt.title('Active Cases: True vs Predicted')
plt.legend()
plt.show()

Using device: cuda
