In [1]:
import os
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric.utils import from_scipy_sparse_matrix
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.neighbors import NearestNeighbors
from scipy.sparse import csr_matrix

In [2]:
# 1. Load your data
df = pd.read_csv("../../dataset/dataset.csv")  

# 2. Define your feature columns and label
small_columns = ['CHF_F30',
                 'HICHOLRP',
                 'INCONT',
                 'BKBONMOM',
                 'PREG',
                 'AGE',
                 'ETHNICNIH',
                 'F45CALC',
                 'F60ALCWK',
                 'F60CALC',]

outcome_column = 'ANYFX'

In [3]:
# Separate the classes
class_0 = df[df[outcome_column] == 0]
class_1 = df[df[outcome_column] == 1]

# Undersample the majority class
if len(class_0) > len(class_1):
    class_0 = class_0.sample(n=len(class_1), random_state=42)
else:
    class_1 = class_1.sample(n=len(class_0), random_state=42)

# Combine back
df_balanced = pd.concat([class_0, class_1]).sample(frac=1, random_state=42).reset_index(drop=True)

X = df_balanced[small_columns].values
y = df_balanced[outcome_column].values

# Create stratified splits
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
for train_val_idx, test_idx in sss.split(X, y):
    X_temp, X_test = X[train_val_idx], X[test_idx]
    y_temp, y_test = y[train_val_idx], y[test_idx]

sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)  # 0.25 of temp = 0.2 overall
for train_idx, val_idx in sss_val.split(X_temp, y_temp):
    X_train, X_val = X_temp[train_idx], X_temp[val_idx]
    y_train, y_val = y_temp[train_idx], y_temp[val_idx]


In [4]:
# 4. Build KNN graph
def build_knn_graph(X, k=5):
    nbrs = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(X)
    edge_index = []
    for idx, neighbors in enumerate(nbrs.kneighbors(X, return_distance=False)):
        for neighbor in neighbors:
            if idx != neighbor:
                edge_index.append([idx, neighbor])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    return edge_index

X_all = np.vstack([X_train, X_val, X_test])
y_all = np.hstack([y_train, y_val, y_test])

edge_index = build_knn_graph(X_all, k=5)

# 5. Save edge matrix (adjacency matrix)
def save_edge_matrix(edge_index, num_nodes, filename='edge_matrix.npy'):
    adj = np.zeros((num_nodes, num_nodes))
    edges = edge_index.cpu().numpy()
    adj[edges[0], edges[1]] = 1
    np.save(filename, adj)

save_edge_matrix(edge_index, num_nodes=X_all.shape[0])

In [5]:
# ====================
# 2. KNN Graph Construction
# ====================

k = 5
knn = NearestNeighbors(n_neighbors=k)
knn.fit(X_train)
knn_graph = knn.kneighbors_graph(X_train, mode='connectivity')

edge_index, _ = from_scipy_sparse_matrix(csr_matrix(knn_graph))

# Save edge index for future use
os.makedirs('saved_models', exist_ok=True)
torch.save(edge_index, 'saved_models/edge_index.pt')

# Create PyG Data object
data = Data(x=torch.tensor(X_train, dtype=torch.float),
            edge_index=edge_index,
            y=torch.tensor(y_train, dtype=torch.long))


In [6]:
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

In [7]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        return x

In [8]:
class GraphSAGE2(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.lin = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        return x

In [9]:
# 6. Create PyG Data object
x_tensor = torch.tensor(X_all, dtype=torch.float)
y_tensor = torch.tensor(y_all, dtype=torch.long)

data = Data(x=x_tensor, edge_index=edge_index, y=y_tensor)

# 7. Create train/val/test masks
num_nodes = data.num_nodes
train_mask = torch.zeros(num_nodes, dtype=torch.bool)
val_mask = torch.zeros(num_nodes, dtype=torch.bool)
test_mask = torch.zeros(num_nodes, dtype=torch.bool)

train_mask[:len(X_train)] = True
val_mask[len(X_train):len(X_train) + len(X_val)] = True
test_mask[-len(X_test):] = True

data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

In [10]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

In [11]:
def evaluate(split):
    model.eval()
    out = model(data.x, data.edge_index)
    pred_probs = F.softmax(out, dim=1)[:, 1]  # Probability for class 1
    pred = out.argmax(dim=1)

    if split == 'val':
        mask = data.val_mask
    elif split == 'test':
        mask = data.test_mask
    else:
        mask = data.train_mask

    y_true = data.y[mask].cpu().numpy()
    y_pred = pred[mask].cpu().numpy()
    y_score = pred_probs[mask].detach().cpu().numpy()

    acc = (y_true == y_pred).mean()

    # Sensitivity, specificity
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    # AUC
    auc = roc_auc_score(y_true, y_score)

    return acc, sensitivity, specificity, auc

In [12]:
class EarlyStopping:
    def __init__(self, patience=20, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_auc, model, path):
        score = val_auc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model, path)
            self.counter = 0

    def save_checkpoint(self, model, path):
        torch.save(model.state_dict(), path)

In [13]:
# Base GNN model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GNN(input_dim=x_tensor.shape[1], hidden_dim=64, output_dim=len(np.unique(y_all))).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True)


save_dir = 'saved_models'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'best_gnn_model.pt')

early_stopping = EarlyStopping(patience=30, verbose=True)

for epoch in range(1, 501):
    loss = train()
    if epoch % 5 == 0:
        train_acc, _, _, _ = evaluate('train')
        val_acc, val_sens, val_spec, val_auc = evaluate('val')
        print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}')
        
        # Step the scheduler
        scheduler.step(val_auc)

        # Check early stopping
        early_stopping(val_auc, model, save_path)

        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

Epoch 005 | Loss: 33.4024 | Train Acc: 0.5000 | Val Acc: 0.4997 | Val AUC: 0.5299
Epoch 010 | Loss: 13.8016 | Train Acc: 0.4846 | Val Acc: 0.4940 | Val AUC: 0.4996
EarlyStopping counter: 1 out of 30
Epoch 015 | Loss: 10.5827 | Train Acc: 0.5012 | Val Acc: 0.5049 | Val AUC: 0.4967
EarlyStopping counter: 2 out of 30
Epoch 020 | Loss: 13.0077 | Train Acc: 0.4971 | Val Acc: 0.5066 | Val AUC: 0.5210
EarlyStopping counter: 3 out of 30
Epoch 025 | Loss: 7.9737 | Train Acc: 0.5000 | Val Acc: 0.4986 | Val AUC: 0.5322
Epoch 030 | Loss: 10.6202 | Train Acc: 0.5023 | Val Acc: 0.5055 | Val AUC: 0.5262
EarlyStopping counter: 1 out of 30
Epoch 035 | Loss: 2.6544 | Train Acc: 0.4988 | Val Acc: 0.5049 | Val AUC: 0.4857
EarlyStopping counter: 2 out of 30
Epoch 040 | Loss: 11.5749 | Train Acc: 0.5008 | Val Acc: 0.5009 | Val AUC: 0.5306
EarlyStopping counter: 3 out of 30
Epoch 045 | Loss: 3.7405 | Train Acc: 0.4998 | Val Acc: 0.5003 | Val AUC: 0.5249
EarlyStopping counter: 4 out of 30
Epoch 050 | Loss: 11

In [14]:
# Load best model and evaluate
model.load_state_dict(torch.load(save_path))
test_acc, test_sens, test_spec, test_auc = evaluate('test')
print(f'\nTest Metrics:')
print(f'Accuracy: {test_acc:.4f}')
print(f'Sensitivity: {test_sens:.4f}')
print(f'Specificity: {test_spec:.4f}')
print(f'AUC: {test_auc:.4f}')


Test Metrics:
Accuracy: 0.5014
Sensitivity: 0.0035
Specificity: 0.9988
AUC: 0.5275


In [15]:
# GraphSAGE model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(input_dim=X.shape[1], hidden_dim=64, 
                  num_classes=2
                  ).to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True)

save_dir = 'saved_models'
os.makedirs(save_dir, exist_ok=True)
save_path = os.path.join(save_dir, 'best_graphsage_model.pt')

early_stopping = EarlyStopping(patience=30, verbose=True)

for epoch in range(1, 501):
    loss = train()
    if epoch % 5 == 0:
        train_acc, _, _, _ = evaluate('train')
        val_acc, val_sens, val_spec, val_auc = evaluate('val')
        print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}')
        
        # Step the scheduler
        scheduler.step(val_auc)

        # Check early stopping
        early_stopping(val_auc, model, save_path)

        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

TypeError: GraphSAGE.forward() takes 2 positional arguments but 3 were given

In [None]:
# Load best model and evaluate
model.load_state_dict(torch.load(save_path))
test_acc, test_sens, test_spec, test_auc = evaluate('test')
print(f'\nTest Metrics:')
print(f'Accuracy: {test_acc:.4f}')
print(f'Sensitivity: {test_sens:.4f}')
print(f'Specificity: {test_spec:.4f}')
print(f'AUC: {test_auc:.4f}')

FileNotFoundError: [Errno 2] No such file or directory: 'saved_models/best_graphsage_model.pt'

In [None]:
# Balance dataset by undersampling majority class
minority_class = df[outcome_column].value_counts().idxmin()
majority_class = df[outcome_column].value_counts().idxmax()

minority_df = df[df[outcome_column] == minority_class]
majority_df = df[df[outcome_column] == majority_class].sample(n=len(minority_df), random_state=42)

df_balanced = pd.concat([minority_df, majority_df]).sample(frac=1, random_state=42)  # shuffle

X = df_balanced[small_columns].values
y = df_balanced[outcome_column].values

# Stratified 60/20/20 split
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.20, random_state=42)
for train_val_idx, test_idx in sss.split(X, y):
    X_temp, X_test = X[train_val_idx], X[test_idx]
    y_temp, y_test = y[train_val_idx], y[test_idx]

sss_val = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
for train_idx, val_idx in sss_val.split(X_temp, y_temp):
    X_train, X_val = X_temp[train_idx], X_temp[val_idx]
    y_train, y_val = y_temp[train_idx], y_temp[val_idx]


# ====================
# 3. Model Definition
# ====================

class GraphSAGE(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_classes):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(input_dim, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.lin = nn.Linear(hidden_dim, num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = F.relu(self.conv2(x, edge_index))
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.lin(x)
        return x

# ====================
# 4. Early Stopping Class
# ====================

class EarlyStopping:
    def __init__(self, patience=30, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.delta = delta

    def __call__(self, val_auc, model, path):
        score = val_auc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model, path)
            self.counter = 0

    def save_checkpoint(self, model, path):
        torch.save(model.state_dict(), path)

# ====================
# 5. Training Setup
# ====================

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(input_dim=X.shape[1], hidden_dim=64, num_classes=2).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=10, verbose=True)
criterion = nn.CrossEntropyLoss()

save_path = 'saved_models/best_graphsage_model.pt'
early_stopping = EarlyStopping(patience=30, verbose=True)

# ====================
# 6. Training Loop
# ====================

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out, data.y)
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def evaluate(split):
    model.eval()
    if split == 'train':
        X_split, y_split = X_train, y_train
    elif split == 'val':
        X_split, y_split = X_val, y_val
    else:
        X_split, y_split = X_test, y_test

    edge_idx, _ = from_scipy_sparse_matrix(csr_matrix(knn.kneighbors_graph(X_split, mode='connectivity')))
    data_split = Data(x=torch.tensor(X_split, dtype=torch.float).to(device),
                      edge_index=edge_idx.to(device))

    logits = model(data_split.x, data_split.edge_index)
    preds = logits.argmax(dim=1).cpu().numpy()

    acc = accuracy_score(y_split, preds)
    auc = roc_auc_score(y_split, preds)
    cm = confusion_matrix(y_split, preds)
    tn, fp, fn, tp = cm.ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    return acc, sensitivity, specificity, auc

# Training
for epoch in range(1, 501):
    loss = train()
    if epoch % 5 == 0:
        train_acc, _, _, _ = evaluate('train')
        val_acc, val_sens, val_spec, val_auc = evaluate('val')
        print(f'Epoch {epoch:03d} | Loss: {loss:.4f} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | Val AUC: {val_auc:.4f}')

        scheduler.step(val_auc)
        early_stopping(val_auc, model, save_path)

        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

# Load best model
model.load_state_dict(torch.load(save_path))

# Final Evaluation
test_acc, test_sens, test_spec, test_auc = evaluate('test')
print(f'\nTest Metrics:')
print(f'Accuracy: {test_acc:.4f}')
print(f'Sensitivity: {test_sens:.4f}')
print(f'Specificity: {test_spec:.4f}')
print(f'AUC: {test_auc:.4f}')

ValueError: Encountered invalid 'dim_size' (got '1737' but expected >= '5208')