In [None]:
!pip install --quiet torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric -f https://data.pyg.org/whl/torch-2.1.0+cu118.html


# Imports

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import gc
from torch_geometric.transforms import ToUndirected, RandomNodeSplit
from torch_geometric.nn import SAGEConv, to_hetero
from sklearn.metrics import f1_score, roc_auc_score
from torch_geometric.datasets import OGB_MAG



device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


# Load/download OGB-MAG (metapath2vec)

In [None]:
dataset = OGB_MAG(root="data/OGB", preprocess="metapath2vec", transform=ToUndirected())
data = dataset[0]
print("Original node types:", data.node_types)
print("Original edge types:", data.edge_types)

# Subgraph Selection

In [None]:
# Number of paper nodes we decided to keep (centered set)
num_papers = 60000
paper_ids = torch.arange(num_papers)

# 1-Hop Heterogeneous Subgraph Creation

In [None]:
# Build nodes_to_keep (papers + 1-hop neighbors)
nodes_to_keep = {'paper': paper_ids}
for (src, rel, dst), edge_index in data.edge_index_dict.items():
    ei = edge_index
    # neighbors going out from papers
    if src == 'paper':
        mask = torch.isin(ei[0], paper_ids)
        if mask.any():
            dst_nodes = torch.unique(ei[1, mask])
            if dst in nodes_to_keep:
                nodes_to_keep[dst] = torch.unique(torch.cat([nodes_to_keep[dst], dst_nodes]))
            else:
                nodes_to_keep[dst] = dst_nodes
    # neighbors pointing into papers
    if dst == 'paper':
        mask = torch.isin(ei[1], paper_ids)
        if mask.any():
            src_nodes = torch.unique(ei[0, mask])
            if src in nodes_to_keep:
                nodes_to_keep[src] = torch.unique(torch.cat([nodes_to_keep[src], src_nodes]))
            else:
                nodes_to_keep[src] = src_nodes


# Create 1-hop heterogeneous subgraph
data_sub = data.subgraph(nodes_to_keep)
print("1-hop heterogeneous subgraph created.")
print("Num papers:", data_sub["paper"].num_nodes)


# Venue Distribution

In [None]:
y = data_sub["paper"].y
unique, counts = torch.unique(y, return_counts=True)

df_venues = pd.DataFrame({
    "venue_id": unique.cpu().tolist(),
    "count": counts.cpu().tolist()
})
df_venues.to_csv("venue_distribution.csv", index=False)
print("Save venue_distribution.csv")

num_classes = int(y.max().item() + 1)
print("Number of venues:", num_classes)


# Train/Validation/Test Split and Data Preparation

In [None]:
num_nodes = data_sub["paper"].num_nodes
num_val = int(num_nodes * 0.15)   
num_test = int(num_nodes * 0.15)  

# Apply a random node split for train/val/test
split_transform = RandomNodeSplit(split="train_rest",num_val=num_val,num_test=num_test)
data_sub = split_transform(data_sub)

print("Train =", data_sub["paper"].train_mask.sum().item())
print("Val   =", data_sub["paper"].val_mask.sum().item())
print("Test  =", data_sub["paper"].test_mask.sum().item())

data_sub = data_sub.to(device)

# Extract node features and edge indices
x_dict = data_sub.x_dict
edge_index_dict = data_sub.edge_index_dict

# Extract labels and masks for papers
y = data_sub['paper'].y
train_mask = data_sub['paper'].train_mask
val_mask = data_sub['paper'].val_mask
test_mask = data_sub['paper'].test_mask

# Get the number of features for paper nodes
num_features = x_dict['paper'].shape[1]


# GraphSAGE

In [None]:
class GraphSAGEEnc(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, aggr='mean'):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels, aggr=aggr)
        self.conv2 = SAGEConv(hidden_channels, out_channels, aggr=aggr)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = F.relu(h)
        h = self.conv2(h, edge_index)
        return h


# Metrics Computation

In [None]:
def compute_metrics(logits, y_true):
    preds = logits.argmax(dim=-1).cpu()
    y_true = y_true.cpu()
    
    accuracy = (preds == y_true).sum().item() / y_true.size(0)
    f1 = f1_score(y_true, preds, average="macro")
    
    y_prob = torch.softmax(logits, dim=-1).cpu()
    try:
        auc = roc_auc_score(y_true, y_prob, multi_class="ovr")
    except:
        auc = float("nan")
    
    return accuracy, f1, auc


# Training Heterogeneous GraphSAGE

In [None]:
EPOCHS = 200
AGGREGATIONS = ["mean", "sum"]

for aggr in AGGREGATIONS:
    print(f"\n=== Training heterogeneous GraphSAGE with aggr={aggr} ===")
    base_model = GraphSAGEEnc(num_features, 128, 128, aggr=aggr)
    model_hetero = to_hetero(base_model, data_sub.metadata(), aggr=aggr).to(device)
    head = nn.Linear(128, num_classes).to(device)
    optimizer = torch.optim.Adam(list(model_hetero.parameters()) + list(head.parameters()), lr=0.001)

    metrics = []

    for epoch in range(1, EPOCHS+1):
        model_hetero.train()
        head.train()
        optimizer.zero_grad()

        h_dict = model_hetero(x_dict, edge_index_dict)
        logits = head(h_dict["paper"])
        loss = F.cross_entropy(logits[train_mask], y[train_mask])
        loss.backward()
        optimizer.step()

        # Evaluation 
        model_hetero.eval()
        head.eval()
        with torch.no_grad():
            train_acc, train_f1, train_auc = compute_metrics(logits[train_mask], y[train_mask])
            val_acc, val_f1, val_auc = compute_metrics(logits[val_mask], y[val_mask])

        
        metrics.append([
            epoch,
            float(loss.item()),
            train_acc, train_f1, train_auc,
            val_acc, val_f1, val_auc
        ])

        # Print metrics every 10 epochs
        if epoch % 10 == 0:
            print(
                f"Epoch {epoch:03d} | Loss={loss:.4f} | "
                f"Train Acc={train_acc:.4f} F1={train_f1:.4f} AUC={train_auc:.4f} | "
                f"Val Acc={val_acc:.4f} F1={val_f1:.4f} AUC={val_auc:.4f}"
            )

    # Save models
    torch.save(model_hetero.state_dict(), f"model_{aggr}.pth")
    torch.save(head.state_dict(), f"head_{aggr}.pth")
    print(f"Models saved for aggr={aggr}.")

    # Save CSV
    df_metrics = pd.DataFrame(
        metrics,
        columns=[
            "epoch",
            "loss",
            "train_acc", "train_f1", "train_auc",
            "val_acc", "val_f1", "val_auc"
        ]
    )
    df_metrics.to_csv(f"metrics_{aggr}.csv", index=False)
    print(f"Metrics saved: metrics_{aggr}.csv")

    # Cleanup to free memory
    del model_hetero, head, optimizer, metrics, logits
    torch.cuda.empty_cache()
    gc.collect()


# Testing

In [None]:
results = {}

for aggr in AGGREGATIONS:
    print(f"\n=== Testing models using the aggregation {aggr} ===")
    
    base_model = GraphSAGEEnc(num_features, 128, 128, aggr=aggr)
    model_hetero = to_hetero(base_model, data_sub.metadata(), aggr=aggr).to(device)
    head = nn.Linear(128, num_classes).to(device)
    
    # Load Weights of the models
    model_hetero.load_state_dict(torch.load(f"/kaggle/working/model_{aggr}.pth"))
    head.load_state_dict(torch.load(f"/kaggle/working/head_{aggr}.pth"))
    
    model_hetero.eval()
    head.eval()
    
    with torch.no_grad():
        h_dict = model_hetero(x_dict, edge_index_dict)
        logits = head(h_dict["paper"])
        
        # Calculate Metrics
        test_acc, test_f1, test_auc = compute_metrics(logits[test_mask], y[test_mask])
        
    print(f"Test Acc={test_acc:.4f} | Test F1={test_f1:.4f} | Test AUC={test_auc:.4f}")
    
    # Save results
    results[aggr] = {
        "acc": test_acc,
        "f1": test_f1,
        "auc": test_auc
    }
