In [None]:
pip -q install torch==2.0.1 torch-geometric==2.4.0 torch-scatter torch-sparse torch-cluster torch-spline-conv scikit-learn pandas numpy tqdm


In [None]:
# hybrid_gat_mlp_churn.py
# pipeline for the WSDM–KKBox churn task.

import os
import gc
import math
import random
import numpy as np
import pandas as pd
from datetime import datetime
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import random_split

from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score
from sklearn.model_selection import train_test_split

from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GATv2Conv, BatchNorm

# -------------------------
# CONFIG
# -------------------------
CFG = {
    "PATHS": {
        "train": "/content/train_v2.csv",
        "members": "/content/members_v3.csv",
        "transactions": "/content/transactions_v2.csv",
        "user_logs": "/content/user_logs_v2.csv",  # ~30GB
    },
    "RANDOM_SEED": 42,
    "LOG_CHUNK_ROWS": 1_000_000,     # chunk size for user_logs
    "MAX_CANCEL_CAP": 4,
    "GRAPH_MAX_NEIGHBORS_PER_GROUP": 10,  # K edges per user inside each (city, registered_via) group
    "BATCH_SIZE": 512,
    "LR": 5e-3,
    "WEIGHT_DECAY": 1e-4,
    "MAX_EPOCHS": 50,
    "EARLY_STOP_PATIENCE": 5,
    "GAT_HIDDEN": 32,
    "GAT_HEADS": 4,
    "GAT_DROPOUT": 0.6,
    "MLP_HIDDEN": 64,
    "VAL_RATIO": 0.15,
    "TEST_RATIO": 0.15,
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
}

random.seed(CFG["RANDOM_SEED"])
np.random.seed(CFG["RANDOM_SEED"])
torch.manual_seed(CFG["RANDOM_SEED"])

# -------------------------
# UTILITIES
# -------------------------
def parse_date(series):
    return pd.to_datetime(series, errors="coerce")

def safe_int(x):
    try:
        return int(x)
    except:
        return 0

def dict_increment(d, key, val):
    d[key] = d.get(key, 0) + val

# -------------------------
# 3.2 DATA PREPARATION
# -------------------------
def aggregate_transactions(path_transactions):
    print("Aggregating transactions ...")
    tx = pd.read_csv(path_transactions)

    # robust date parsing
    for col in ["transaction_date", "membership_expire_date"]:
        if col in tx.columns:
            tx[col] = parse_date(tx[col])

    # payment_amount might be missing/float
    pay_col = "payment_plan_price"
    if "actual_amount_paid" in tx.columns:
        pay_col = "actual_amount_paid"  # KKBox has this sometimes

    agg = tx.groupby("msno").agg(
        total_transactions=("msno", "size"),
        total_payment=(pay_col, "sum"),
        is_cancel_sum=("is_cancel", "sum"),
        last_transaction_date=("transaction_date", "max")
    ).reset_index()

    # cap cancels as your text suggests
    agg["is_cancel_sum"] = agg["is_cancel_sum"].clip(upper=CFG["MAX_CANCEL_CAP"]).fillna(0).astype(int)
    agg["total_payment"] = agg["total_payment"].fillna(0.0)

    return agg

def aggregate_user_logs(path_logs, chunk_rows):
    print("Aggregating user logs (chunked) ...")
    # Expected cols (typical KKBox): msno, date, num_25, num_50, num_75, num_985, num_100, num_unq, total_secs
    agg_dict = {}  # msno -> dict of metrics

    usecols = None  # None = auto
    chunks = pd.read_csv(path_logs, chunksize=chunk_rows, iterator=True)
    for chunk in tqdm(chunks, desc="user_logs chunks"):
        if "date" in chunk.columns:
            chunk["date"] = parse_date(chunk["date"])
        # coerce numeric
        if "total_secs" in chunk.columns:
            chunk["total_secs"] = pd.to_numeric(chunk["total_secs"], errors="coerce").fillna(0)
        if "num_unq" in chunk.columns:
            chunk["num_unq"] = pd.to_numeric(chunk["num_unq"], errors="coerce").fillna(0)

        # aggregate at chunk level
        g = chunk.groupby("msno").agg(
            log_days=("date", "nunique") if "date" in chunk.columns else ("msno", "size"),
            total_secs_sum=("total_secs", "sum") if "total_secs" in chunk.columns else ("msno", "size"),
            total_songs_played=("num_unq", "sum") if "num_unq" in chunk.columns else ("msno", "size"),
        )

        # merge into dict
        for msno, row in g.iterrows():
            d = agg_dict.get(msno)
            if d is None:
                agg_dict[msno] = {
                    "log_days": int(row["log_days"]),
                    "total_secs_sum": float(row["total_secs_sum"]),
                    "total_songs_played": float(row["total_songs_played"]),
                }
            else:
                d["log_days"] += int(row["log_days"])
                d["total_secs_sum"] += float(row["total_secs_sum"])
                d["total_songs_played"] += float(row["total_songs_played"])

        del chunk, g
        gc.collect()

    # dict -> DataFrame
    df = pd.DataFrame.from_dict(agg_dict, orient="index").reset_index().rename(columns={"index": "msno"})
    # numeric types
    df["log_days"] = df["log_days"].astype(int)
    return df

def load_members(path_members):
    print("Loading members ...")
    mem = pd.read_csv(path_members)
    # parse registration time
    if "registration_init_time" in mem.columns:
        # KKBox format is yyyymmdd integer
        mem["registration_init_time"] = pd.to_datetime(mem["registration_init_time"], format="%Y%m%d", errors="coerce")
    # keep essential columns
    keep = ["msno", "city", "gender", "registered_via", "registration_init_time"]
    mem = mem[keep]
    return mem

def load_train_labels(path_train):
    tr = pd.read_csv(path_train)[["msno", "is_churn"]]
    tr["is_churn"] = tr["is_churn"].fillna(0).astype(int)
    return tr

def build_user_level_table(paths):
    tx_agg = aggregate_transactions(paths["transactions"])
    logs_agg = aggregate_user_logs(paths["user_logs"], CFG["LOG_CHUNK_ROWS"])
    members = load_members(paths["members"])
    labels = load_train_labels(paths["train"])

    print("Merging all sources ...")
    df = labels.merge(members, on="msno", how="left") \
               .merge(tx_agg, on="msno", how="left") \
               .merge(logs_agg, on="msno", how="left")

    # fill NaNs with zeros as per your write-up
    num_cols = ["total_transactions", "total_payment", "is_cancel_sum",
                "log_days", "total_secs_sum", "total_songs_played"]
    for c in num_cols:
        df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0)

    # 3.3 FEATURE ENGINEERING
    print("Feature engineering ...")
    df["last_transaction_date"] = parse_date(df["last_transaction_date"])
    df["registration_init_time"] = parse_date(df["registration_init_time"])

    # membership_days (difference in days)
    df["membership_days"] = (
        (df["last_transaction_date"] - df["registration_init_time"]).dt.days
    )
    df["membership_days"] = df["membership_days"].fillna(0).clip(lower=0)

    # registration_year/month
    df["registration_year"] = df["registration_init_time"].dt.year.fillna(0).astype(int)
    df["registration_month"] = df["registration_init_time"].dt.month.fillna(0).astype(int)

    # label encode low-card categorical vars
    for col in ["city", "gender", "registered_via"]:
        le = LabelEncoder()
        df[col] = df[col].fillna("unknown")
        df[col] = le.fit_transform(df[col].astype(str))

    # finalize feature set
    feature_cols = [
        "city", "gender", "registered_via",
        "total_transactions", "total_payment", "is_cancel_sum",
        "log_days", "total_secs_sum", "total_songs_played",
        "membership_days", "registration_year", "registration_month"
    ]
    # keep 15-ish features: add a few simple interactions if needed
    # (You documented 15D; depending on columns here you can add derived ones)
    # For consistency, we'll add 3 more simple transforms:
    df["avg_secs_per_day"] = (df["total_secs_sum"] / (df["log_days"] + 1e-6)).fillna(0)
    df["avg_songs_per_day"] = (df["total_songs_played"] / (df["log_days"] + 1e-6)).fillna(0)
    df["pay_per_tx"] = (df["total_payment"] / (df["total_transactions"] + 1e-6)).fillna(0)

    feature_cols.extend(["avg_secs_per_day", "avg_songs_per_day", "pay_per_tx"])

    # scaling (Z-score)
    print("Scaling numeric features ...")
    scaler = StandardScaler()
    df[feature_cols] = scaler.fit_transform(df[feature_cols].astype(float))

    return df, feature_cols, scaler

# -------------------------
# GRAPH CONSTRUCTION (Eq.4)
# -------------------------
def build_sparse_similarity_edges(df, max_neighbors=10):
    """
    Create edges for users that share BOTH (city, registered_via).
    To stay sparse on ~1M nodes, we connect each user to up to K peers
    within its group by linking to the next K users in the group's index order.
    """
    print("Building sparse similarity graph ...")
    df = df.reset_index(drop=True)
    df["node_id"] = np.arange(len(df))
    group_key = list(["city", "registered_via"])

    edges_src = []
    edges_dst = []

    for _, g in tqdm(df.groupby(group_key), total=df.groupby(group_key).ngroups):
        ids = g["node_id"].to_numpy()
        if len(ids) <= 1:
            continue
        # wire to up to K neighbors (directed edges both ways)
        K = min(max_neighbors, len(ids) - 1)
        # simple ring neighbors to keep degree bounded and graph connected inside group
        for i in range(len(ids)):
            for k in range(1, K + 1):
                j = (i + k) % len(ids)
                edges_src.append(ids[i])
                edges_dst.append(ids[j])
                edges_src.append(ids[j])
                edges_dst.append(ids[i])

    edge_index = torch.tensor([edges_src, edges_dst], dtype=torch.long)
    print(f"Edges built: {edge_index.size(1):,}")
    return edge_index, df["node_id"].values

# -------------------------
# MODEL: GAT + MLP Hybrid
# -------------------------
class HybridGATMLP(nn.Module):
    def __init__(self, in_dim_tabular, gat_in_dim, gat_hidden=32, gat_heads=4,
                 gat_dropout=0.6, mlp_hidden=64):
        super().__init__()
        # GAT encoder
        self.gat1 = GATv2Conv(gat_in_dim, gat_hidden, heads=gat_heads, dropout=gat_dropout, concat=True)
        self.bn1 = BatchNorm(gat_hidden * gat_heads)
        self.gat2 = GATv2Conv(gat_hidden * gat_heads, gat_hidden, heads=1, dropout=gat_dropout, concat=True)
        self.bn2 = BatchNorm(gat_hidden)

        # MLP path on tabular (assume same as input to GAT here; if different, pass both)
        self.fc1 = nn.Linear(in_dim_tabular, mlp_hidden)
        self.bn_tab1 = nn.BatchNorm1d(mlp_hidden)
        self.fc2 = nn.Linear(mlp_hidden, in_dim_tabular)  # project back to ~15 dims as in your text
        self.bn_tab2 = nn.BatchNorm1d(in_dim_tabular)

        self.dropout = nn.Dropout(gat_dropout)

        fused_dim = gat_hidden + in_dim_tabular  # z_i (32) + t2 (15≈in_dim_tabular)
        self.classifier = nn.Sequential(
            nn.Linear(fused_dim, fused_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(fused_dim, 1)
        )

    def forward(self, x_tab, x_gat, edge_index):
        # GAT path
        z = self.gat1(x_gat, edge_index)
        z = F.elu(self.bn1(z))
        z = self.dropout(z)
        z = self.gat2(z, edge_index)
        z = F.elu(self.bn2(z))
        z = self.dropout(z)

        # MLP path
        t = self.fc1(x_tab)
        t = F.relu(self.bn_tab1(t))
        t = self.dropout(t)
        t = self.fc2(t)
        t = F.relu(self.bn_tab2(t))
        t = self.dropout(t)

        u = torch.cat([z, t], dim=1)
        logits = self.classifier(u).squeeze(1)
        return logits

# -------------------------
# TRAIN / EVAL
# -------------------------
def train_one_epoch(model, loader, optimizer, pos_weight, device):
    model.train()
    total_loss = 0.0
    total = 0
    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()
        logits = model(batch.x_tab, batch.x_gat, batch.edge_index)
        loss = F.binary_cross_entropy_with_logits(logits, batch.y.float(), pos_weight=pos_weight)
        loss.backward()
        optimizer.step()
        total_loss += float(loss.item()) * batch.num_nodes
        total += batch.num_nodes
    return total_loss / max(1, total)

@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    ys, ps = [], []
    for batch in loader:
        batch = batch.to(device)
        logits = model(batch.x_tab, batch.x_gat, batch.edge_index)
        prob = torch.sigmoid(logits).detach().cpu().numpy()
        y = batch.y.cpu().numpy()
        ys.append(y)
        ps.append(prob)
    y_true = np.concatenate(ys)
    y_prob = np.concatenate(ps)
    y_pred = (y_prob >= 0.5).astype(int)
    auc = roc_auc_score(y_true, y_prob)
    f1 = f1_score(y_true, y_pred)
    acc = accuracy_score(y_true, y_pred)
    return auc, f1, acc

# -------------------------
# BUILD GRAPH DATA OBJECT
# -------------------------
def build_pyg_data(df, feature_cols, edge_index):
    # features for both paths are from the same standardized tabular features
    X = torch.tensor(df[feature_cols].values, dtype=torch.float32)
    y = torch.tensor(df["is_churn"].values, dtype=torch.long)

    data = Data()
    data.x_tab = X.clone()         # for MLP path
    data.x_gat = X.clone()         # for GAT path (you can choose a subset if desired)
    data.y = y
    data.edge_index = edge_index
    data.num_nodes = X.size(0)
    return data

# -------------------------
# MAIN
# -------------------------
def main():
    df, feature_cols, scaler = build_user_level_table(CFG["PATHS"])

    print(f"Users after merge: {len(df):,}")
    print("Constructing similarity edges ...")
    edge_index, node_index = build_sparse_similarity_edges(
        df[["city", "registered_via"]].copy(),
        max_neighbors=CFG["GRAPH_MAX_NEIGHBORS_PER_GROUP"]
    )

    # Build torch-geometric Data
    data = build_pyg_data(df, feature_cols, edge_index)

    # Train/Val/Test split (stratified on y)
    idx = np.arange(data.num_nodes)
    train_idx, test_idx = train_test_split(idx, test_size=CFG["TEST_RATIO"],
                                           stratify=df["is_churn"], random_state=CFG["RANDOM_SEED"])
    train_idx, val_idx = train_test_split(train_idx, test_size=CFG["VAL_RATIO"]/(1-CFG["TEST_RATIO"]),
                                          stratify=df["is_churn"].iloc[train_idx], random_state=CFG["RANDOM_SEED"])

    # masks
    data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
    data.train_mask[torch.tensor(train_idx)] = True
    data.val_mask[torch.tensor(val_idx)] = True
    data.test_mask[torch.tensor(test_idx)] = True

    # Subgraph loaders (NeighborLoader keeps graph structure)
    train_loader = NeighborLoader(
        data, num_neighbors=[15, 10], batch_size=CFG["BATCH_SIZE"], input_nodes=data.train_mask
    )
    val_loader = NeighborLoader(
        data, num_neighbors=[15, 10], batch_size=CFG["BATCH_SIZE"], input_nodes=data.val_mask
    )
    test_loader = NeighborLoader(
        data, num_neighbors=[15, 10], batch_size=CFG["BATCH_SIZE"], input_nodes=data.test_mask
    )

    device = CFG["DEVICE"]
    print("Device:", device)

    model = HybridGATMLP(
        in_dim_tabular=len(feature_cols),
        gat_in_dim=len(feature_cols),
        gat_hidden=CFG["GAT_HIDDEN"],
        gat_heads=CFG["GAT_HEADS"],
        gat_dropout=CFG["GAT_DROPOUT"],
        mlp_hidden=CFG["MLP_HIDDEN"],
    ).to(device)

    optimizer = torch.optim.AdamW(model.parameters(), lr=CFG["LR"], weight_decay=CFG["WEIGHT_DECAY"])

    # class weights: w1=10 (churn), w0=1
    pos_weight = torch.tensor([10.0], device=device)

    # Early stopping
    best_auc = -1.0
    best_state = None
    patience = CFG["EARLY_STOP_PATIENCE"]
    bad = 0

    for epoch in range(1, CFG["MAX_EPOCHS"] + 1):
        tr_loss = train_one_epoch(model, train_loader, optimizer, pos_weight, device)
        val_auc, val_f1, val_acc = evaluate(model, val_loader, device)
        print(f"Epoch {epoch:02d} | loss {tr_loss:.4f} | val AUC {val_auc:.4f} | F1 {val_f1:.4f} | Acc {val_acc:.4f}")

        if val_auc > best_auc + 1e-4:
            best_auc = val_auc
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= patience:
                print("Early stopping.")
                break

    if best_state is not None:
        model.load_state_dict({k: v.to(device) for k, v in best_state.items()})

    # Final test
    test_auc, test_f1, test_acc = evaluate(model, test_loader, device)
    print(f"TEST | AUC {test_auc:.4f} | F1 {test_f1:.4f} | Acc {test_acc:.4f}")

if __name__ == "__main__":
    main()


In [None]:
#EVALUTION REPORTS
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

# -------------------------
# EXTENDED EVALUATION
# -------------------------
@torch.no_grad()
def full_evaluation(model, loader, device, title="Test Set"):
    model.eval()
    y_true, y_prob = [], []

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch.x_tab, batch.x_gat, batch.edge_index)
        probs = torch.sigmoid(logits).cpu().numpy()
        y_prob.extend(probs)
        y_true.extend(batch.y.cpu().numpy())

    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob >= 0.5).astype(int)

    # Classification metrics
    print(f"\n=== {title} Metrics ===")
    print(classification_report(y_true, y_pred, digits=4))
    acc = accuracy_score(y_true, y_pred)
    auc = roc_auc_score(y_true, y_prob)
    print(f"Accuracy: {acc:.4f} | AUC: {auc:.4f}")

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(5,4))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False,
                xticklabels=["Not Churn", "Churn"],
                yticklabels=["Not Churn", "Churn"])
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title(f"{title} – Confusion Matrix")
    plt.show()

    return acc, auc, cm


# -------------------------
# LOSS & METRIC CURVES
# -------------------------
def plot_training_curves(train_losses, val_aucs, val_losses=None):
    epochs = np.arange(1, len(train_losses)+1)

    plt.figure(figsize=(6,4))
    plt.plot(epochs, train_losses, label="Train Loss")
    if val_losses:
        plt.plot(epochs, val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss Curves")
    plt.legend()
    plt.show()

    if val_aucs:
        plt.figure(figsize=(6,4))
        plt.plot(epochs, val_aucs, label="Val AUC", color="green")
        plt.xlabel("Epoch")
        plt.ylabel("AUC")
        plt.title("Validation AUC Curve")
        plt.legend()
        plt.show()
