In [None]:
import numpy as np
import networkx as nx
import pandas as pd
import math
import itertools
from itertools import combinations
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.datasets import Coauthor
from torch_geometric.utils import to_networkx, to_scipy_sparse_matrix
from torch_geometric.nn import GATConv, SGConv, GCNConv
from collections import Counter
import gc
import os
import random
import scipy.sparse as sp
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, cohen_kappa_score
from tqdm import tqdm

# --- Basic Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def set_seed(seed: int = 42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def sample_per_class_fixed(labels, candidate_idx, num_classes, k_per_class):
    """Samples a fixed number of nodes per class from a candidate set."""
    sampled_indices = []
    for c in range(num_classes):
        class_mask = (labels[candidate_idx] == c)
        class_indices = candidate_idx[class_mask]
        num_to_sample = min(k_per_class, class_indices.numel())
        if num_to_sample > 0:
            perm = torch.randperm(class_indices.numel())
            sampled_indices.append(class_indices[perm[:num_to_sample]])
    return torch.cat(sampled_indices) if sampled_indices else torch.tensor([], dtype=torch.long)

def calculate_all_metrics(pred, prob, labels, mask):
    """Calculates all specified evaluation metrics."""
    labels, mask = labels.cpu(), mask.cpu()
    mask_indices = mask.nonzero(as_tuple=True)[0]
    
    pred_tensor = pred.cpu() if torch.is_tensor(pred) else torch.from_numpy(pred).cpu()
    prob_tensor = prob.cpu() if torch.is_tensor(prob) else torch.from_numpy(prob).cpu()

    y_true = labels[mask_indices].numpy()
    y_pred = pred_tensor[mask_indices].numpy()
    y_prob = prob_tensor[mask_indices].numpy()
    
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    
    try:
        auc = roc_auc_score(y_true, y_prob, multi_class='ovr')
    except ValueError:
        auc = 0.5 

    return {'acc': acc, 'f1': f1, 'auc': auc, 'kappa': kappa}

# ====================================================================
# 1. GNN MODEL DEFINITIONS (7 Models)
# ====================================================================

class GAT(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GAT, self).__init__()
        self.hid, self.in_head, self.out_head = 8, 8, 1
        self.conv1 = GATConv(num_features, self.hid, heads=self.in_head, dropout=0.6)
        self.conv2 = GATConv(self.hid * self.in_head, num_classes, concat=False, heads=self.out_head, dropout=0.6)
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

class SGC(nn.Module):
    def __init__(self, num_features, num_classes):
        super(SGC, self).__init__()
        self.conv = SGConv(num_features, num_classes, K=2, cached=False)
    def forward(self, data):
        return F.log_softmax(self.conv(data.x, data.edge_index), dim=1)

class GCN(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)
    def forward(self, data):
        x = F.relu(self.conv1(data.x, data.edge_index))
        return F.log_softmax(self.conv2(x, data.edge_index), dim=1)

class PlanetoidGCN(nn.Module):
    def __init__(self, num_features, num_classes, num_nodes):
        super(PlanetoidGCN, self).__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)
        self.embedding = nn.Parameter(torch.randn(num_nodes, 16))
    def forward(self, data):
        x = F.relu(self.conv1(data.x, data.edge_index))
        return F.log_softmax(self.conv2(x, data.edge_index), dim=1)
    def get_embedding(self):
        return self.embedding

class Delta(nn.Module):
    def __init__(self, d_in, d_red=6, act=nn.ReLU()):
        super().__init__()
        self.f = nn.Linear(d_in, d_red, bias=True)
        self.fk_list = None
        self.gate = nn.Sequential(nn.Linear(d_red, 1), nn.Sigmoid())
        self.act = act
    def set_fk_list(self, K):
        self.fk_list = nn.ModuleList([nn.Linear(self.f.in_features, self.f.out_features, bias=True) for _ in range(K)])
    def forward(self, H_list):
        out = []
        for k in range(len(H_list)):
            f_hk = self.f(H_list[k])
            fk_hk = self.fk_list[k](H_list[k])
            beta = self.gate(f_hk * fk_hk)
            out.append(self.act(beta * (f_hk - fk_hk)))
        return out

class MGNNLayer(nn.Module):
    def __init__(self, d_in, d_hidden, K, d_red=6):
        super().__init__()
        self.lin_z = nn.Linear(d_in, d_hidden, bias=False)
        self.delta = Delta(d_hidden, d_red=d_red)
        self.delta.set_fk_list(K)
    def agg_messages(self, Atil_k_list, Z):
        return [torch.spmm(A, Z) for A in Atil_k_list]
    def forward(self, X, Atil, Atil_list):
        Z = torch.spmm(Atil, self.lin_z(X))
        Hk = self.agg_messages(Atil_list, Z)
        Hk_tilde = self.delta(Hk)
        return torch.cat(Hk_tilde, dim=1)

class MGNN(nn.Module):
    def __init__(self, d_in, d_hidden, num_cls, K=3, d_red=8):
        super().__init__()
        self.layer = MGNNLayer(d_in, d_hidden, K=K, d_red=d_red)
        self.out = nn.Linear(K * d_red, num_cls)
    def forward(self, data):
        H = self.layer(data.x, data.adj_norm, data.motif_adjs)
        return F.log_softmax(self.out(H), dim=1)

class SDMG(nn.Module):
    def __init__(self, feat_dim, out_dim=256, heads=2):
        super().__init__()
        self.gat1 = GATConv(feat_dim, out_dim // heads, heads=heads, dropout=0.2)
        self.gat2 = GATConv(out_dim, out_dim // heads, heads=heads, dropout=0.2)
    def forward(self, data):
        h = F.elu(self.gat1(data.x, data.edge_index))
        return self.gat2(h, data.edge_index)

class ActOp(nn.Module):
    def __init__(self):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(3))
    def forward(self, x):
        stack = torch.stack([F.relu(x), torch.tanh(x), x], dim=-1)
        w = F.softmax(self.alpha, dim=0)
        return torch.einsum("...d,d->...", stack, w)

class LayerAggregator(nn.Module):
    def __init__(self, num_layers: int):
        super().__init__()
        self.alpha = nn.Parameter(torch.zeros(num_layers))
    def forward(self, hs: list):
        H_stack = torch.stack(hs, dim=0)
        w = F.softmax(self.alpha, dim=0)
        return torch.einsum("l,lnd->nd", w, H_stack)

class WholeGraphEncoder(nn.Module):
    def __init__(self, in_dim, hidden_dim, num_layers, dropout=0.5):
        super().__init__()
        dims = [in_dim] + [hidden_dim] * num_layers
        self.layers = nn.ModuleList([GCNConv(dims[i], dims[i+1]) for i in range(num_layers)])
        self.acts = nn.ModuleList([ActOp() for _ in range(num_layers)])
        self.dropout = dropout
    def forward(self, x, edge_index):
        hs = []
        h = x
        for l, layer in enumerate(self.layers):
            h = F.dropout(h, p=self.dropout, training=self.training)
            h = layer(h, edge_index)
            h = self.acts[l](h)
            hs.append(h)
        return hs

class CSSE(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_layers=3, dropout=0.5):
        super().__init__()
        self.encoder = WholeGraphEncoder(in_dim, hidden_dim, num_layers, dropout)
        self.aggregator = LayerAggregator(num_layers)
        self.classifier = nn.Linear(hidden_dim, out_dim)
    def forward(self, data):
        hs = self.encoder(data.x, data.edge_index)
        h_final = self.aggregator(hs)
        return F.log_softmax(self.classifier(h_final), dim=1)
    def get_arch_params(self):
        return [p for m in [self.encoder.acts, self.aggregator] for p in m.parameters()]
    def get_net_params(self):
        arch_ids = {id(p) for p in self.get_arch_params()}
        return [p for p in self.parameters() if id(p) not in arch_ids]

# ====================================================================
# 2. GNN TRAINING PIPELINES
# ====================================================================

def standard_gnn_training(model, data, device, lr, wd, epochs):
    model, data = model.to(device), data.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    for _ in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        return torch.exp(model(data))

def sample_context(graph, labels, num_samples=100):
    context_pairs, gamma_values = [], []
    nodes = list(graph.nodes)
    if not nodes: return [], []
    for _ in range(num_samples):
        node = np.random.choice(nodes)
        neighbors = list(nx.single_source_shortest_path_length(graph, node, cutoff=2).keys())
        if len(neighbors) > 1:
            neighbor = np.random.choice([n for n in neighbors if n != node])
            context_pairs.append((node, neighbor))
            gamma_values.append(1 if (labels[node] == labels[neighbor]) else 0)
    return context_pairs, gamma_values

def context_loss(embedding, context_pairs, gamma_values):
    loss = 0
    device = embedding.device
    if not context_pairs: return torch.tensor(0.0, device=device)
    for (i, j), gamma in zip(context_pairs, gamma_values):
        if i < embedding.shape[0] and j < embedding.shape[0]:
            score = torch.dot(embedding[i], embedding[j])
            gamma_tensor = torch.tensor(gamma, device=device, dtype=torch.float32)
            loss += gamma_tensor * F.logsigmoid(score) + (1 - gamma_tensor) * F.logsigmoid(-score)
    return -loss / len(context_pairs)

def planetoid_gcn_training(model, data, G, device, lr, wd, epochs):
    model, data = model.to(device), data.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    for _ in range(epochs):
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss_s = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        context_pairs, gamma_values = sample_context(G, data.y.cpu())
        loss_u = context_loss(model.get_embedding(), context_pairs, gamma_values)
        loss = loss_s + 0.5 * loss_u
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        return torch.exp(model(data))

def csse_training(model, data, device, lr, wd, alr, epochs):
    model, data = model.to(device), data.to(device)
    optimizer_w = optim.Adam(model.get_net_params(), lr=lr, weight_decay=wd)
    optimizer_a = optim.Adam(model.get_arch_params(), lr=alr, weight_decay=0.0)
    for _ in range(epochs):
        model.train()
        optimizer_w.zero_grad()
        loss_w = F.nll_loss(model(data)[data.train_mask], data.y[data.train_mask])
        loss_w.backward()
        optimizer_w.step()
        optimizer_a.zero_grad()
        loss_a = F.nll_loss(model(data)[data.val_mask], data.y[data.val_mask])
        loss_a.backward()
        optimizer_a.step()
    model.eval()
    with torch.no_grad():
        return torch.exp(model(data))

def sdmg_training_and_inference(model, data, device):
    model, data = model.to(device), data.to(device)
    optimizer = optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
    for _ in range(200): # Simplified pre-training
        model.train()
        optimizer.zero_grad()
        out = model(data)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
    model.eval()
    with torch.no_grad():
        H = model(data).detach()
    clf = LogisticRegression(solver='lbfgs', multi_class='auto', random_state=42, max_iter=200)
    clf.fit(H[data.train_mask].cpu().numpy(), data.y[data.train_mask].cpu().numpy())
    probs = clf.predict_proba(H.cpu().numpy())
    return torch.from_numpy(probs).float()

# ====================================================================
# 3. SIMPLEX AND HOI HELPER FUNCTIONS
# ====================================================================

# modified_function.py
def group_by_size(cliques):
    if not cliques: return []
    max_len = len(max(cliques, key=len)) if cliques else 0
    grouped = [[] for _ in range(max_len)]
    for c in cliques:
        if len(c) > 0: grouped[len(c)-1].append(c)
    return grouped

def _calculate_variance(participation_counter, num_nodes):
    if not participation_counter: return 0.0
    freqs = np.array(list(participation_counter.values()))
    all_freqs = np.concatenate([freqs, np.zeros(num_nodes - len(freqs))])
    return np.var(all_freqs)

def generate_all_simplex_types_greedy(G: nx.Graph, budget: int = 10000):
    print("1/4: Finding Maximal Cliques...")
    maximal_cliques_list = [sorted(c) for c in nx.find_cliques(G)]
    print("2/4: Generating All Cliques...")
    all_cliques_set = {tuple(sorted(sub)) for m in maximal_cliques_list for k in range(1, len(m) + 1) for sub in combinations(m, k)}
    all_cliques_list = [list(c) for c in all_cliques_set]
    print(f"3/4: Building Balanced Simplices (Greedy, Budget={budget})...")
    num_nodes = G.number_of_nodes()
    aug_max_set = {tuple(c) for c in maximal_cliques_list}
    node_part = Counter(n for c in aug_max_set for n in c)
    current_var = _calculate_variance(node_part, num_nodes)
    candidates = all_cliques_set - aug_max_set
    for i in range(budget):
        if not candidates: break
        best_clique, min_var = None, current_var
        for cand in candidates:
            temp_part = node_part.copy(); temp_part.update(cand)
            pot_var = _calculate_variance(temp_part, num_nodes)
            if pot_var < min_var: min_var, best_clique = pot_var, cand
        if best_clique:
            aug_max_set.add(best_clique); candidates.remove(best_clique)
            node_part.update(best_clique); current_var = min_var
        else: break
    balanced_list = [list(c) for c in aug_max_set]
    print("4/4: Formatting final results...")
    return group_by_size(all_cliques_list), group_by_size(maximal_cliques_list), group_by_size(balanced_list)

def get_x_known(data, train_mask, val_mask, num_classes):
    mask = train_mask | val_mask
    indices = mask.nonzero().squeeze().cpu().numpy()
    if indices.ndim == 0: indices = np.array([indices.item()])
    labels = data.y[indices].cpu().numpy()
    one_hot = np.eye(num_classes, dtype=np.int64)[labels]
    return np.hstack([indices[:, np.newaxis], one_hot])

def precompute_hoi_coefficients(n_max, n_classes, device):
    print("Pre-computing HOI coefficients...")
    fact_lookup = torch.tensor([math.factorial(i) for i in range(n_max + 2)], dtype=torch.float32, device=device)
    mat = [[torch.eye(n_classes, device=device)[i].unsqueeze(-1) for i in range(n_classes)]]
    for k in range(n_max):
        next_mat = []
        if not mat[k]: break
        for j in range(n_classes):
            for i in range(len(mat[k])):
                next_mat.append(torch.cat([mat[0][j], mat[k][i]], dim=1))
        mat.append(next_mat)
    
    coef = [torch.empty(0, device=device)]
    for k in range(1, n_max + 2):
        cvals = []
        if k-1 < len(mat) and mat[k-1]:
            for j in range(len(mat[k-1])):
                row_sums = mat[k-1][j].sum(1)
                row_fac = torch.prod(fact_lookup[row_sums.long()])
                cvals.append(fact_lookup[k] / row_fac)
        coef.append(torch.tensor(cvals, device=device))
    return coef

def prob_product(vectors):
    res = vectors[0]
    for v in vectors[1:]: res = torch.ger(res, v).flatten()
    return res

def generalized_outer_product(P, index_lists):
    if not index_lists: return torch.empty(0, device=P.device)
    return torch.stack([prob_product([P[idx] for idx in indices]) for indices in index_lists])

# ====================================================================
# 4. HOI MODEL AND TRAINING (WITH WEIGHT STRATEGY)
# ====================================================================

def objective_efficient(P, simplices, device, precomputed_coef, weight_strategy='constant'):
    K_MAX, n_max = 5, len(simplices)
    
    if weight_strategy == 'linear':
        clique_weight = torch.arange(1, n_max + 1, device=device).float()
    elif weight_strategy == 'exponential':
        clique_weight = torch.exp(torch.arange(n_max, device=device).float())
    else: # constant
        clique_weight = torch.ones(n_max, device=device)
        
    total_obj = 0.0
    for i in range(1, min(n_max, K_MAX)):
        clique_size = i + 1
        if not simplices[i] or clique_size >= len(precomputed_coef) or precomputed_coef[clique_size].numel() == 0:
            continue
        prob_prod = generalized_outer_product(P, simplices[i])
        coef_slice = precomputed_coef[clique_size]
        if coef_slice.shape[0] == prob_prod.shape[1]:
            term_sum = (coef_slice * prob_prod).sum()
            total_obj += clique_weight[i] * term_sum
    return total_obj

class HOIModel(nn.Module):
    def __init__(self, device, initial_data, x_known, precomputed_coef):
        super(HOIModel, self).__init__()
        self.device = device
        self.precomputed_coef = precomputed_coef
        init_tensor = torch.tensor(initial_data, dtype=torch.float32)
        self.n_V, self.n_L = init_tensor.shape
        self.fixed_indices = torch.from_numpy(x_known[:, 0].astype(int))
        self.fixed_params = torch.tensor(x_known[:, 1:], dtype=torch.float32)
        mask = torch.ones(self.n_V, dtype=torch.bool)
        mask[self.fixed_indices] = False
        self.trainable_indices = torch.arange(self.n_V)[mask]
        self.trainable_params = nn.Parameter(init_tensor[mask])

    def forward(self, simplices, weight_strategy):
        full_data = torch.zeros((self.n_V, self.n_L), device=self.device)
        full_data[self.fixed_indices] = self.fixed_params.to(self.device)
        full_data[self.trainable_indices] = self.trainable_params
        soft_P = F.softmax(full_data, dim=1)
        return objective_efficient(soft_P, simplices, self.device, self.precomputed_coef, weight_strategy)

def HOI_training(epochs, device, simplices, initial_data, x_known, lr, precomputed_coef, weight_strategy):
    model = HOIModel(device, initial_data, x_known, precomputed_coef).to(device)
    optimizer = optim.Adam([model.trainable_params], lr=lr)
    for _ in range(epochs):
        optimizer.zero_grad()
        loss = model(simplices, weight_strategy)
        loss.backward()
        optimizer.step()
    
    with torch.no_grad():
        full_data = torch.zeros((model.n_V, model.n_L), device=model.device)
        full_data[model.fixed_indices] = model.fixed_params.to(model.device)
        full_data[model.trainable_indices] = model.trainable_params
        probs = F.softmax(full_data, dim=1)
    return probs

# ====================================================================
# 5. MAIN EXPERIMENT EXECUTION BLOCK
# ====================================================================
if __name__ == '__main__':
    # --- Basic Setup ---
    SEEDS = list(range(10))
    GNN_CONFIG = {
        'GAT': {'class': GAT, 'lr': 0.005, 'wd': 5e-4, 'epochs': 1000},
        'GCN': {'class': GCN, 'lr': 0.01, 'wd': 5e-4, 'epochs': 400},
        'SGC': {'class': SGC, 'lr': 0.1, 'wd': 5e-5, 'epochs': 200},
        'PlanetoidGCN': {'class': PlanetoidGCN, 'lr': 0.01, 'wd': 5e-4, 'epochs': 400},
        'MGNN': {'class': MGNN, 'lr': 0.01, 'wd': 5e-4, 'epochs': 200, 'params': {'d_hidden': 32, 'd_red': 8}},
        'SDMG': {'class': SDMG, 'lr': 0.005, 'wd': 5e-4, 'epochs': 200}, # Simplified training
        'CSSE': {'class': CSSE, 'lr': 1e-2, 'wd': 5e-4, 'epochs': 200, 'params': {'alr': 3e-4, 'hidden_dim': 128}}
    }
    
    # --- Data Loading and Pre-computation (run once) ---
    print("Loading Coauthor-Physics dataset...")
    dataset = Coauthor(root='dataset/Coauthor-Physics', name='Physics')
    data = dataset[0]
    num_nodes, num_features, num_classes = data.num_nodes, data.num_features, dataset.num_classes
    
    print("Generating simplices (this may take a while)...")
    G = to_networkx(data.cpu(), to_undirected=True, remove_self_loops=True)
    general_s, maximal_s, balanced_s = generate_all_simplex_types_greedy(G)
    
    print("Pre-computing HOI coefficients...")
    precomputed_coef = precompute_hoi_coefficients(len(general_s), num_classes, device)
    
    # --- Results Storage ---
    results_storage = []

    # --- Main Loop ---
    for gnn_name, config in GNN_CONFIG.items():
        for seed in SEEDS:
            print(f"\n--- Running: GNN={gnn_name}, Seed={seed} ---")
            set_seed(seed)

            # Data Splitting
            perm = torch.randperm(num_nodes)
            train_idx = sample_per_class_fixed(data.y, perm[:int(num_nodes*0.6)], num_classes, 20)
            val_idx = sample_per_class_fixed(data.y, perm[int(num_nodes*0.6):int(num_nodes*0.8)], num_classes, 100)
            test_idx = perm[int(num_nodes*0.8):]
            data.train_mask = torch.zeros(num_nodes, dtype=torch.bool); data.train_mask[train_idx] = True
            data.val_mask = torch.zeros(num_nodes, dtype=torch.bool); data.val_mask[val_idx] = True
            data.test_mask = torch.zeros(num_nodes, dtype=torch.bool); data.test_mask[test_idx] = True

            # --- Stage 1: Raw GNN Training ---
            current_device = 'cpu' if gnn_name == 'SGC' else device
            model_class = config['class']
            
            if gnn_name == 'PlanetoidGCN':
                gnn_model = model_class(num_features, num_classes, num_nodes)
                raw_probs = planetoid_gcn_training(gnn_model, data, G, current_device, config['lr'], config['wd'], config['epochs'])
            elif gnn_name == 'MGNN':
                # Pre-computation for MGNN
                adj_sp = to_scipy_sparse_matrix(data.edge_index, num_nodes=num_nodes)
                adj_norm_sp = adj_sp + sp.eye(adj_sp.shape[0])
                data.adj_norm = torch.from_numpy(adj_norm_sp.toarray()).float().to(current_device)
                # Simplified motif adjs
                A2 = adj_sp.dot(adj_sp); A3 = A2.dot(adj_sp)
                data.motif_adjs = [torch.from_numpy(m.toarray()).float().to(current_device) for m in [A2, A3, A2.dot(A2)]]
                gnn_model = model_class(num_features, config['params']['d_hidden'], num_classes)
                raw_probs = standard_gnn_training(gnn_model, data, current_device, config['lr'], config['wd'], config['epochs'])
            elif gnn_name == 'SDMG':
                gnn_model = model_class(num_features)
                raw_probs = sdmg_training_and_inference(gnn_model, data, current_device)
            elif gnn_name == 'CSSE':
                gnn_model = model_class(num_features, config['params']['hidden_dim'], num_classes)
                raw_probs = csse_training(gnn_model, data, current_device, config['lr'], config['wd'], config['params']['alr'], config['epochs'])
            else: # GAT, GCN, SGC
                gnn_model = model_class(num_features, num_classes)
                raw_probs = standard_gnn_training(gnn_model, data, current_device, config['lr'], config['wd'], config['epochs'])

            raw_preds = torch.argmax(raw_probs, dim=1)
            raw_metrics = calculate_all_metrics(raw_preds, raw_probs, data.y, data.test_mask)
            results_storage.append({'gnn': gnn_name, 'seed': seed, 'type': 'Raw', 'weight': 'N/A', **raw_metrics})
            print(f"  Raw {gnn_name} Accuracy: {raw_metrics['acc']:.4f}")

            # --- Stage 2: HOI Post-processing ---
            initial_data = raw_probs.detach().cpu().numpy()
            x_known = get_x_known(data, data.train_mask, data.val_mask, num_classes)
            
            for simplex_name, simplices in [('General', general_s), ('Maximal', maximal_s), ('Balanced', balanced_s)]:
                for weight_name in ['Constant', 'Linear', 'Exponential']:
                    hoi_probs = HOI_training(20, device, simplices, initial_data, x_known, 0.1, precomputed_coef, weight_name.lower())
                    hoi_preds = torch.argmax(hoi_probs, dim=1)
                    hoi_metrics = calculate_all_metrics(hoi_preds, hoi_probs, data.y, data.test_mask)
                    results_storage.append({'gnn': gnn_name, 'seed': seed, 'type': simplex_name, 'weight': weight_name, **hoi_metrics})

            del gnn_model, raw_probs, raw_preds; gc.collect()
            if torch.cuda.is_available(): torch.cuda.empty_cache()

    # --- Final Results Aggregation and Display ---
    df = pd.DataFrame(results_storage)
    summary = df.groupby(['gnn', 'type', 'weight']).agg(['mean', 'std']).reset_index()
    
    # Formatting for display
    for metric in ['acc', 'f1', 'auc', 'kappa']:
        summary[f'{metric}_str'] = summary.apply(lambda row: f"{row[(metric, 'mean')]:.4f} ± {row[(metric, 'std')]:.4f}", axis=1)
    
    print("\n\n" + "="*80)
    print(" " * 25 + "FINAL COMPREHENSIVE RESULTS")
    print("="*80)
    
    for metric in ['acc', 'f1', 'auc', 'kappa']:
        print(f"\n--- METRIC: {metric.upper()} ---")
        pivot = summary.pivot_table(index=['gnn', 'type'], columns='weight', values=f'{metric}_str', aggfunc='first')
        # Reorder columns for better readability
        cols_order = [c for c in ['N/A', 'Constant', 'Linear', 'Exponential'] if c in pivot.columns]
        pivot = pivot[cols_order]
        print(pivot)