In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix

import numpy as np
import pandas as pd
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score
from scipy.sparse import csr_matrix
import joblib

import random

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# ====================
# Seeds and Hyperparameters
# ====================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

set_seed()

# Hyperparameters
k = 50

In [3]:
# ====================
# LOAD DATA
# ====================
df = pd.read_csv("../../dataset/dataset.csv")  

# 2. Define your feature columns and label
features = ['CHF_F30',
            'HICHOLRP',
            'INCONT',
            'BKBONMOM',
            'PREG',
            'AGE',
            'ETHNICNIH',
            'F45CALC',
            'F60ALCWK',
            'F60CALC',]

outcome = 'ANYFX'

X = df[features].values
y = df[outcome].values

In [4]:
class_0_idx = np.where(y == 0)[0]
class_1_idx = np.where(y == 1)[0]

undersample_size = min(len(class_0_idx), len(class_1_idx))

undersample_idx_0 = np.random.choice(class_0_idx, undersample_size, replace=False)
undersample_idx_1 = np.random.choice(class_1_idx, undersample_size, replace=False)

balanced_idx = np.concatenate([undersample_idx_0, undersample_idx_1])

X_balanced = X[balanced_idx]
y_balanced = y[balanced_idx]

# ====================
# STRATIFIED SPLIT
# ====================
splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.4, random_state=42)
train_idx, temp_idx = next(splitter.split(X_balanced, y_balanced))

X_train, y_train = X_balanced[train_idx], y_balanced[train_idx]
X_temp, y_temp = X_balanced[temp_idx], y_balanced[temp_idx]

splitter_val = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_idx, test_idx = next(splitter_val.split(X_temp, y_temp))

X_val, y_val = X_temp[val_idx], y_temp[val_idx]
X_test, y_test = X_temp[test_idx], y_temp[test_idx]

# ====================
# SCALING
# ====================
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# ====================
# KNN
# ====================
def build_knn_graph(X_split, k):
    knn_local = NearestNeighbors(n_neighbors=k)
    knn_local.fit(X_split)
    adj_matrix = knn_local.kneighbors_graph(X_split, mode='connectivity')
    edge_index, _ = from_scipy_sparse_matrix(csr_matrix(adj_matrix))
    return edge_index

knn = NearestNeighbors(n_neighbors=k)
knn.fit(X)

# ====================
# DEFINE MODELS
# ====================
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class GraphSAGE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphSAGE, self).__init__()
        self.sage1 = SAGEConv(input_dim, hidden_dim)
        self.sage2 = SAGEConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.sage1(x, edge_index))
        x = self.sage2(x, edge_index)
        return x

# ====================
# EVALUATION FUNCTION
# ====================
@torch.no_grad()
def evaluate(model, X_split, y_split, k=k):
    model.eval()
    edge_index = build_knn_graph(X_split, k)
    data_split = Data(x=torch.tensor(X_split, dtype=torch.float).to(device),
                      edge_index=edge_index.to(device))
    logits = model(data_split.x, data_split.edge_index)
    preds = logits.argmax(dim=1).cpu().numpy()

    acc = accuracy_score(y_split, preds)
    auc = roc_auc_score(y_split, preds)
    tn, fp, fn, tp = confusion_matrix(y_split, preds).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    return acc, sensitivity, specificity, auc

class EarlyStopping:
    def __init__(self, patience=20, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_auc, model, path):
        score = val_auc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model, path)
            self.counter = 0

    def save_checkpoint(self, model, path):
        torch.save(model.state_dict(), path)
        
# ====================
# TRAINING FUNCTION
# ====================
def run_training(model_type='gcn'):
    if model_type == 'gcn':
        model = GCN(input_dim=X.shape[1], hidden_dim=64, output_dim=2).to(device)
    elif model_type == 'sage':
        model = GraphSAGE(input_dim=X.shape[1], hidden_dim=64, output_dim=2).to(device)
    else:
        raise ValueError("Choose either 'gcn' or 'sage'")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True)
    criterion = nn.CrossEntropyLoss()

    early_stopping = EarlyStopping(patience=30, verbose=True)
    save_path = f'saved_models/best_{model_type}_model.pt'

    data_train = Data(x=torch.tensor(X_train, dtype=torch.float).to(device),
                      edge_index=build_knn_graph(X_train, k).to(device),
                      y=torch.tensor(y_train, dtype=torch.long).to(device))
    
    for epoch in range(1, 501):
        model.train()
        optimizer.zero_grad()
        out = model(data_train.x, data_train.edge_index)
        loss = criterion(out, data_train.y)
        loss.backward()
        optimizer.step()

        if epoch % 5 == 0:
            train_acc, _, _, _ = evaluate(model, X_train, y_train, k)
            val_acc, val_sens, val_spec, val_auc = evaluate(model, X_val, y_val, k)
            print(f'{model_type.upper()} Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}')

            scheduler.step(val_auc)
            early_stopping(val_auc, model, save_path)

            if early_stopping.early_stop:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load(save_path))
    test_acc, test_sens, test_spec, test_auc = evaluate(model, X_test, y_test, k)
    print(f'\n{model_type.upper()} Test Metrics:')
    print(f'Accuracy: {test_acc:.4f}')
    print(f'Sensitivity: {test_sens:.4f}')
    print(f'Specificity: {test_spec:.4f}')
    print(f'AUC: {test_auc:.4f}')


In [5]:
run_training('gcn')

GCN Epoch 005 | Loss: 0.6937 | Train Acc: 0.5377 | Val Acc: 0.5334 | Val AUC: 0.5334
GCN Epoch 010 | Loss: 0.6941 | Train Acc: 0.5510 | Val Acc: 0.5438 | Val AUC: 0.5438
GCN Epoch 015 | Loss: 0.6849 | Train Acc: 0.5550 | Val Acc: 0.5236 | Val AUC: 0.5236
EarlyStopping counter: 1 out of 30
GCN Epoch 020 | Loss: 0.6852 | Train Acc: 0.5521 | Val Acc: 0.5207 | Val AUC: 0.5207
EarlyStopping counter: 2 out of 30
GCN Epoch 025 | Loss: 0.6825 | Train Acc: 0.5659 | Val Acc: 0.5282 | Val AUC: 0.5282
EarlyStopping counter: 3 out of 30
GCN Epoch 030 | Loss: 0.6817 | Train Acc: 0.5700 | Val Acc: 0.5363 | Val AUC: 0.5363
EarlyStopping counter: 4 out of 30
GCN Epoch 035 | Loss: 0.6810 | Train Acc: 0.5675 | Val Acc: 0.5282 | Val AUC: 0.5282
EarlyStopping counter: 5 out of 30
GCN Epoch 040 | Loss: 0.6804 | Train Acc: 0.5675 | Val Acc: 0.5334 | Val AUC: 0.5334
EarlyStopping counter: 6 out of 30
GCN Epoch 045 | Loss: 0.6800 | Train Acc: 0.5711 | Val Acc: 0.5340 | Val AUC: 0.5340
EarlyStopping counter: 7 

In [6]:
run_training('sage')

SAGE Epoch 005 | Loss: 0.6951 | Train Acc: 0.5504 | Val Acc: 0.5207 | Val AUC: 0.5207
SAGE Epoch 010 | Loss: 0.6839 | Train Acc: 0.5586 | Val Acc: 0.5392 | Val AUC: 0.5392
SAGE Epoch 015 | Loss: 0.6826 | Train Acc: 0.5608 | Val Acc: 0.5300 | Val AUC: 0.5300
EarlyStopping counter: 1 out of 30
SAGE Epoch 020 | Loss: 0.6784 | Train Acc: 0.5700 | Val Acc: 0.5288 | Val AUC: 0.5288
EarlyStopping counter: 2 out of 30
SAGE Epoch 025 | Loss: 0.6774 | Train Acc: 0.5677 | Val Acc: 0.5426 | Val AUC: 0.5426
SAGE Epoch 030 | Loss: 0.6760 | Train Acc: 0.5692 | Val Acc: 0.5282 | Val AUC: 0.5282
EarlyStopping counter: 1 out of 30
SAGE Epoch 035 | Loss: 0.6750 | Train Acc: 0.5706 | Val Acc: 0.5305 | Val AUC: 0.5305
EarlyStopping counter: 2 out of 30
SAGE Epoch 040 | Loss: 0.6740 | Train Acc: 0.5767 | Val Acc: 0.5271 | Val AUC: 0.5271
EarlyStopping counter: 3 out of 30
SAGE Epoch 045 | Loss: 0.6730 | Train Acc: 0.5757 | Val Acc: 0.5253 | Val AUC: 0.5253
EarlyStopping counter: 4 out of 30
SAGE Epoch 050 |