In [13]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, classification_report
from sklearn.neighbors import kneighbors_graph
import random

In [14]:
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
random.seed(SEED)

df = pd.read_csv("ThoraricSurgery.csv")

y = df["Risk1Yr"].map({"F": 0, "T": 1}).values.astype(int)
X = df.drop(columns=["Risk1Yr"])

for col in X.columns:
    if X[col].dtype == "object":
        le = LabelEncoder()
        X[col] = le.fit_transform(X[col])

scaler = StandardScaler()
X = scaler.fit_transform(X)

k = 5
knn_graph = kneighbors_graph(X, k, mode="connectivity", include_self=False)

row, col = knn_graph.nonzero()
edge_index = torch.tensor([row, col], dtype=torch.long)

x = torch.tensor(X, dtype=torch.float)
y_torch = torch.tensor(y, dtype=torch.long)

train_idx, test_idx = train_test_split(
    range(len(y)), test_size=0.2, stratify=y, random_state=SEED
)

train_mask = torch.zeros(len(y), dtype=torch.bool)
test_mask = torch.zeros(len(y), dtype=torch.bool)
train_mask[train_idx] = True
test_mask[test_idx] = True

data = Data(x=x, edge_index=edge_index, y=y_torch,
            train_mask=train_mask, test_mask=test_mask)

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=0.3, training=self.training)
        x = self.conv2(x, edge_index)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN(in_channels=data.num_features, hidden_channels=32, out_channels=2).to(device)
data = data.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test():
    model.eval()
    out = model(data)
    pred = out.argmax(dim=1)
    y_true = data.y[data.test_mask].cpu().numpy()
    y_pred = pred[data.test_mask].cpu().numpy()
    y_prob = F.softmax(out, dim=1)[:, 1][data.test_mask].cpu().numpy()
    return (
        accuracy_score(y_true, y_pred),
        f1_score(y_true, y_pred),
        roc_auc_score(y_true, y_prob),
        classification_report(y_true, y_pred)
    )

print("\n===== GNN Training =====")
for epoch in range(1, 101):
    loss = train()
    if epoch % 10 == 0:
        acc, f1, auc, _ = test()
        print(f"Epoch {epoch:03d}, Loss: {loss:.4f}, Acc: {acc:.4f}, F1: {f1:.4f}, AUC: {auc:.4f}")

acc, f1, auc, report = test()
print("\n===== GNN Final Results =====")
print(f"Accuracy: {acc:.4f}, F1 Score: {f1:.4f}, ROC-AUC: {auc:.4f}")
print("\nClassification Report:\n", report)



===== GNN Training =====
Epoch 010, Loss: 0.4061, Acc: 0.8404, F1: 0.0000, AUC: 0.6732
Epoch 020, Loss: 0.3609, Acc: 0.8191, F1: 0.0000, AUC: 0.6848
Epoch 030, Loss: 0.3592, Acc: 0.8298, F1: 0.0000, AUC: 0.6411
Epoch 040, Loss: 0.3428, Acc: 0.8191, F1: 0.0000, AUC: 0.6634
Epoch 050, Loss: 0.3334, Acc: 0.8191, F1: 0.0000, AUC: 0.6571
Epoch 060, Loss: 0.3222, Acc: 0.8298, F1: 0.0000, AUC: 0.6429
Epoch 070, Loss: 0.3161, Acc: 0.8298, F1: 0.0000, AUC: 0.6277
Epoch 080, Loss: 0.3079, Acc: 0.8191, F1: 0.0000, AUC: 0.6098
Epoch 090, Loss: 0.3006, Acc: 0.8191, F1: 0.0000, AUC: 0.5991
Epoch 100, Loss: 0.2991, Acc: 0.8191, F1: 0.0000, AUC: 0.6018

===== GNN Final Results =====
Accuracy: 0.8191, F1 Score: 0.0000, ROC-AUC: 0.6018

Classification Report:
               precision    recall  f1-score   support

           0       0.85      0.96      0.90        80
           1       0.00      0.00      0.00        14

    accuracy                           0.82        94
   macro avg       0.42     