## GAT-Based Loan Default Prediction

This notebook implements a GAT to predict loan defaults on graph-structured data. It includes custom metrics, focal loss for class imbalance, and NeighborLoader sampling for efficient mini-batch training. Model performance is evaluated using AUC, F1-score, and bootstrap confidence intervals.

Note: This script is intended for academic reference only.

In [None]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv, TransformerConv
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader

import os

def find_best_f1_threshold(labels, probs):
    
    best_f1, best_threshold = 0.0, 0.5
    for t in np.linspace(0.1, 0.9, 81):
        preds = (probs > t).astype(int)
        f1 = custom_f1_score(labels, preds)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = t
    return best_threshold, best_f1

def custom_f1_score(y_true, y_pred_binary):
    
    y_true = np.asarray(y_true)
    y_pred_binary = np.asarray(y_pred_binary)
    y_true_bool = y_true.astype(bool)
    y_pred_bool = y_pred_binary.astype(bool)
    tp = np.sum(y_true_bool & y_pred_bool)
    fp = np.sum(~y_true_bool & y_pred_bool)
    fn = np.sum(y_true_bool & ~y_pred_bool)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0.0
    
    return f1

def custom_roc_auc_score(y_true, y_scores):
    
    y_true = np.asarray(y_true)
    y_scores = np.asarray(y_scores)
    sorted_indices = np.argsort(y_scores)[::-1]
    y_true_sorted = y_true[sorted_indices]
    num_positive = np.sum(y_true == 1)
    num_negative = np.sum(y_true == 0)
    tpr_points = [0.0] 
    fpr_points = [0.0]
    tp_count = 0
    fp_count = 0

    for i in range(len(y_true_sorted)):
        
        if y_true_sorted[i] == 1:
            tp_count += 1
        else:
            fp_count += 1
            
        current_tpr = tp_count / num_positive
        current_fpr = fp_count / num_negative

        if not (current_tpr == tpr_points[-1] and current_fpr == fpr_points[-1]):
            tpr_points.append(current_tpr)
            fpr_points.append(current_fpr)

    if fpr_points[-1] != 1.0 or tpr_points[-1] != 1.0:
        fpr_points.append(1.0)
        tpr_points.append(1.0)

    auc = 0.0
    for i in range(len(fpr_points) - 1):
        auc += (fpr_points[i+1] - fpr_points[i]) * (tpr_points[i+1] + tpr_points[i]) / 2.0
            
    return auc

def custom_resample(data_array, replace=True, n_samples=None, random_state=None):

    if n_samples is None:
        n_samples = len(data_array)
            
    if random_state is not None:
        np.random.seed(random_state)
            
    return np.random.choice(data_array, size=n_samples, replace=replace)

class FocalLoss(nn.Module):
    
    def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss

        return F_loss.mean() if self.reduce else F_loss

class Config:

    TRAIN_GRAPH_PATH = "../data/graph_data/train_graph.pt"
    TEST_GRAPH_PATH = "../data/graph_data/test_graph.pt"
    MODEL_SAVE_PATH = "./trained_gat_model.pt"
    IN_CHANNELS = 13
    HIDDEN_CHANNELS = 128
    OUT_CHANNELS = 1
    HEADS = 8 
    NUM_NEIGHBORS = [20, 10]
    LEARNING_RATE = 0.005
    EPOCHS = 200
    BATCH_SIZE = 128
    NUM_WORKERS = 4
    EDGE_DIM = 4 

class GAT(nn.Module):
    
    def __init__(self, in_channels, hidden_channels, out_channels, heads,
                 attn_dropout=0.6, feat_dropout=0.5, edge_dim=None):
        super().__init__()

        self.conv1 = TransformerConv(in_channels, hidden_channels,
                                     heads=heads, concat=True,
                                     dropout=attn_dropout,
                                     edge_dim=edge_dim)
        self.norm1 = nn.LayerNorm(hidden_channels * heads)
        self.dropout = nn.Dropout(feat_dropout)
        self.conv2 = TransformerConv(hidden_channels * heads, out_channels,
                                     heads=1, concat=False,
                                     dropout=attn_dropout,
                                     edge_dim=edge_dim)

    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index, edge_attr)
        x = self.norm1(x)              
        x = F.elu(x)
        x = self.dropout(x)
        x = self.conv2(x, edge_index, edge_attr)
        return x

def load_graph_data(path):
    
    data = torch.load(path)
    if hasattr(data, 'edge_index') and data.edge_index is not None:
        if data.edge_index.dtype != torch.long:
            print(f"Warning: Converting edge_index from {data.edge_index.dtype} to torch.long for {path}")
            data.edge_index = data.edge_index.long()
        if not data.edge_index.is_contiguous():
            print(f"Warning: Making edge_index contiguous for {path}")
            data.edge_index = data.edge_index.contiguous()
            
    if hasattr(data, 'edge_attr') and data.edge_attr is not None:
        if data.edge_attr.dtype != torch.float:
            print(f"Warning: Converting edge_attr from {data.edge_attr.dtype} to torch.float for {path}")
            data.edge_attr = data.edge_attr.float()
            
    if hasattr(data, 'default') and data.default is not None:
        labels = data.default
        if labels.dtype != torch.float:
            print(f"Warning: Converting 'default' labels from {labels.dtype} to torch.float for {path}")
            data.y = labels.float()
        else:
            data.y = labels
        del data.default
    elif hasattr(data, 'y') and data.y is not None:
        if data.y.dtype != torch.float:
            print(f"Warning: Converting existing 'y' labels from {data.y.dtype} to torch.float for {path}")
            data.y = data.y.float()
    else:
        print(f"Warning: No 'y' or 'default' labels found in {path}. Please ensure labels are correctly loaded.")

    return data

def calculate_accuracy(predictions, targets):
    correct = (predictions == targets).sum().item()
    acc = correct / targets.size(0)
    return acc

def train():
    full_train_data = load_graph_data(Config.TRAIN_GRAPH_PATH)
    full_test_data = load_graph_data(Config.TEST_GRAPH_PATH)

    full_train_data = full_train_data.to('cpu')
    full_test_data = full_test_data.to('cpu')

    print("\n--- Loaded Graph Data Memory Footprint (Full Graph) ---")
    total_loaded_mem_bytes = 0
    if hasattr(full_train_data, 'x') and full_train_data.x is not None:
        print(f"Train features (x) shape: {full_train_data.x.shape}")
        x_mem_bytes = full_train_data.x.element_size() * full_train_data.x.nelement()
        total_loaded_mem_bytes += x_mem_bytes
        print(f"Train features (x) size: {x_mem_bytes / (1024**3):.2f} GB (dtype: {full_train_data.x.dtype})")
    if hasattr(full_train_data, 'edge_index') and full_train_data.edge_index is not None:
        print(f"Train edge_index shape: {full_train_data.edge_index.shape}")
        edge_index_mem_bytes = full_train_data.edge_index.element_size() * full_train_data.edge_index.nelement()
        total_loaded_mem_bytes += edge_index_mem_bytes
        print(f"Train edge_index size: {edge_index_mem_bytes / (1024**3):.2f} GB (dtype: {full_train_data.edge_index.dtype})")
    if hasattr(full_train_data, 'edge_attr') and full_train_data.edge_attr is not None:
        print(f"Train edge_attr shape: {full_train_data.edge_attr.shape}")
        edge_attr_mem_bytes = full_train_data.edge_attr.element_size() * full_train_data.edge_attr.nelement()
        total_loaded_mem_bytes += edge_attr_mem_bytes
        print(f"Train edge_attr size: {edge_attr_mem_bytes / (1024**3):.2f} GB (dtype: {full_train_data.edge_attr.dtype})")
    if hasattr(full_train_data, 'y') and full_train_data.y is not None:
        print(f"Train labels (y) shape: {full_train_data.y.shape}")
        y_mem_bytes = full_train_data.y.element_size() * full_train_data.y.nelement()
        total_loaded_mem_bytes += y_mem_bytes
        print(f"Train labels (y) size: {y_mem_bytes / (1024**3):.2f} GB (dtype: {full_train_data.y.dtype})")
            
    print(f"Total estimated memory for full_train_data: {total_loaded_mem_bytes / (1024**3):.2f} GB\n")

    in_channels = full_train_data.x.shape[1]
    out_channels = Config.OUT_CHANNELS

    model = GAT(in_channels=in_channels,
                hidden_channels=Config.HIDDEN_CHANNELS,
                out_channels=out_channels,
                heads=Config.HEADS,
                edge_dim=Config.EDGE_DIM)
    model.to('cpu')
            
    optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE, weight_decay=5e-4)
    num_positive_train = torch.sum(full_train_data.y == 1).item()
    num_negative_train = torch.sum(full_train_data.y == 0).item()
    
    if num_positive_train == 0:
        print("Warning: No positive samples in training data.")
        pos_weight = torch.tensor([1.0], dtype=torch.float32)
    else:
        pos_weight_value = num_negative_train / num_positive_train
        pos_weight = torch.tensor([pos_weight_value], dtype=torch.float32).to('cpu')
    
    criterion = FocalLoss(alpha=1, gamma=2, logits=True)
    positives = torch.where(full_train_data.y == 1)[0]
    negatives = torch.where(full_train_data.y == 0)[0]
 
    target_neg_samples = len(positives) * 5
    if target_neg_samples > len(negatives):
        print(f"Warning: Desired negative samples ({target_neg_samples}) exceeds available negatives ({len(negatives)}). Sampling all negatives.")
        sampled_negatives = negatives
    else:
        sampled_negatives = negatives[torch.randperm(len(negatives))[:target_neg_samples]]
    
    balanced_nodes = torch.cat([positives, sampled_negatives])
    print(f"Training with oversampled nodes: {len(positives)} positives, {len(sampled_negatives)} sampled negatives.")
    
    train_loader = NeighborLoader(
        full_train_data,
        num_neighbors=Config.NUM_NEIGHBORS,
        batch_size=Config.BATCH_SIZE,
        input_nodes=balanced_nodes,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        drop_last=True
    )

    test_loader = NeighborLoader(
        full_test_data,
        num_neighbors=Config.NUM_NEIGHBORS,
        batch_size=Config.BATCH_SIZE,
        input_nodes=torch.arange(full_test_data.num_nodes),
        shuffle=False,
        num_workers=Config.NUM_WORKERS
    )

    best_auc = 0.0
    counter = 0
    patience = 10
    best_model_state = None

    for epoch in range(Config.EPOCHS):
        model.train()
        total_loss = 0
        for batch_idx, batch_data in enumerate(train_loader):
            batch_x = batch_data.x.to('cpu')
            batch_edge_index = batch_data.edge_index.to('cpu')
            batch_labels = batch_data.y.to('cpu')
            batch_edge_attr = batch_data.edge_attr.to('cpu')

            optimizer.zero_grad()
            
            out_full_subgraph = model(batch_x, batch_edge_index, batch_edge_attr).squeeze()
            bs = batch_data.batch_size if hasattr(batch_data, 'batch_size') else Config.BATCH_SIZE
            out = out_full_subgraph[:bs]
            labels = batch_labels[:bs]

            loss = criterion(out, labels.float())
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()

        print(f'Epoch {epoch:03d} — Avg Loss: {total_loss / len(train_loader):.4f}')

        model.eval()
        all_preds_proba = []
        all_labels = []

        with torch.no_grad():
            for batch_idx, batch_data in enumerate(test_loader):
                batch_x = batch_data.x.to('cpu')
                batch_edge_index = batch_data.edge_index.to('cpu')
                batch_labels = batch_data.y[:batch_data.batch_size].to('cpu').long()
                batch_edge_attr = batch_data.edge_attr.to('cpu')
                outputs_full = model(batch_x, batch_edge_index, batch_edge_attr).squeeze()
                outputs = outputs_full[:batch_data.batch_size]
                probs = torch.sigmoid(outputs)
                all_preds_proba.extend(probs.tolist())
                all_labels.extend(batch_labels.tolist())

        final_preds_proba = np.array(all_preds_proba)
        final_labels = np.array(all_labels)
        print("Predicted probabilities range:", np.min(final_preds_proba), "-", np.max(final_preds_proba))

        if np.max(final_preds_proba) < 0.1:
            print("Warning: Narrow prediction range. Model may not be learning effectively.")

        threshold, best_f1 = find_best_f1_threshold(final_labels, final_preds_proba)
        final_preds_binary = (final_preds_proba > threshold).astype(int)
        print(f"Optimal threshold based on F1: {threshold:.2f}")

        try:
            auc_score = custom_roc_auc_score(final_labels, final_preds_proba)
            f1 = custom_f1_score(final_labels, final_preds_binary)
            acc = calculate_accuracy(torch.tensor(final_preds_binary), torch.tensor(final_labels))

            print(f"AUC: {auc_score:.4f} — F1: {f1:.4f} — Acc: {acc:.4f}")
        except ValueError as e:
            print(f"Metric computation error: {e}")
            auc_score = 0

        if auc_score > best_auc:
            best_auc = auc_score
            best_model_state = model.state_dict()
            counter = 0
            print(f"NEW BEST AUC: {auc_score:.4f}")
        else:
            counter += 1
            print(f"No improvement for {counter} epoch(s).")
            if counter >= patience:
                print("### Early stopping triggered ###")
                break

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print("### Restored best model weights (early stopping) ###")

    n_bootstraps = 1000
    bootstrapped_auc_scores = []
    bootstrapped_f1_scores = []

    print(f"\nPerforming {n_bootstraps} bootstraps for AUC and F1-Score.")

    for i in range(n_bootstraps):
        indices = custom_resample(np.arange(len(final_labels)), replace=True, n_samples=len(final_labels), random_state=i)

        if len(np.unique(final_labels[indices])) < 2:
            continue

        try:
            auc_bootstrap = custom_roc_auc_score(final_labels[indices], final_preds_proba[indices])
            bootstrapped_auc_scores.append(auc_bootstrap)

            f1_bootstrap = custom_f1_score(
                final_labels[indices],
                (final_preds_proba[indices] > threshold).astype(int)
            )
            bootstrapped_f1_scores.append(f1_bootstrap)
        except ValueError:
            continue

    if bootstrapped_auc_scores:
        print(f"Bootstrapped AUC (Mean): {np.mean(bootstrapped_auc_scores):.4f}")
        print(f"Bootstrapped AUC (95% CI): ({np.percentile(bootstrapped_auc_scores, 2.5):.4f}, {np.percentile(bootstrapped_auc_scores, 97.5):.4f})")
    else:
        print("Not enough valid bootstrap samples to calculate AUC statistics.")

    if bootstrapped_f1_scores:
        print(f"Bootstrapped F1-Score (Mean): {np.mean(bootstrapped_f1_scores):.4f}")
        print(f"Bootstrapped F1-Score (95% CI): ({np.percentile(bootstrapped_f1_scores, 2.5):.4f}, {np.percentile(bootstrapped_f1_scores, 97.5):.4f})")
    else:
        print("Not enough valid bootstrap samples to calculate F1-Score statistics.")

    os.makedirs(os.path.dirname(Config.MODEL_SAVE_PATH), exist_ok=True)
    torch.save(model.state_dict(), Config.MODEL_SAVE_PATH)
    print(f"\nModel saved to: {Config.MODEL_SAVE_PATH}")

if __name__ == '__main__':
    if not os.path.exists(os.path.dirname(Config.TRAIN_GRAPH_PATH)):
        print(f"Error: Data directory '{os.path.dirname(Config.TRAIN_GRAPH_PATH)}' not found.")
        exit()

    train()