# Graph Embeddings, Node Classification & Link Prediction
### Semester 2 Extension — Massive Graph Management and Analytics

---

**Author:** Olha Baliasina and Samuel Chapuis

**Datasets:**
- **Amazon Co-Purchasing Network** (~335K nodes, ~926K edges) — Large
- **CA-HepTh Collaboration Network** (~6K–9K authors) — Small

**This notebook covers:**
1. Shallow Embeddings: DeepWalk, Node2Vec
2. Spectral Embeddings: Laplacian Eigenmaps
3. GNN-based: GCN, GraphSAGE, GAT
4. Node Classification on both datasets
5. Link Prediction on both datasets
6. Comprehensive comparison & analysis

---
# Part 0: Setup & Dependencies


In [1]:
# 1) Pin NumPy to what torch wheels commonly expect, and reinstall compiled stack to match it
%pip install -q --no-cache-dir --force-reinstall \
  numpy==1.26.4 pandas scipy scikit-learn

# 2) Install PyTorch (CUDA 12.6)
%pip install -q --no-cache-dir \
  torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 \
  --index-url https://download.pytorch.org/whl/cu126

# 3) Install PyG compiled wheels (binary only, no source builds)
%pip install -q --no-cache-dir --only-binary=:all: \
  pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
  -f https://data.pyg.org/whl/torch-2.8.0+cu126.html

%pip install -q --no-cache-dir torch-geometric

# 4) The remaining libraries you import/use
%pip install -q --no-cache-dir umap-learn node2vec gensim python-louvain python-igraph


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m81.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m229.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.1/62.1 kB[0m [31m178.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m197.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m175.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.0/35.0 MB[0m [31m214.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m147.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m309.1/309.1 kB[0m [31m346.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
# import numpy
# print("numpy:", numpy.__version__)

numpy: 2.0.2


In [4]:
# # ---- install PyTorch CUDA 12.6 wheels ----
# %pip install -q --no-cache-dir torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu126

# # ---- install PyG compiled wheels (BINARY ONLY: no source builds) ----
# %pip install -q --no-cache-dir --only-binary=:all: pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
# %pip install -q torch-geometric

# # ---- the remaining libs you use ----
# %pip install -q node2vec gensim umap-learn python-louvain python-igraph


In [2]:
# # PyTorch (CUDA 12.6) + PyG wheels that match it (no source builds)
# %pip install -q --no-cache-dir torch==2.8.0 torchvision==0.23.0 torchaudio==2.8.0 --index-url https://download.pytorch.org/whl/cu126
# %pip install -q --no-cache-dir --only-binary=:all: pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
#   -f https://data.pyg.org/whl/torch-2.8.0+cu126.html
# %pip install -q torch-geometric

# # Your extra libs
# %pip install -q umap-learn node2vec gensim python-louvain python-igraph


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.4/322.4 MB[0m [31m127.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.8/821.8 MB[0m [31m238.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m183.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.5/3.5 MB[0m [31m253.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m155.6/155.6 MB[0m [31m225.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m95.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.9/10.9 MB[0m [31m215.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m277.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [3]:
# import numpy as np, pandas as pd
# import torch, torch_geometric, umap
# print("numpy", np.__version__)
# print("pandas", pd.__version__)
# print("torch", torch.__version__, "cuda", torch.version.cuda)
# print("pyg", torch_geometric.__version__)


ValueError: numpy.dtype size changed, may indicate binary incompatibility. Expected 96 from C header, got 88 from PyObject

In [3]:
# %pip install -q torch-geometric



[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.3/1.3 MB[0m [31m55.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m36.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [5]:
# %pip install -q node2vec gensim umap-learn python-louvain python-igraph


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m27.9/27.9 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.7/5.7 MB[0m [31m100.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m76.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-python 4.13.0.90 requires numpy>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
pytensor 2.37.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.
rasterio 1.5.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.
opencv-python-headless 4.13.0.90 requires numpy>=2; 

In [6]:
# import re, torch

# torch_ver = re.match(r"(\d+\.\d+\.\d+)", torch.__version__).group(1)
# cuda_ver = torch.version.cuda

# if cuda_ver is None:
#     cuda_tag = "cpu"
# else:
#     cuda_tag = "cu" + cuda_ver.replace(".", "")

# url = f"https://data.pyg.org/whl/torch-{torch_ver}+{cuda_tag}.html"
# print("Using wheels from:", url)

# %pip install -q torch-scatter torch-sparse torch-cluster torch-spline-conv -f {url}
# %pip install -q torch-geometric


In [4]:
# %pip install -q node2vec gensim umap-learn
# %pip install -q python-louvain python-igraph
# %pip install -q numpy


In [11]:
# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install torch-geometric
# !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.1.0+cu118.html
# !pip install node2vec gensim umap-learn
# !pip install community python-louvain igraph


In [1]:
# ============================================================
# 0.2  Imports
# ============================================================
import os
import time
import warnings
import pickle
from collections import Counter, defaultdict, deque
from itertools import combinations

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from scipy import sparse
from scipy.sparse.linalg import eigsh

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    roc_auc_score, average_precision_score
)
from sklearn.preprocessing import LabelEncoder, StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F

# PyTorch Geometric
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.loader import NeighborLoader, LinkNeighborLoader
from torch_geometric.nn import GCNConv, SAGEConv, GATConv
from torch_geometric.utils import (
    from_networkx, negative_sampling, train_test_split_edges,
    to_undirected, add_self_loops
)
from torch_geometric.transforms import RandomLinkSplit

# Gensim for shallow embeddings
from gensim.models import Word2Vec

# Node2Vec
from node2vec import Node2Vec as Node2VecLib

# UMAP for visualization
import umap

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")
print(f"PyG version: {torch_geometric.__version__}")

SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

Using device: cuda
PyTorch version: 2.8.0+cu126
PyG version: 2.7.0


---
# Part 1: Data Loading & Preprocessing

We load both datasets and prepare them for the experiments.

### Data files required:
- `data/com-amazon.ungraph.txt` — Amazon edge list
- `data/com-amazon.all.dedup.cmty.txt` — Amazon ground-truth communities
- `arxiv_hepth_meta.csv` — HepTh paper metadata
- (Optional) `data/CA-HepTh.txt` — HepTh original edge list

## 1.1 Amazon Co-Purchasing Network

In [2]:
# ============================================================
# 1.1  Load Amazon graph
# ============================================================
GRAPH_FILE = 'data/com-amazon.ungraph.txt'
COMMUNITY_FILE = 'data/com-amazon.all.dedup.cmty.txt'

print("Loading Amazon co-purchasing network...")
t0 = time.time()
G_amazon = nx.read_edgelist(
    GRAPH_FILE, comments='#', delimiter='\t',
    create_using=nx.Graph(), nodetype=int
)
print(f"Loaded in {time.time()-t0:.1f}s — {G_amazon.number_of_nodes():,} nodes, {G_amazon.number_of_edges():,} edges")

# Load ground-truth communities
ground_truth_communities = []
with open(COMMUNITY_FILE, 'r') as f:
    for line in f:
        line = line.strip()
        if line and not line.startswith('#'):
            members = [int(x) for x in line.split('\t')]
            ground_truth_communities.append(set(members))

print(f"Loaded {len(ground_truth_communities):,} ground-truth communities")

Loading Amazon co-purchasing network...
Loaded in 6.5s — 334,863 nodes, 925,872 edges
Loaded 75,149 ground-truth communities


In [3]:
# ============================================================
# 1.2  Construct node labels for Amazon (node classification)
# ============================================================
# Each node can belong to multiple communities (overlapping).
# Strategy: assign each node the community it shares with the
# MOST neighbors (i.e., the community most structurally relevant).
# Then keep only the top-K most frequent labels.

# Build node -> list of community IDs
node_to_comms = defaultdict(list)
for cid, comm in enumerate(ground_truth_communities):
    for node in comm:
        if node in G_amazon:
            node_to_comms[node].append(cid)

# For each node, pick the community that most of its neighbors also belong to
def assign_best_community(G, node_to_comms):
    labels = {}
    for node in G.nodes():
        if node not in node_to_comms or len(node_to_comms[node]) == 0:
            continue
        neighbors = set(G.neighbors(node))
        best_cid, best_score = None, -1
        for cid in node_to_comms[node]:
            comm_set = ground_truth_communities[cid]
            overlap = len(neighbors & comm_set)
            if overlap > best_score:
                best_score = overlap
                best_cid = cid
        if best_cid is not None:
            labels[node] = best_cid
    return labels

print("Assigning primary community labels to nodes...")
t0 = time.time()
amazon_node_labels = assign_best_community(G_amazon, node_to_comms)
print(f"Done in {time.time()-t0:.1f}s — {len(amazon_node_labels):,} nodes labeled")

# Keep top-K most frequent labels
TOP_K = 20
label_counts = Counter(amazon_node_labels.values())
top_labels = set([l for l, _ in label_counts.most_common(TOP_K)])
amazon_node_labels_filtered = {n: l for n, l in amazon_node_labels.items() if l in top_labels}

# Re-encode labels to 0..K-1
le_amazon = LabelEncoder()
labeled_nodes_amazon = sorted(amazon_node_labels_filtered.keys())
raw_labels = [amazon_node_labels_filtered[n] for n in labeled_nodes_amazon]
encoded_labels_amazon = le_amazon.fit_transform(raw_labels)

print(f"\nFiltered to top-{TOP_K} categories: {len(amazon_node_labels_filtered):,} labeled nodes")
print(f"Label distribution:")
for label_id, count in sorted(Counter(encoded_labels_amazon).items(), key=lambda x: -x[1])[:10]:
    print(f"  Class {label_id}: {count:,} nodes")

Assigning primary community labels to nodes...
Done in 5.4s — 317,194 nodes labeled

Filtered to top-20 categories: 173,905 labeled nodes
Label distribution:
  Class 5: 22,914 nodes
  Class 12: 16,623 nodes
  Class 9: 16,363 nodes
  Class 11: 14,782 nodes
  Class 2: 13,967 nodes
  Class 14: 13,873 nodes
  Class 18: 12,292 nodes
  Class 1: 11,471 nodes
  Class 8: 7,718 nodes
  Class 17: 7,245 nodes


## 1.2 CA-HepTh Collaboration Network

In [4]:
# ============================================================
# 1.3  Load and build HepTh collaboration graph
# ============================================================
import ast

meta_df = pd.read_csv("data/arxiv_hepth_meta.csv")

# Clean authors_list column
def safe_parse(x):
    if isinstance(x, list):
        return x
    try:
        parsed = ast.literal_eval(str(x))
        if isinstance(parsed, list):
            return [str(a).strip() for a in parsed if len(str(a).strip()) > 1]
    except:
        pass
    return []

meta_df['authors_clean'] = meta_df['authors_list'].apply(safe_parse)
meta_df['n_authors_clean'] = meta_df['authors_clean'].apply(len)

# Build co-authorship graph
G_hepth = nx.Graph()
for _, row in meta_df.iterrows():
    authors = row['authors_clean']
    if len(authors) >= 2:
        for a1, a2 in combinations(authors, 2):
            if G_hepth.has_edge(a1, a2):
                G_hepth[a1][a2]['weight'] += 1
            else:
                G_hepth.add_edge(a1, a2, weight=1)

# Use largest connected component
components = sorted(nx.connected_components(G_hepth), key=len, reverse=True)
G_hepth_lcc = G_hepth.subgraph(components[0]).copy()

print(f"HepTh full graph: {G_hepth.number_of_nodes():,} nodes, {G_hepth.number_of_edges():,} edges")
print(f"HepTh LCC:        {G_hepth_lcc.number_of_nodes():,} nodes, {G_hepth_lcc.number_of_edges():,} edges")
print(f"Components: {len(components)}, LCC covers {G_hepth_lcc.number_of_nodes()/G_hepth.number_of_nodes():.1%}")

HepTh full graph: 16,210 nodes, 27,789 edges
HepTh LCC:        10,397 nodes, 22,243 edges
Components: 1919, LCC covers 64.1%


In [5]:
# ============================================================
# 1.4  Construct node labels for HepTh (node classification)
# ============================================================
# Use Louvain community detection to create pseudo-labels
# (since there is no explicit node-level ground truth beyond subject class)
import community.community_louvain as community_louvain

partition_hepth = community_louvain.best_partition(G_hepth_lcc, random_state=SEED)
n_comms = len(set(partition_hepth.values()))
print(f"Louvain found {n_comms} communities on HepTh LCC")

# Keep communities with at least 20 members for meaningful classification
comm_sizes = Counter(partition_hepth.values())
valid_comms = {c for c, s in comm_sizes.items() if s >= 20}
hepth_node_labels = {n: c for n, c in partition_hepth.items() if c in valid_comms}

le_hepth = LabelEncoder()
labeled_nodes_hepth = sorted(hepth_node_labels.keys())
encoded_labels_hepth = le_hepth.fit_transform([hepth_node_labels[n] for n in labeled_nodes_hepth])

print(f"Keeping {len(valid_comms)} communities (>=20 members): {len(hepth_node_labels):,} labeled nodes")
print(f"Number of classes: {len(set(encoded_labels_hepth))}")

Louvain found 72 communities on HepTh LCC
Keeping 66 communities (>=20 members): 10,333 labeled nodes
Number of classes: 66


---
# Part 2: Structural Feature Engineering

Since neither dataset has inherent node attributes, we compute structural features
that will serve as input to GNN models and as baselines for comparison.

In [7]:
# ============================================================
# 2.1  Compute structural features for both graphs
# ============================================================
def compute_structural_features(G, name="graph"):
    """Compute structural node features for a graph."""
    print(f"Computing structural features for {name}...")
    t0 = time.time()

    # Work on a copy so you don't mutate the original graph
    G = G.copy()
    n_loops = nx.number_of_selfloops(G)
    if n_loops > 0:
        print(f"  Removing {n_loops} self-loops...")
        G.remove_edges_from(nx.selfloop_edges(G))

    features = {}

    degrees = dict(G.degree())
    features['degree'] = degrees
    print("  Degree: done")

    clustering = nx.clustering(G)
    features['clustering'] = clustering
    print("  Clustering: done")

    pagerank = nx.pagerank(G, max_iter=100)
    features['pagerank'] = pagerank
    print("  PageRank: done")

    core = nx.core_number(G)
    features['core_number'] = core
    print("  Core number: done")

    print(f"  Total time: {time.time()-t0:.1f}s")
    return features


# For Amazon — compute on full graph (all these scale well)
amazon_features = compute_structural_features(G_amazon, "Amazon")

# For HepTh — compute on LCC
hepth_features = compute_structural_features(G_hepth_lcc, "HepTh LCC")

Computing structural features for Amazon...
  Degree: done
  Clustering: done
  PageRank: done
  Core number: done
  Total time: 21.8s
Computing structural features for HepTh LCC...
  Removing 12 self-loops...
  Degree: done
  Clustering: done
  PageRank: done
  Core number: done
  Total time: 0.3s


In [8]:
# ============================================================
# 2.2  Build feature matrices
# ============================================================
def build_feature_matrix(G, features, node_list=None):
    """Build a numpy feature matrix from structural features."""
    if node_list is None:
        node_list = sorted(G.nodes())

    feat_names = ['degree', 'clustering', 'pagerank', 'core_number']
    X = np.zeros((len(node_list), len(feat_names)))

    for j, fname in enumerate(feat_names):
        for i, node in enumerate(node_list):
            X[i, j] = features[fname].get(node, 0.0)

    # Normalize
    scaler = StandardScaler()
    X = scaler.fit_transform(X)

    return X, feat_names, node_list

# Amazon feature matrix (for all nodes, ordered)
amazon_nodes_sorted = sorted(G_amazon.nodes())
amazon_node_to_idx = {n: i for i, n in enumerate(amazon_nodes_sorted)}
X_amazon, feat_names, _ = build_feature_matrix(G_amazon, amazon_features, amazon_nodes_sorted)
print(f"Amazon feature matrix: {X_amazon.shape}")

# HepTh feature matrix
hepth_nodes_sorted = sorted(G_hepth_lcc.nodes())
hepth_node_to_idx = {n: i for i, n in enumerate(hepth_nodes_sorted)}
X_hepth, _, _ = build_feature_matrix(G_hepth_lcc, hepth_features, hepth_nodes_sorted)
print(f"HepTh feature matrix: {X_hepth.shape}")

Amazon feature matrix: (334863, 4)
HepTh feature matrix: (10397, 4)


---
# Part 3: Shallow Embeddings — DeepWalk & Node2Vec

## 3.1 Theory

**DeepWalk** performs uniform random walks and feeds the node sequences into Word2Vec (Skip-Gram with negative sampling). It captures community structure — nodes in the same densely-connected region tend to co-occur in walks.

**Node2Vec** extends this with biased random walks:
- **p** (return parameter): High p → less likely to backtrack → explores further
- **q** (in-out parameter): Low q → BFS-like (structural equivalence); High q → DFS-like (homophily)

Both are **unsupervised** and **transductive** — they learn fixed embeddings per node.

## 3.2 DeepWalk Implementation

In [9]:
# ============================================================
# 3.1  DeepWalk (uniform random walks + Word2Vec)
# ============================================================
def deepwalk(G, dimensions=128, walk_length=40, num_walks=10,
             window=5, workers=4, seed=42):
    """
    DeepWalk: uniform random walks + Skip-Gram.

    Parameters
    ----------
    G : nx.Graph
    dimensions : int — embedding dimensionality
    walk_length : int — length of each random walk
    num_walks : int — number of walks per node
    window : int — Word2Vec context window size

    Returns
    -------
    embeddings : dict {node: np.array}
    model : Word2Vec model
    """
    print(f"DeepWalk: {num_walks} walks × {walk_length} steps, dim={dimensions}")
    t0 = time.time()

    nodes = list(G.nodes())
    rng = np.random.default_rng(seed)

    # Generate random walks
    walks = []
    for _ in range(num_walks):
        rng.shuffle(nodes)
        for start_node in nodes:
            walk = [start_node]
            current = start_node
            for _ in range(walk_length - 1):
                neighbors = list(G.neighbors(current))
                if len(neighbors) == 0:
                    break
                current = neighbors[rng.integers(len(neighbors))]
                walk.append(current)
            walks.append([str(n) for n in walk])  # Word2Vec needs strings

    print(f"  Generated {len(walks):,} walks in {time.time()-t0:.1f}s")

    # Train Word2Vec
    t1 = time.time()
    model = Word2Vec(
        walks, vector_size=dimensions, window=window,
        min_count=0, sg=1, workers=workers, seed=seed, epochs=5
    )
    print(f"  Word2Vec trained in {time.time()-t1:.1f}s")

    # Extract embeddings
    embeddings = {}
    for node in G.nodes():
        key = str(node)
        if key in model.wv:
            embeddings[node] = model.wv[key]

    print(f"  Total time: {time.time()-t0:.1f}s, embedded {len(embeddings):,} nodes")
    return embeddings, model

## 3.3 Node2Vec Implementation

In [10]:
# ============================================================
# 3.2  Node2Vec (biased random walks + Word2Vec)
# ============================================================
def node2vec_embed(G, dimensions=128, walk_length=40, num_walks=10,
                   p=1.0, q=1.0, window=5, workers=4, seed=42):
    """
    Node2Vec with biased random walks.

    Parameters
    ----------
    p : float — return parameter (high = less backtracking)
    q : float — in-out parameter (low = BFS-like, high = DFS-like)
    """
    print(f"Node2Vec: p={p}, q={q}, dim={dimensions}, walks={num_walks}×{walk_length}")
    t0 = time.time()

    node2vec = Node2VecLib(
        G, dimensions=dimensions, walk_length=walk_length,
        num_walks=num_walks, p=p, q=q, workers=workers, seed=seed,
        quiet=True
    )

    t1 = time.time()
    print(f"  Walks generated in {t1-t0:.1f}s")

    model = node2vec.fit(window=window, min_count=0, batch_words=4, seed=seed)
    print(f"  Model trained in {time.time()-t1:.1f}s")

    embeddings = {}
    for node in G.nodes():
        key = str(node)
        if key in model.wv:
            embeddings[node] = model.wv[key]

    print(f"  Total: {time.time()-t0:.1f}s, embedded {len(embeddings):,} nodes")
    return embeddings, model

## 3.4 Run Shallow Embeddings on Both Datasets

In [12]:
%pip -q install tqdm
from tqdm.auto import tqdm
import random

In [13]:
from gensim.models import Word2Vec

def _random_walk(G, start, walk_length):
    walk = [start]
    for _ in range(walk_length - 1):
        cur = walk[-1]
        nbrs = list(G.neighbors(cur))
        if not nbrs:
            break
        walk.append(random.choice(nbrs))
    return walk

def deepwalk_with_pbar(G, dimensions=128, walk_length=40, num_walks=10, window=5, workers=2, seed=42):
    random.seed(seed)
    nodes = list(G.nodes())

    total = num_walks * len(nodes)
    walks = []
    pbar = tqdm(total=total, desc="Generating walks", leave=True)

    for _ in range(num_walks):
        random.shuffle(nodes)
        for v in nodes:
            walks.append([str(x) for x in _random_walk(G, v, walk_length)])
            pbar.update(1)

    pbar.close()

    print("Training Word2Vec...")
    model = Word2Vec(
        sentences=walks,
        vector_size=dimensions,
        window=window,
        min_count=0,
        sg=1,
        workers=workers
    )

    emb = {int(k) if k.isdigit() else k: model.wv[k] for k in model.wv.index_to_key}
    return emb, model


In [None]:
# ============================================================
# 3.3  Run DeepWalk on both datasets
# ============================================================
EMB_DIM = 128

print("=" * 60)
print("AMAZON — DeepWalk")
print("=" * 60)
amazon_dw_emb, _ = deepwalk_with_pbar(G_amazon, dimensions=EMB_DIM, walk_length=40, num_walks=10, window=5)

print("\n" + "=" * 60)
print("HEPTH — DeepWalk")
print("=" * 60)
hepth_dw_emb, _ = deepwalk_with_pbar(G_hepth_lcc, dimensions=EMB_DIM, walk_length=40, num_walks=10, window=5)


AMAZON — DeepWalk


Generating walks:   0%|          | 0/3348630 [00:00<?, ?it/s]

Training Word2Vec...


---
# Part 4: Spectral Embeddings — Laplacian Eigenmaps

## 4.1 Theory

Laplacian Eigenmaps compute the embedding by finding the k smallest non-trivial eigenvectors
of the graph Laplacian matrix L = D − A (or its normalized variant L_norm = D^(-½) L D^(-½)).

The optimization objective is:
$$\min_Y \sum_{(i,j) \in E} w_{ij} \|y_i - y_j\|^2 = \min_Y \text{tr}(Y^T L Y)$$
subject to Y^T D Y = I.

This ensures connected nodes are close in the embedding space. The solution is given by
the eigenvectors corresponding to the smallest non-zero eigenvalues of L.

**Scalability note:** Eigendecomposition of L is O(n³) for dense methods. For the Amazon
graph (335K nodes), this is infeasible. We use scipy's sparse eigensolver (`eigsh`) which
works well up to ~50K nodes, and restrict Amazon spectral embeddings to a subgraph.

In [None]:
# ============================================================
# 4.1  Laplacian Eigenmaps implementation
# ============================================================
def laplacian_eigenmaps(G, dimensions=128, normalized=True):
    """
    Compute Laplacian Eigenmaps for graph G.

    Parameters
    ----------
    G : nx.Graph
    dimensions : int — number of embedding dimensions
    normalized : bool — use normalized Laplacian if True

    Returns
    -------
    embeddings : dict {node: np.array}
    eigenvalues : np.array
    """
    n = G.number_of_nodes()
    print(f"Laplacian Eigenmaps: {n:,} nodes, dim={dimensions}, normalized={normalized}")
    t0 = time.time()

    nodes = sorted(G.nodes())
    node_to_idx = {n: i for i, n in enumerate(nodes)}

    # Build sparse adjacency matrix
    rows, cols, vals = [], [], []
    for u, v, d in G.edges(data=True):
        w = d.get('weight', 1.0)
        i, j = node_to_idx[u], node_to_idx[v]
        rows.extend([i, j])
        cols.extend([j, i])
        vals.extend([w, w])

    A = sparse.csr_matrix((vals, (rows, cols)), shape=(n, n))
    D = sparse.diags(np.array(A.sum(axis=1)).flatten())
    L = D - A

    if normalized:
        # Normalized Laplacian: D^(-1/2) L D^(-1/2)
        D_inv_sqrt = sparse.diags(1.0 / np.sqrt(np.maximum(np.array(A.sum(axis=1)).flatten(), 1e-10)))
        L = D_inv_sqrt @ L @ D_inv_sqrt

    # Compute smallest eigenvectors (skip the first trivial one)
    k = min(dimensions + 1, n - 1)
    print(f"  Computing {k} smallest eigenvectors...")
    eigenvalues, eigenvectors = eigsh(L, k=k, which='SM', tol=1e-6)

    # Sort by eigenvalue and skip the first (zero eigenvalue)
    idx = np.argsort(eigenvalues)
    eigenvalues = eigenvalues[idx]
    eigenvectors = eigenvectors[:, idx]

    # Skip first eigenvector (constant), take next 'dimensions'
    embedding_matrix = eigenvectors[:, 1:dimensions+1]

    # Build embeddings dict
    embeddings = {}
    for i, node in enumerate(nodes):
        embeddings[node] = embedding_matrix[i]

    print(f"  Done in {time.time()-t0:.1f}s")
    print(f"  Smallest eigenvalues: {eigenvalues[:5].round(6)}")

    return embeddings, eigenvalues

In [None]:
# ============================================================
# 4.2  Run spectral embeddings
# ============================================================
# --- HepTh (full LCC — fits in memory) ---
print("=" * 60)
print("HEPTH — Laplacian Eigenmaps (full LCC)")
print("=" * 60)
hepth_spectral_emb, hepth_eigenvalues = laplacian_eigenmaps(
    G_hepth_lcc, dimensions=EMB_DIM, normalized=True
)

# --- Amazon (sampled subgraph — full graph is too large) ---
print("\n" + "=" * 60)
print("AMAZON — Laplacian Eigenmaps (sampled subgraph)")
print("=" * 60)

# BFS sample ~15K nodes for spectral embedding
def bfs_sample(G, n_nodes=15000, seed=42):
    rng = np.random.default_rng(seed)
    start = rng.choice(list(G.nodes()))
    visited = {start}
    q = deque([start])
    while q and len(visited) < n_nodes:
        u = q.popleft()
        for v in G.neighbors(u):
            if v not in visited:
                visited.add(v)
                q.append(v)
                if len(visited) >= n_nodes:
                    break
    return G.subgraph(visited).copy()

G_amazon_sub = bfs_sample(G_amazon, n_nodes=15000, seed=SEED)
print(f"Amazon subgraph for spectral: {G_amazon_sub.number_of_nodes():,} nodes, {G_amazon_sub.number_of_edges():,} edges")

amazon_spectral_emb, amazon_eigenvalues = laplacian_eigenmaps(
    G_amazon_sub, dimensions=EMB_DIM, normalized=True
)

In [None]:
# ============================================================
# 4.3  Visualize eigenvalue spectrum (spectral gap)
# ============================================================
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# HepTh
ax = axes[0]
ax.plot(range(1, min(51, len(hepth_eigenvalues))),
        hepth_eigenvalues[1:min(51, len(hepth_eigenvalues))],
        'o-', markersize=4, color='steelblue')
ax.set_xlabel("Eigenvalue index")
ax.set_ylabel("Eigenvalue")
ax.set_title("HepTh — Laplacian Eigenvalue Spectrum")
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

# Amazon subgraph
ax = axes[1]
ax.plot(range(1, min(51, len(amazon_eigenvalues))),
        amazon_eigenvalues[1:min(51, len(amazon_eigenvalues))],
        'o-', markersize=4, color='coral')
ax.set_xlabel("Eigenvalue index")
ax.set_ylabel("Eigenvalue")
ax.set_title("Amazon (subgraph) — Laplacian Eigenvalue Spectrum")
ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.savefig("eigenvalue_spectrum.png", dpi=150, bbox_inches='tight')
plt.show()

print("The spectral gap (difference between 1st and 2nd smallest eigenvalues)")
print("indicates community structure strength. A larger gap = clearer communities.")

---
# Part 5: Graph Neural Networks — GCN, GraphSAGE, GAT

## 5.1 Theory Overview

All GNNs follow the **message-passing** paradigm: each layer updates a node's representation
by aggregating information from its neighbors.

**GCN** (Graph Convolutional Network): Uses a fixed, symmetric normalization:
`h_v^(l+1) = σ(Σ_{u∈N(v)∪{v}} (1/√(d_u · d_v)) · h_u^(l) · W^(l))`

**GraphSAGE**: Samples a fixed number of neighbors and uses a learnable aggregator
(mean, LSTM, or max-pool). Key advantage: **inductive** + **scalable** via mini-batching.

**GAT**: Uses multi-head **attention** to learn neighbor importance weights:
`α_{ij} = softmax(LeakyReLU(a^T [W·h_i || W·h_j]))`
More expressive but more expensive. Multiple attention heads provide stability.

## 5.2 Model Definitions

In [None]:
# ============================================================
# 5.1  GCN Model
# ============================================================
class GCN(nn.Module):
    """
    Graph Convolutional Network for node classification.

    Architecture: Input -> GCN -> ReLU -> Dropout -> GCN -> ReLU -> Dropout -> GCN -> Output
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, out_channels))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

    def get_embedding(self, x, edge_index):
        """Return the penultimate layer embedding."""
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
        return x

In [None]:
# ============================================================
# 5.2  GraphSAGE Model
# ============================================================
class GraphSAGE(nn.Module):
    """
    GraphSAGE for node classification with neighbor sampling.

    Uses mean aggregation — the most common and scalable variant.
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

    def get_embedding(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.relu(x)
        return x

In [None]:
# ============================================================
# 5.3  GAT Model
# ============================================================
class GAT(nn.Module):
    """
    Graph Attention Network for node classification.

    Uses multi-head attention to learn neighbor importance.
    First layer uses 'heads' attention heads (outputs are concatenated).
    Last layer uses 1 head (output is averaged for stability).
    """
    def __init__(self, in_channels, hidden_channels, out_channels,
                 num_layers=2, heads=4, dropout=0.5):
        super().__init__()
        self.convs = nn.ModuleList()
        # First layer: multi-head attention, output concatenated
        self.convs.append(GATConv(in_channels, hidden_channels, heads=heads, dropout=dropout))
        for _ in range(num_layers - 2):
            self.convs.append(GATConv(hidden_channels * heads, hidden_channels, heads=heads, dropout=dropout))
        # Last layer: single head, average
        self.convs.append(GATConv(hidden_channels * heads, out_channels, heads=1, concat=False, dropout=dropout))
        self.dropout = dropout

    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs[:-1]):
            x = conv(x, edge_index)
            x = F.elu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, edge_index)
        return x

    def get_embedding(self, x, edge_index):
        for conv in self.convs[:-1]:
            x = conv(x, edge_index)
            x = F.elu(x)
        return x

## 5.3 Training Utilities

In [6]:
# ============================================================
# 5.4  GNN Training utilities
# ============================================================
def train_gnn_epoch(model, data, optimizer, criterion, mask):
    """Train one epoch for node classification."""
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out[mask], data.y[mask])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def eval_gnn(model, data, mask):
    """Evaluate node classification performance."""
    model.eval()
    out = model(data.x, data.edge_index)
    pred = out[mask].argmax(dim=1)
    y_true = data.y[mask].cpu().numpy()
    y_pred = pred.cpu().numpy()

    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average='macro', zero_division=0)
    f1_micro = f1_score(y_true, y_pred, average='micro', zero_division=0)

    return {'accuracy': acc, 'f1_macro': f1_macro, 'f1_micro': f1_micro}

def train_gnn_full(model, data, epochs=200, lr=0.01, weight_decay=5e-4,
                   patience=20, verbose=True):
    """
    Full training loop with early stopping.

    Returns the best model state (by validation F1) and training history.
    """
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.CrossEntropyLoss()

    best_val_f1 = 0
    best_state = None
    patience_counter = 0
    history = {'train_loss': [], 'val_acc': [], 'val_f1': []}

    for epoch in range(1, epochs + 1):
        loss = train_gnn_epoch(model, data, optimizer, criterion, data.train_mask)
        val_metrics = eval_gnn(model, data, data.val_mask)

        history['train_loss'].append(loss)
        history['val_acc'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1_macro'])

        if val_metrics['f1_macro'] > best_val_f1:
            best_val_f1 = val_metrics['f1_macro']
            best_state = {k: v.clone() for k, v in model.state_dict().items()}
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience:
            if verbose:
                print(f"  Early stopping at epoch {epoch}")
            break

        if verbose and epoch % 50 == 0:
            print(f"  Epoch {epoch:3d}: loss={loss:.4f}, val_acc={val_metrics['accuracy']:.4f}, val_f1={val_metrics['f1_macro']:.4f}")

    # Restore best model
    if best_state is not None:
        model.load_state_dict(best_state)

    # Final test evaluation
    test_metrics = eval_gnn(model, data, data.test_mask)

    return model, history, test_metrics

In [None]:
# ============================================================
# 5.5  Prepare PyG Data objects
# ============================================================
def prepare_pyg_data(G, X_features, node_labels, labeled_nodes,
                     node_to_idx, encoded_labels, train_ratio=0.6,
                     val_ratio=0.2, seed=42):
    """
    Convert a NetworkX graph + features + labels into a PyG Data object
    with train/val/test masks.
    """
    nodes_sorted = sorted(G.nodes())
    n = len(nodes_sorted)

    # Edge index
    edges = list(G.edges())
    src = [node_to_idx[u] for u, v in edges]
    dst = [node_to_idx[v] for u, v in edges]
    # Make undirected
    edge_index = torch.tensor([src + dst, dst + src], dtype=torch.long)

    # Node features
    x = torch.tensor(X_features, dtype=torch.float)

    # Labels (full array, -1 for unlabeled)
    y = torch.full((n,), -1, dtype=torch.long)
    labeled_indices = [node_to_idx[node] for node in labeled_nodes]
    for idx, label in zip(labeled_indices, encoded_labels):
        y[idx] = label

    # Create masks
    labeled_indices = np.array(labeled_indices)
    labels_for_split = encoded_labels

    idx_train, idx_temp, y_train, y_temp = train_test_split(
        labeled_indices, labels_for_split,
        train_size=train_ratio, stratify=labels_for_split, random_state=seed
    )
    val_size = val_ratio / (1 - train_ratio)
    idx_val, idx_test, _, _ = train_test_split(
        idx_temp, y_temp,
        train_size=val_size, stratify=y_temp, random_state=seed
    )

    train_mask = torch.zeros(n, dtype=torch.bool)
    val_mask = torch.zeros(n, dtype=torch.bool)
    test_mask = torch.zeros(n, dtype=torch.bool)
    train_mask[idx_train] = True
    val_mask[idx_val] = True
    test_mask[idx_test] = True

    data = Data(x=x, edge_index=edge_index, y=y,
                train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

    print(f"PyG Data: {data}")
    print(f"  Train: {train_mask.sum().item():,}, Val: {val_mask.sum().item():,}, Test: {test_mask.sum().item():,}")
    print(f"  Num classes: {len(set(encoded_labels))}")

    return data

# --- HepTh PyG Data ---
print("Preparing HepTh PyG data...")
data_hepth = prepare_pyg_data(
    G_hepth_lcc, X_hepth, hepth_node_labels, labeled_nodes_hepth,
    hepth_node_to_idx, encoded_labels_hepth
)
data_hepth = data_hepth.to(DEVICE)

# --- Amazon PyG Data ---
print("\nPreparing Amazon PyG data...")
data_amazon = prepare_pyg_data(
    G_amazon, X_amazon, amazon_node_labels_filtered, labeled_nodes_amazon,
    amazon_node_to_idx, encoded_labels_amazon
)
data_amazon = data_amazon.to(DEVICE)

---
# Part 6: Node Classification Experiments

We now evaluate all embedding methods on the node classification task.

**Protocol:**
1. **Shallow + Spectral embeddings → Logistic Regression**: Use the learned embeddings as
   features and train a downstream classifier. This is the standard evaluation protocol
   for unsupervised embeddings.
2. **GNNs → End-to-end**: Train the GNN directly for classification (supervised).

This comparison is fair because we report the same metrics (accuracy, macro-F1) on the
same test set for all methods.

## 6.1 Embedding-based Classification (Logistic Regression)

In [None]:
# ============================================================
# 6.1  Evaluate embeddings with Logistic Regression
# ============================================================
def evaluate_embedding_classification(embeddings, labeled_nodes, encoded_labels,
                                      name="Embedding", test_size=0.2, val_size=0.2):
    """
    Evaluate node embeddings for classification using Logistic Regression.

    Returns dict with accuracy, f1_macro, f1_micro on the test set.
    """
    # Filter to nodes present in embedding
    valid = [(n, l) for n, l in zip(labeled_nodes, encoded_labels) if n in embeddings]
    if len(valid) == 0:
        print(f"  {name}: No valid nodes found!")
        return None

    nodes, labels = zip(*valid)
    X = np.array([embeddings[n] for n in nodes])
    y = np.array(labels)

    # Split
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=(test_size + val_size), stratify=y, random_state=SEED
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=test_size/(test_size+val_size),
        stratify=y_temp, random_state=SEED
    )

    # Train Logistic Regression
    clf = LogisticRegression(max_iter=1000, solver='lbfgs', multi_class='multinomial', random_state=SEED)
    clf.fit(X_train, y_train)

    y_pred = clf.predict(X_test)

    results = {
        'accuracy': accuracy_score(y_test, y_pred),
        'f1_macro': f1_score(y_test, y_pred, average='macro', zero_division=0),
        'f1_micro': f1_score(y_test, y_pred, average='micro', zero_division=0),
        'n_test': len(y_test),
    }

    print(f"  {name}: acc={results['accuracy']:.4f}, F1-macro={results['f1_macro']:.4f}, F1-micro={results['f1_micro']:.4f} (n={results['n_test']})")
    return results

In [None]:
# ============================================================
# 6.2  Node Classification — HepTh (all methods)
# ============================================================
print("=" * 70)
print("NODE CLASSIFICATION — HepTh")
print("=" * 70)

hepth_nc_results = {}

# --- Structural features baseline ---
struct_emb_hepth = {n: X_hepth[hepth_node_to_idx[n]] for n in hepth_nodes_sorted}
hepth_nc_results['Structural Features'] = evaluate_embedding_classification(
    struct_emb_hepth, labeled_nodes_hepth, encoded_labels_hepth, "Structural Features"
)

# --- DeepWalk ---
hepth_nc_results['DeepWalk'] = evaluate_embedding_classification(
    hepth_dw_emb, labeled_nodes_hepth, encoded_labels_hepth, "DeepWalk"
)

# --- Node2Vec variants ---
for config_name, emb in hepth_n2v_embs.items():
    name = f"Node2Vec ({config_name})"
    hepth_nc_results[name] = evaluate_embedding_classification(
        emb, labeled_nodes_hepth, encoded_labels_hepth, name
    )

# --- Spectral ---
hepth_nc_results['Spectral (Laplacian)'] = evaluate_embedding_classification(
    hepth_spectral_emb, labeled_nodes_hepth, encoded_labels_hepth, "Spectral (Laplacian)"
)

In [None]:
# ============================================================
# 6.3  Node Classification — HepTh — GNNs
# ============================================================
num_classes_hepth = len(set(encoded_labels_hepth))
in_channels_hepth = data_hepth.x.shape[1]

print("\n--- GCN ---")
model_gcn_hepth = GCN(in_channels_hepth, 128, num_classes_hepth, num_layers=2, dropout=0.5).to(DEVICE)
model_gcn_hepth, hist_gcn, test_gcn = train_gnn_full(
    model_gcn_hepth, data_hepth, epochs=300, lr=0.01, patience=30
)
hepth_nc_results['GCN'] = test_gcn
print(f"  GCN: acc={test_gcn['accuracy']:.4f}, F1-macro={test_gcn['f1_macro']:.4f}")

print("\n--- GraphSAGE ---")
model_sage_hepth = GraphSAGE(in_channels_hepth, 128, num_classes_hepth, num_layers=2, dropout=0.5).to(DEVICE)
model_sage_hepth, hist_sage, test_sage = train_gnn_full(
    model_sage_hepth, data_hepth, epochs=300, lr=0.01, patience=30
)
hepth_nc_results['GraphSAGE'] = test_sage
print(f"  GraphSAGE: acc={test_sage['accuracy']:.4f}, F1-macro={test_sage['f1_macro']:.4f}")

print("\n--- GAT ---")
model_gat_hepth = GAT(in_channels_hepth, 32, num_classes_hepth, num_layers=2, heads=4, dropout=0.5).to(DEVICE)
model_gat_hepth, hist_gat, test_gat = train_gnn_full(
    model_gat_hepth, data_hepth, epochs=300, lr=0.005, patience=30
)
hepth_nc_results['GAT'] = test_gat
print(f"  GAT: acc={test_gat['accuracy']:.4f}, F1-macro={test_gat['f1_macro']:.4f}")

In [None]:
# ============================================================
# 6.4  Node Classification — Amazon (all methods)
# ============================================================
print("=" * 70)
print("NODE CLASSIFICATION — Amazon")
print("=" * 70)

amazon_nc_results = {}

# --- Structural features baseline ---
struct_emb_amazon = {n: X_amazon[amazon_node_to_idx[n]] for n in amazon_nodes_sorted}
amazon_nc_results['Structural Features'] = evaluate_embedding_classification(
    struct_emb_amazon, labeled_nodes_amazon, encoded_labels_amazon, "Structural Features"
)

# --- DeepWalk ---
amazon_nc_results['DeepWalk'] = evaluate_embedding_classification(
    amazon_dw_emb, labeled_nodes_amazon, encoded_labels_amazon, "DeepWalk"
)

# --- Node2Vec variants ---
for config_name, emb in amazon_n2v_embs.items():
    name = f"Node2Vec ({config_name})"
    amazon_nc_results[name] = evaluate_embedding_classification(
        emb, labeled_nodes_amazon, encoded_labels_amazon, name
    )

# --- Spectral (subgraph only) ---
# Only evaluate nodes in the subgraph
sub_labeled = [n for n in labeled_nodes_amazon if n in amazon_spectral_emb]
sub_labels = [encoded_labels_amazon[labeled_nodes_amazon.index(n)] for n in sub_labeled]
if len(sub_labeled) > 100:
    amazon_nc_results['Spectral (subgraph)'] = evaluate_embedding_classification(
        amazon_spectral_emb, sub_labeled, sub_labels, "Spectral (subgraph)"
    )
else:
    print("  Spectral: too few labeled nodes in subgraph, skipping")

In [None]:
# ============================================================
# 6.5  Node Classification — Amazon — GNNs
# ============================================================
num_classes_amazon = len(set(encoded_labels_amazon))
in_channels_amazon = data_amazon.x.shape[1]

# For the large Amazon graph, we use GraphSAGE with mini-batch training
# GCN and GAT are run on a subgraph for comparison

print("\n--- GraphSAGE (full graph, mini-batch) ---")
model_sage_amazon = GraphSAGE(in_channels_amazon, 128, num_classes_amazon, num_layers=2, dropout=0.5).to(DEVICE)

# Mini-batch training with NeighborLoader
train_loader = NeighborLoader(
    data_amazon.cpu(),
    num_neighbors=[15, 10],
    batch_size=1024,
    input_nodes=data_amazon.train_mask.cpu(),
    shuffle=True,
)

optimizer = torch.optim.Adam(model_sage_amazon.parameters(), lr=0.01, weight_decay=5e-4)
criterion = nn.CrossEntropyLoss()

best_val_f1 = 0
best_state = None

for epoch in range(1, 101):
    model_sage_amazon.train()
    total_loss = 0
    for batch in train_loader:
        batch = batch.to(DEVICE)
        optimizer.zero_grad()
        out = model_sage_amazon(batch.x, batch.edge_index)
        # Only compute loss on seed nodes (first batch_size nodes)
        mask = batch.train_mask[:batch.batch_size] if hasattr(batch, 'train_mask') else slice(None, batch.batch_size)
        loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validate (full-batch on CPU if needed)
    if epoch % 10 == 0:
        val_metrics = eval_gnn(model_sage_amazon, data_amazon.to(DEVICE), data_amazon.val_mask.to(DEVICE))
        if val_metrics['f1_macro'] > best_val_f1:
            best_val_f1 = val_metrics['f1_macro']
            best_state = {k: v.clone() for k, v in model_sage_amazon.state_dict().items()}
        print(f"  Epoch {epoch}: loss={total_loss/len(train_loader):.4f}, val_F1={val_metrics['f1_macro']:.4f}")

if best_state:
    model_sage_amazon.load_state_dict(best_state)
test_sage_amazon = eval_gnn(model_sage_amazon, data_amazon.to(DEVICE), data_amazon.test_mask.to(DEVICE))
amazon_nc_results['GraphSAGE'] = test_sage_amazon
print(f"  GraphSAGE: acc={test_sage_amazon['accuracy']:.4f}, F1-macro={test_sage_amazon['f1_macro']:.4f}")

In [None]:
# ============================================================
# 6.6  Amazon GCN (on subgraph for comparison)
# ============================================================
# Prepare subgraph PyG data
sub_nodes = sorted(G_amazon_sub.nodes())
sub_node_to_idx = {n: i for i, n in enumerate(sub_nodes)}

# Build subgraph features
X_sub = np.array([X_amazon[amazon_node_to_idx[n]] for n in sub_nodes])

# Filter labeled nodes to subgraph
sub_labeled_nodes = [n for n in labeled_nodes_amazon if n in sub_node_to_idx]
sub_encoded = [encoded_labels_amazon[labeled_nodes_amazon.index(n)] for n in sub_labeled_nodes]

if len(sub_labeled_nodes) > 200:
    data_amazon_sub = prepare_pyg_data(
        G_amazon_sub, X_sub,
        {n: amazon_node_labels_filtered[n] for n in sub_labeled_nodes if n in amazon_node_labels_filtered},
        sub_labeled_nodes, sub_node_to_idx, sub_encoded
    ).to(DEVICE)

    num_classes_sub = len(set(sub_encoded))

    print("\n--- GCN (Amazon subgraph) ---")
    model_gcn_amazon = GCN(X_sub.shape[1], 128, num_classes_sub, num_layers=2, dropout=0.5).to(DEVICE)
    model_gcn_amazon, _, test_gcn_amazon = train_gnn_full(
        model_gcn_amazon, data_amazon_sub, epochs=300, lr=0.01, patience=30
    )
    amazon_nc_results['GCN (subgraph)'] = test_gcn_amazon
    print(f"  GCN (subgraph): acc={test_gcn_amazon['accuracy']:.4f}, F1-macro={test_gcn_amazon['f1_macro']:.4f}")

    print("\n--- GAT (Amazon subgraph) ---")
    model_gat_amazon = GAT(X_sub.shape[1], 32, num_classes_sub, num_layers=2, heads=4, dropout=0.5).to(DEVICE)
    model_gat_amazon, _, test_gat_amazon = train_gnn_full(
        model_gat_amazon, data_amazon_sub, epochs=300, lr=0.005, patience=30
    )
    amazon_nc_results['GAT (subgraph)'] = test_gat_amazon
    print(f"  GAT (subgraph): acc={test_gat_amazon['accuracy']:.4f}, F1-macro={test_gat_amazon['f1_macro']:.4f}")
else:
    print("  Too few labeled nodes in Amazon subgraph for GCN/GAT training")

---
# Part 7: Link Prediction Experiments

Link prediction evaluates whether an embedding can determine if two nodes should be connected.

**Protocol:**
1. **Edge split**: Remove a fraction of edges as positive test examples. Sample an equal
   number of non-edges as negative test examples.
2. **For shallow/spectral embeddings**: Compute edge scores using dot product, cosine
   similarity, or Hadamard product of node embeddings → Logistic Regression.
3. **For GNNs**: Train a link-prediction-specific model or use embedding similarity.

**Metrics**: AUC-ROC and Average Precision (AP).

In [None]:
# ============================================================
# 7.1  Link prediction utilities
# ============================================================
def prepare_link_prediction_data(G, test_ratio=0.1, val_ratio=0.05, seed=42):
    """
    Split edges into train/val/test for link prediction.

    Returns:
    - G_train: graph with test+val edges removed
    - pos_val_edges, neg_val_edges
    - pos_test_edges, neg_test_edges
    """
    rng = np.random.default_rng(seed)
    edges = list(G.edges())
    rng.shuffle(edges)

    n_test = int(len(edges) * test_ratio)
    n_val = int(len(edges) * val_ratio)

    test_edges = edges[:n_test]
    val_edges = edges[n_test:n_test + n_val]
    train_edges = edges[n_test + n_val:]

    # Build training graph
    G_train = nx.Graph()
    G_train.add_nodes_from(G.nodes())
    G_train.add_edges_from(train_edges)

    # Ensure connectivity of training graph (add back critical edges)
    # For simplicity, just use the edges as-is

    # Sample negative edges
    nodes = list(G.nodes())
    non_edges_set = set()
    existing_edges_set = set(G.edges()) | set((v, u) for u, v in G.edges())

    while len(non_edges_set) < n_test + n_val:
        u = nodes[rng.integers(len(nodes))]
        v = nodes[rng.integers(len(nodes))]
        if u != v and (u, v) not in existing_edges_set and (u, v) not in non_edges_set:
            non_edges_set.add((u, v))

    non_edges = list(non_edges_set)
    neg_test_edges = non_edges[:n_test]
    neg_val_edges = non_edges[n_test:n_test + n_val]

    print(f"Link prediction split:")
    print(f"  Train edges: {len(train_edges):,}")
    print(f"  Val edges: {len(val_edges):,} pos + {len(neg_val_edges):,} neg")
    print(f"  Test edges: {len(test_edges):,} pos + {len(neg_test_edges):,} neg")

    return G_train, val_edges, neg_val_edges, test_edges, neg_test_edges

def evaluate_link_prediction(embeddings, pos_edges, neg_edges, method='dot', name=""):
    """
    Evaluate link prediction using embedding similarity.

    Methods: 'dot' (dot product), 'cosine', 'hadamard' (+ LR)
    """
    def get_scores(edges, embeddings, method):
        scores = []
        valid = 0
        for u, v in edges:
            if u in embeddings and v in embeddings:
                eu, ev = embeddings[u], embeddings[v]
                if method == 'dot':
                    scores.append(np.dot(eu, ev))
                elif method == 'cosine':
                    norm = np.linalg.norm(eu) * np.linalg.norm(ev)
                    scores.append(np.dot(eu, ev) / max(norm, 1e-10))
                valid += 1
            else:
                scores.append(0.0)
        return np.array(scores), valid

    pos_scores, n_pos = get_scores(pos_edges, embeddings, method)
    neg_scores, n_neg = get_scores(neg_edges, embeddings, method)

    y_true = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
    y_scores = np.concatenate([pos_scores, neg_scores])

    auc = roc_auc_score(y_true, y_scores)
    ap = average_precision_score(y_true, y_scores)

    print(f"  {name} ({method}): AUC={auc:.4f}, AP={ap:.4f} (pos={n_pos}, neg={n_neg})")
    return {'auc': auc, 'ap': ap, 'method': method}

In [None]:
# ============================================================
# 7.2  Link Prediction — HepTh
# ============================================================
print("=" * 70)
print("LINK PREDICTION — HepTh")
print("=" * 70)

# Prepare edge split
G_hepth_train, val_e, neg_val_e, test_e, neg_test_e = prepare_link_prediction_data(
    G_hepth_lcc, test_ratio=0.1, val_ratio=0.05
)

# Re-train embeddings on the training graph (no data leakage!)
print("\nRetraining embeddings on training graph...")

hepth_lp_dw, _ = deepwalk(G_hepth_train, dimensions=EMB_DIM, walk_length=40, num_walks=10)
hepth_lp_n2v, _ = node2vec_embed(G_hepth_train, dimensions=EMB_DIM, walk_length=40, num_walks=10, p=1.0, q=0.5)
hepth_lp_spectral, _ = laplacian_eigenmaps(G_hepth_train, dimensions=EMB_DIM, normalized=True)

hepth_lp_results = {}

print("\nEvaluating on test edges:")
for emb_name, emb in [
    ("DeepWalk", hepth_lp_dw),
    ("Node2Vec (BFS)", hepth_lp_n2v),
    ("Spectral", hepth_lp_spectral),
]:
    for method in ['dot', 'cosine']:
        key = f"{emb_name} ({method})"
        hepth_lp_results[key] = evaluate_link_prediction(emb, test_e, neg_test_e, method=method, name=emb_name)

In [None]:
# ============================================================
# 7.3  Link Prediction — HepTh — GNN-based
# ============================================================
print("\n--- GNN Link Prediction (HepTh) ---")

class LinkPredictor(nn.Module):
    """Simple MLP link predictor on top of GNN embeddings."""
    def __init__(self, in_channels):
        super().__init__()
        self.lin1 = nn.Linear(2 * in_channels, in_channels)
        self.lin2 = nn.Linear(in_channels, 1)

    def forward(self, z_u, z_v):
        x = torch.cat([z_u, z_v], dim=-1)
        x = F.relu(self.lin1(x))
        return self.lin2(x).squeeze(-1)

class GNNLinkPredictor(nn.Module):
    """GNN encoder + Link predictor."""
    def __init__(self, encoder, predictor):
        super().__init__()
        self.encoder = encoder
        self.predictor = predictor

    def encode(self, x, edge_index):
        return self.encoder.get_embedding(x, edge_index)

    def predict(self, z, edge_label_index):
        return self.predictor(z[edge_label_index[0]], z[edge_label_index[1]])

# Prepare PyG data for link prediction (training graph)
hepth_train_nodes = sorted(G_hepth_train.nodes())
hepth_train_n2i = {n: i for i, n in enumerate(hepth_train_nodes)}

train_src = [hepth_train_n2i[u] for u, v in G_hepth_train.edges()]
train_dst = [hepth_train_n2i[v] for u, v in G_hepth_train.edges()]
train_edge_index = torch.tensor([train_src + train_dst, train_dst + train_src], dtype=torch.long)

# Features
X_hepth_train = np.array([X_hepth[hepth_node_to_idx[n]] if n in hepth_node_to_idx else np.zeros(X_hepth.shape[1])
                           for n in hepth_train_nodes])
x_train = torch.tensor(X_hepth_train, dtype=torch.float).to(DEVICE)
train_edge_index = train_edge_index.to(DEVICE)

# Positive and negative test edges as indices
def edges_to_index(edges, n2i):
    valid_edges = [(n2i[u], n2i[v]) for u, v in edges if u in n2i and v in n2i]
    if len(valid_edges) == 0:
        return torch.zeros((2, 0), dtype=torch.long)
    src, dst = zip(*valid_edges)
    return torch.tensor([src, dst], dtype=torch.long)

pos_test_idx = edges_to_index(test_e, hepth_train_n2i).to(DEVICE)
neg_test_idx = edges_to_index(neg_test_e, hepth_train_n2i).to(DEVICE)

# Train GNN for link prediction
encoder = GraphSAGE(X_hepth_train.shape[1], 128, 64, num_layers=2, dropout=0.3).to(DEVICE)
predictor = LinkPredictor(64).to(DEVICE)
lp_model = GNNLinkPredictor(encoder, predictor).to(DEVICE)

optimizer = torch.optim.Adam(lp_model.parameters(), lr=0.01)

# Training loop
for epoch in range(1, 201):
    lp_model.train()
    optimizer.zero_grad()

    z = lp_model.encode(x_train, train_edge_index)

    # Positive edges (sample from training)
    pos_idx = train_edge_index[:, :len(train_src)]  # use first half (undirected)
    pos_pred = lp_model.predict(z, pos_idx)

    # Negative sampling
    neg_idx = negative_sampling(train_edge_index, num_nodes=len(hepth_train_nodes),
                                num_neg_samples=pos_idx.shape[1])
    neg_pred = lp_model.predict(z, neg_idx)

    loss = F.binary_cross_entropy_with_logits(
        torch.cat([pos_pred, neg_pred]),
        torch.cat([torch.ones_like(pos_pred), torch.zeros_like(neg_pred)])
    )
    loss.backward()
    optimizer.step()

    if epoch % 50 == 0:
        print(f"  Epoch {epoch}: loss={loss.item():.4f}")

# Evaluate
lp_model.eval()
with torch.no_grad():
    z = lp_model.encode(x_train, train_edge_index)
    pos_scores = torch.sigmoid(lp_model.predict(z, pos_test_idx)).cpu().numpy()
    neg_scores = torch.sigmoid(lp_model.predict(z, neg_test_idx)).cpu().numpy()

y_true = np.concatenate([np.ones(len(pos_scores)), np.zeros(len(neg_scores))])
y_scores = np.concatenate([pos_scores, neg_scores])

gnn_auc = roc_auc_score(y_true, y_scores)
gnn_ap = average_precision_score(y_true, y_scores)
print(f"\n  GraphSAGE LP: AUC={gnn_auc:.4f}, AP={gnn_ap:.4f}")
hepth_lp_results['GraphSAGE (GNN)'] = {'auc': gnn_auc, 'ap': gnn_ap}

In [None]:
# ============================================================
# 7.4  Link Prediction — Amazon
# ============================================================
print("=" * 70)
print("LINK PREDICTION — Amazon")
print("=" * 70)

G_amazon_train, amz_val_e, amz_neg_val_e, amz_test_e, amz_neg_test_e = prepare_link_prediction_data(
    G_amazon, test_ratio=0.05, val_ratio=0.02  # smaller ratios for large graph
)

# Re-train embeddings on training graph
print("\nRetraining DeepWalk on Amazon training graph...")
amazon_lp_dw, _ = deepwalk(G_amazon_train, dimensions=EMB_DIM, walk_length=40, num_walks=10)

print("\nRetraining Node2Vec (BFS-like) on Amazon training graph...")
amazon_lp_n2v, _ = node2vec_embed(G_amazon_train, dimensions=EMB_DIM, walk_length=40, num_walks=10, p=1.0, q=0.5)

amazon_lp_results = {}

print("\nEvaluating on test edges:")
for emb_name, emb in [
    ("DeepWalk", amazon_lp_dw),
    ("Node2Vec (BFS)", amazon_lp_n2v),
]:
    for method in ['dot', 'cosine']:
        key = f"{emb_name} ({method})"
        amazon_lp_results[key] = evaluate_link_prediction(emb, amz_test_e, amz_neg_test_e, method=method, name=emb_name)

In [None]:
---
# Part 8: Results Comparison & Analysis

In [None]:
# ============================================================
# 8.1  Summary tables
# ============================================================
def results_to_df(results_dict, task="NC"):
    """Convert results dict to a pretty DataFrame."""
    rows = []
    for method, metrics in results_dict.items():
        if metrics is None:
            continue
        row = {'Method': method}
        row.update(metrics)
        rows.append(row)
    return pd.DataFrame(rows)

# --- Node Classification ---
print("=" * 70)
print("NODE CLASSIFICATION RESULTS")
print("=" * 70)

print("\n--- HepTh ---")
df_nc_hepth = results_to_df(hepth_nc_results)
if 'n_test' in df_nc_hepth.columns:
    df_nc_hepth = df_nc_hepth.drop(columns=['n_test'])
df_nc_hepth = df_nc_hepth.sort_values('f1_macro', ascending=False)
print(df_nc_hepth.to_string(index=False, float_format='%.4f'))

print("\n--- Amazon ---")
df_nc_amazon = results_to_df(amazon_nc_results)
if 'n_test' in df_nc_amazon.columns:
    df_nc_amazon = df_nc_amazon.drop(columns=['n_test'])
df_nc_amazon = df_nc_amazon.sort_values('f1_macro', ascending=False)
print(df_nc_amazon.to_string(index=False, float_format='%.4f'))

# --- Link Prediction ---
print("\n" + "=" * 70)
print("LINK PREDICTION RESULTS")
print("=" * 70)

print("\n--- HepTh ---")
df_lp_hepth = results_to_df(hepth_lp_results)
if 'method' in df_lp_hepth.columns:
    df_lp_hepth = df_lp_hepth.drop(columns=['method'])
df_lp_hepth = df_lp_hepth.sort_values('auc', ascending=False)
print(df_lp_hepth.to_string(index=False, float_format='%.4f'))

print("\n--- Amazon ---")
df_lp_amazon = results_to_df(amazon_lp_results)
if 'method' in df_lp_amazon.columns:
    df_lp_amazon = df_lp_amazon.drop(columns=['method'])
df_lp_amazon = df_lp_amazon.sort_values('auc', ascending=False)
print(df_lp_amazon.to_string(index=False, float_format='%.4f'))

In [None]:
# ============================================================
# 8.2  Visualization — Node Classification comparison bar chart
# ============================================================
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# HepTh
ax = axes[0]
if not df_nc_hepth.empty:
    methods = df_nc_hepth['Method'].values
    x = np.arange(len(methods))
    width = 0.35
    ax.bar(x - width/2, df_nc_hepth['accuracy'].values, width, label='Accuracy', color='steelblue', alpha=0.8)
    ax.bar(x + width/2, df_nc_hepth['f1_macro'].values, width, label='F1-macro', color='coral', alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
    ax.set_ylabel('Score')
    ax.set_title('Node Classification — HepTh', fontweight='bold')
    ax.legend()
    ax.set_ylim(0, 1.05)

# Amazon
ax = axes[1]
if not df_nc_amazon.empty:
    methods = df_nc_amazon['Method'].values
    x = np.arange(len(methods))
    ax.bar(x - width/2, df_nc_amazon['accuracy'].values, width, label='Accuracy', color='steelblue', alpha=0.8)
    ax.bar(x + width/2, df_nc_amazon['f1_macro'].values, width, label='F1-macro', color='coral', alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
    ax.set_ylabel('Score')
    ax.set_title('Node Classification — Amazon', fontweight='bold')
    ax.legend()
    ax.set_ylim(0, 1.05)

plt.tight_layout()
plt.savefig("node_classification_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# 8.3  Visualization — Link Prediction comparison
# ============================================================
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# HepTh LP
ax = axes[0]
if not df_lp_hepth.empty:
    methods = df_lp_hepth['Method'].values
    x = np.arange(len(methods))
    width = 0.35
    ax.bar(x - width/2, df_lp_hepth['auc'].values, width, label='AUC-ROC', color='steelblue', alpha=0.8)
    ax.bar(x + width/2, df_lp_hepth['ap'].values, width, label='Avg Precision', color='coral', alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
    ax.set_ylabel('Score')
    ax.set_title('Link Prediction — HepTh', fontweight='bold')
    ax.legend()
    ax.set_ylim(0, 1.05)

# Amazon LP
ax = axes[1]
if not df_lp_amazon.empty:
    methods = df_lp_amazon['Method'].values
    x = np.arange(len(methods))
    ax.bar(x - width/2, df_lp_amazon['auc'].values, width, label='AUC-ROC', color='steelblue', alpha=0.8)
    ax.bar(x + width/2, df_lp_amazon['ap'].values, width, label='Avg Precision', color='coral', alpha=0.8)
    ax.set_xticks(x)
    ax.set_xticklabels(methods, rotation=45, ha='right', fontsize=9)
    ax.set_ylabel('Score')
    ax.set_title('Link Prediction — Amazon', fontweight='bold')
    ax.legend()
    ax.set_ylim(0, 1.05)

plt.tight_layout()
plt.savefig("link_prediction_comparison.png", dpi=150, bbox_inches='tight')
plt.show()

---
# Part 9: Embedding Visualization

We use UMAP to project high-dimensional embeddings to 2D, colored by class label.
This gives an intuitive view of how well each method separates communities.

In [None]:
# ============================================================
# 9.1  UMAP Visualization of embeddings
# ============================================================
def plot_embedding_umap(embeddings, labeled_nodes, labels, title="", ax=None, max_points=5000):
    """
    Project embeddings to 2D with UMAP and plot, colored by label.
    """
    # Filter to labeled nodes present in embeddings
    valid = [(n, l) for n, l in zip(labeled_nodes, labels) if n in embeddings]
    if len(valid) == 0:
        return

    # Subsample if too many
    rng = np.random.default_rng(SEED)
    if len(valid) > max_points:
        idx = rng.choice(len(valid), size=max_points, replace=False)
        valid = [valid[i] for i in idx]

    nodes, labs = zip(*valid)
    X = np.array([embeddings[n] for n in nodes])
    y = np.array(labs)

    # UMAP
    reducer = umap.UMAP(n_components=2, random_state=SEED, n_neighbors=15, min_dist=0.1)
    X_2d = reducer.fit_transform(X)

    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))

    scatter = ax.scatter(X_2d[:, 0], X_2d[:, 1], c=y, cmap='tab20', s=5, alpha=0.6)
    ax.set_title(title, fontsize=11, fontweight='bold')
    ax.set_xticks([])
    ax.set_yticks([])

    return ax

In [None]:
# ============================================================
# 9.2  Visualize HepTh embeddings
# ============================================================
fig, axes = plt.subplots(2, 3, figsize=(18, 12))

embedding_configs = [
    ("DeepWalk", hepth_dw_emb),
    ("Node2Vec (BFS)", hepth_n2v_embs.get("BFS-like", {})),
    ("Node2Vec (DFS)", hepth_n2v_embs.get("DFS-like", {})),
    ("Spectral", hepth_spectral_emb),
    ("Node2Vec (Uniform)", hepth_n2v_embs.get("Uniform", {})),
]

for idx, (name, emb) in enumerate(embedding_configs):
    row, col = divmod(idx, 3)
    if emb:
        plot_embedding_umap(emb, labeled_nodes_hepth, encoded_labels_hepth,
                           title=f"HepTh — {name}", ax=axes[row][col])
    else:
        axes[row][col].text(0.5, 0.5, "N/A", ha='center', va='center')
        axes[row][col].set_title(f"HepTh — {name}")

# GNN embedding (extract from trained model)
model_sage_hepth.eval()
with torch.no_grad():
    gnn_emb_tensor = model_sage_hepth.get_embedding(data_hepth.x, data_hepth.edge_index)
    gnn_emb_np = gnn_emb_tensor.cpu().numpy()

gnn_emb_dict = {hepth_nodes_sorted[i]: gnn_emb_np[i] for i in range(len(hepth_nodes_sorted))}
plot_embedding_umap(gnn_emb_dict, labeled_nodes_hepth, encoded_labels_hepth,
                   title="HepTh — GraphSAGE (GNN)", ax=axes[1][2])

plt.suptitle("UMAP Projections of Node Embeddings (HepTh)", fontsize=14, fontweight='bold', y=1.01)
plt.tight_layout()
plt.savefig("hepth_embedding_umap.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================
# 9.3  Visualize Amazon embeddings (sampled)
# ============================================================
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

amazon_emb_configs = [
    ("DeepWalk", amazon_dw_emb),
    ("Node2Vec (BFS)", amazon_n2v_embs.get("BFS-like", {})),
    ("Spectral (subgraph)", amazon_spectral_emb),
]

for idx, (name, emb) in enumerate(amazon_emb_configs):
    if emb:
        plot_embedding_umap(emb, labeled_nodes_amazon, encoded_labels_amazon,
                           title=f"Amazon — {name}", ax=axes[idx], max_points=8000)
    else:
        axes[idx].text(0.5, 0.5, "N/A", ha='center', va='center')
        axes[idx].set_title(f"Amazon — {name}")

plt.suptitle("UMAP Projections of Node Embeddings (Amazon)", fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig("amazon_embedding_umap.png", dpi=150, bbox_inches='tight')
plt.show()

---
# Part 10: GNN Training Analysis

In [None]:
# ============================================================
# 10.1  Plot GNN training curves (HepTh)
# ============================================================
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

histories = {
    'GCN': hist_gcn,
    'GraphSAGE': hist_sage,
    'GAT': hist_gat,
}

for idx, (name, hist) in enumerate(histories.items()):
    ax = axes[idx]
    epochs = range(1, len(hist['train_loss']) + 1)

    ax2 = ax.twinx()
    ax.plot(epochs, hist['train_loss'], 'b-', alpha=0.6, label='Train Loss')
    ax2.plot(epochs, hist['val_f1'], 'r-', alpha=0.8, label='Val F1-macro')
    ax2.plot(epochs, hist['val_acc'], 'g--', alpha=0.6, label='Val Accuracy')

    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss', color='blue')
    ax2.set_ylabel('Score', color='red')
    ax.set_title(f'{name} Training Curves (HepTh)', fontweight='bold')

    lines1, labels1 = ax.get_legend_handles_labels()
    lines2, labels2 = ax2.get_legend_handles_labels()
    ax.legend(lines1 + lines2, labels1 + labels2, loc='center right', fontsize=8)

plt.tight_layout()
plt.savefig("gnn_training_curves.png", dpi=150, bbox_inches='tight')
plt.show()