In [8]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.utils import from_scipy_sparse_matrix
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_distances
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from scipy.sparse import csr_matrix

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [9]:
graph_method = 'cosine'  # 'knn' or 'cosine'
k_neighbors = 5          # Only used for KNN graph
cosine_threshold = 0.7   # Only used for cosine graph
hidden_dim = 64
num_epochs = 500
early_stopping_patience = 30
train_val_test_split = (0.6, 0.2, 0.2)
save_dir = 'saved_models_cosine_0.7'
os.makedirs(save_dir, exist_ok=True)

In [10]:
# ====================
# LOAD DATA
# ====================
df = pd.read_csv("../../dataset/dataset.csv")  

# 2. Define your feature columns and label
features = ['CHF_F30',
            'HICHOLRP',
            'INCONT',
            'BKBONMOM',
            'PREG',
            'AGE',
            'ETHNICNIH',
            'F45CALC',
            'F60ALCWK',
            'F60CALC',]

outcome = 'ANYFX'

X = df[features].values
y = df[outcome].values

In [11]:
# Balance classes
minority_class = df[outcome].value_counts().idxmin()
majority_class = df[outcome].value_counts().idxmax()

minority_df = df[df[outcome] == minority_class]
majority_df = df[df[outcome] == majority_class].sample(n=len(minority_df), random_state=42)

df_balanced = pd.concat([minority_df, majority_df]).sample(frac=1, random_state=42).reset_index(drop=True)

X = df_balanced[features].values
y = df_balanced[outcome].values

# Stratified split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
for train_val_idx, test_idx in sss.split(X, y):
    X_temp, X_test = X[train_val_idx], X[test_idx]
    y_temp, y_test = y[train_val_idx], y[test_idx]

sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
for train_idx, val_idx in sss_val.split(X_temp, y_temp):
    X_train, X_val = X_temp[train_idx], X_temp[val_idx]
    y_train, y_val = y_temp[train_idx], y_temp[val_idx]

In [12]:
# ====================
# SCALING
# ====================
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_val = scaler.transform(X_val)
X_test = scaler.transform(X_test)

# ====================
# Graph Construction
# ====================
def build_graph(X_split, method='knn', k=5, threshold=0.5):
    if method == 'knn':
        knn = NearestNeighbors(n_neighbors=k)
        knn.fit(X_split)
        adj_matrix = knn.kneighbors_graph(X_split, mode='connectivity')
        edge_index, _ = from_scipy_sparse_matrix(csr_matrix(adj_matrix))
        return edge_index

    elif method == 'cosine':
        distances = cosine_distances(X_split)
        adj_matrix = (distances <= threshold).astype(np.float32)
        np.fill_diagonal(adj_matrix, 0)
        edge_index, _ = from_scipy_sparse_matrix(csr_matrix(adj_matrix))
        return edge_index

    else:
        raise ValueError("Unknown graph construction method. Choose 'knn' or 'cosine'.")


# ====================
# Model Definitions
# ====================
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        return x

class GraphSAGE(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        return x

# ====================
# Early Stopping
# ====================
class EarlyStopping:
    def __init__(self, patience=20, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_auc, model, path):
        score = val_auc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model, path)
            self.counter = 0

    def save_checkpoint(self, model, path):
        torch.save(model.state_dict(), path)

# ====================
# Training & Evaluation
# ====================
def train(model, optimizer, data, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(model, X_split, y_split, method='knn', k=5, threshold=0.5):
    model.eval()
    edge_index = build_graph(X_split, method=method, k=k, threshold=threshold)
    data_split = Data(x=torch.tensor(X_split, dtype=torch.float).to(device),
                      edge_index=edge_index.to(device))

    logits = model(data_split.x, data_split.edge_index)
    preds = logits.argmax(dim=1).cpu().numpy()

    acc = accuracy_score(y_split, preds)
    auc = roc_auc_score(y_split, preds)
    tn, fp, fn, tp = confusion_matrix(y_split, preds).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    return acc, sensitivity, specificity, auc

# ====================
# Main Training Loop
# ====================
def run_training(model_type='gcn'):
    if model_type == 'gcn':
        model = GCN(input_dim=X.shape[1], hidden_dim=hidden_dim, output_dim=2).to(device)
    elif model_type == 'sage':
        model = GraphSAGE(input_dim=X.shape[1], hidden_dim=hidden_dim, output_dim=2).to(device)
    else:
        raise ValueError("Choose either 'gcn' or 'sage'")

    optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True)
    criterion = nn.CrossEntropyLoss()

    early_stopping = EarlyStopping(patience=early_stopping_patience, verbose=True)
    save_path = f'{save_dir}/best_{model_type}_model.pt'

    edge_index_train = build_graph(X_train, method=graph_method, k=k_neighbors, threshold=cosine_threshold)
    data_train = Data(x=torch.tensor(X_train, dtype=torch.float).to(device),
                      edge_index=edge_index_train.to(device),
                      y=torch.tensor(y_train, dtype=torch.long).to(device))

    for epoch in range(1, num_epochs + 1):
        loss = train(model, optimizer, data_train, criterion)
        if epoch % 5 == 0:
            train_acc, _, _, _ = evaluate(model, X_train, y_train, method=graph_method, k=k_neighbors, threshold=cosine_threshold)
            val_acc, val_sens, val_spec, val_auc = evaluate(model, X_val, y_val, method=graph_method, k=k_neighbors, threshold=cosine_threshold)
            print(f'{model_type.upper()} Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}')

            scheduler.step(val_auc)
            early_stopping(val_auc, model, save_path)

            if early_stopping.early_stop:
                print("Early stopping triggered.")
                break

    model.load_state_dict(torch.load(save_path))
    test_acc, test_sens, test_spec, test_auc = evaluate(model, X_test, y_test, method=graph_method, k=k_neighbors, threshold=cosine_threshold)
    print(f'\n{model_type.upper()} Test Metrics:')
    print(f'Accuracy: {test_acc:.4f}')
    print(f'Sensitivity: {test_sens:.4f}')
    print(f'Specificity: {test_spec:.4f}')
    print(f'AUC: {test_auc:.4f}')

In [13]:
run_training(model_type='gcn')

GCN Epoch 005 | Loss: 0.6914 | Train Acc: 0.5275 | Val Acc: 0.5567 | Val AUC: 0.5566
GCN Epoch 010 | Loss: 0.6901 | Train Acc: 0.5373 | Val Acc: 0.5757 | Val AUC: 0.5757
GCN Epoch 015 | Loss: 0.6880 | Train Acc: 0.5467 | Val Acc: 0.5682 | Val AUC: 0.5682
EarlyStopping counter: 1 out of 30
GCN Epoch 020 | Loss: 0.6873 | Train Acc: 0.5486 | Val Acc: 0.5550 | Val AUC: 0.5550
EarlyStopping counter: 2 out of 30
GCN Epoch 025 | Loss: 0.6862 | Train Acc: 0.5490 | Val Acc: 0.5625 | Val AUC: 0.5625
EarlyStopping counter: 3 out of 30
GCN Epoch 030 | Loss: 0.6872 | Train Acc: 0.5497 | Val Acc: 0.5682 | Val AUC: 0.5682
EarlyStopping counter: 4 out of 30
GCN Epoch 035 | Loss: 0.6868 | Train Acc: 0.5478 | Val Acc: 0.5659 | Val AUC: 0.5659
EarlyStopping counter: 5 out of 30
GCN Epoch 040 | Loss: 0.6869 | Train Acc: 0.5484 | Val Acc: 0.5728 | Val AUC: 0.5728
EarlyStopping counter: 6 out of 30
GCN Epoch 045 | Loss: 0.6855 | Train Acc: 0.5457 | Val Acc: 0.5723 | Val AUC: 0.5723
EarlyStopping counter: 7 

In [14]:
run_training(model_type='sage')

SAGE Epoch 005 | Loss: 0.6928 | Train Acc: 0.5463 | Val Acc: 0.5538 | Val AUC: 0.5539
SAGE Epoch 010 | Loss: 0.6908 | Train Acc: 0.5311 | Val Acc: 0.5423 | Val AUC: 0.5421
EarlyStopping counter: 1 out of 30
SAGE Epoch 015 | Loss: 0.6873 | Train Acc: 0.5541 | Val Acc: 0.5642 | Val AUC: 0.5642
SAGE Epoch 020 | Loss: 0.6859 | Train Acc: 0.5584 | Val Acc: 0.5636 | Val AUC: 0.5636
EarlyStopping counter: 1 out of 30
SAGE Epoch 025 | Loss: 0.6841 | Train Acc: 0.5630 | Val Acc: 0.5676 | Val AUC: 0.5676
SAGE Epoch 030 | Loss: 0.6840 | Train Acc: 0.5549 | Val Acc: 0.5625 | Val AUC: 0.5625
EarlyStopping counter: 1 out of 30
SAGE Epoch 035 | Loss: 0.6837 | Train Acc: 0.5684 | Val Acc: 0.5665 | Val AUC: 0.5665
EarlyStopping counter: 2 out of 30
SAGE Epoch 040 | Loss: 0.6842 | Train Acc: 0.5672 | Val Acc: 0.5682 | Val AUC: 0.5682
SAGE Epoch 045 | Loss: 0.6803 | Train Acc: 0.5684 | Val Acc: 0.5671 | Val AUC: 0.5671
EarlyStopping counter: 1 out of 30
SAGE Epoch 050 | Loss: 0.6804 | Train Acc: 0.5741 |