In [None]:
%load_ext autoreload
%autoreload 2 
%load_ext tensorboard

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.nn.functional import binary_cross_entropy_with_logits
import urllib.request
import os
import tarfile
import random
from icecream import ic
from torch.nn.functional import cross_entropy
from sklearn.model_selection import train_test_split

device = "mps" if torch.backends.mps.is_available() else "cpu"

# Cora dataset

https://graphsandnetworks.com/the-cora-dataset/

The Cora dataset consists of 2708 scientific publications classified into one of seven classes. The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued word vector indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words.

Load cora:

- D_inv is $D_{v,v}=Deg(v) = |N(v)|$, Deg is the nb of neighbours that each node has.
- adj_bar will be reused in $H^{(k+1)}=D^{-1}AH^{(k)}$. We will multiply by H at each layer.
- x: the features (words within doc)

In [None]:
# Load Cora dataset

# Download and extract if not present
if not os.path.exists("cora"):
    print("Downloading Cora dataset...")
    urllib.request.urlretrieve(
        "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz", "cora.tgz"
    )
    with tarfile.open("cora.tgz", "r:gz") as tar:
        tar.extractall()

# Load node features and labels
content_path = "cora/cora.content"
content = np.genfromtxt(content_path, dtype=str)
content

In [None]:
content.shape

In [None]:
node_ids = content[:, 0].astype(int)
features = content[:, 1:-1].astype(np.float32)
class_strs = content[:, -1]

N = len(node_ids)
node_map = {node_id: i for i, node_id in enumerate(node_ids)}

# Features matrix
X = features  # Already in shape (N, 1433)

# Labels: 7 classes
class_map = {
    "Case_Based": 0,
    "Genetic_Algorithms": 1,
    "Neural_Networks": 2,
    "Probabilistic_Methods": 3,
    "Reinforcement_Learning": 4,
    "Rule_Learning": 5,
    "Theory": 6,
}
labels = np.array([class_map[c] for c in class_strs])
labels.shape

In [None]:
node_map[35]

In [None]:
# Load edges (directed citations)
cites_path = "cora/cora.cites"
cites = np.genfromtxt(cites_path, dtype=int)
cites

edges = []
for src, dst in cites:
    if src in node_map and dst in node_map:  # Ensure both exist
        src_idx = node_map[src]
        dst_idx = node_map[dst]
        edges.append([src_idx, dst_idx])
edges


In [None]:
# Adjacency matrix (undirected, no self-loops)
A = np.zeros((N, N), dtype=np.float32)
for u, v in edges:
    A[u, v] = 1.0
    # A[v, u] = 1.0  # Make undirected
A

# try smoothing

In [None]:
I = np.eye(N)
A_tilde = A + np.identity(N)
D = np.sum(A_tilde, axis=1)
D


In [None]:
D_inv_sqrt = np.power(D, -0.5)
D_inv_sqrt = np.where(np.isinf(D_inv_sqrt), 0.0, D_inv_sqrt)
D_inv_sqrt = np.diag(D_inv_sqrt)
A_norm = D_inv_sqrt @ A_tilde @ D_inv_sqrt

A_norm

   A_tilde = A + I  # Add identity matrix

   D_tilde = degree_matrix(A_tilde)  # Now all nodes have degree â‰¥ 1
   
   D_inv = D_tilde^{-1}

In [None]:
# see page 50 of 03-GNN1.pdf
# Laplacian normalization D^-1 A

degrees = np.sum(A, axis=1)
degrees[degrees == 0] = 1.0
# degrees = degrees + 1
degrees
# Avoid division by zero
D_inv = np.diag(1.0 / (degrees))
D_inv


In [None]:
# adj_bar will be used in $$D^{-1}AH^{(k)}$$

adj_bar = D_inv @ A
adj_bar = torch.FloatTensor(adj_bar)
adj_bar.shape


In [None]:
x = torch.FloatTensor(X)
pos_edges = np.array(edges)  # Directed for positive samples
pos_edges

# x, adj_bar, pos_edges, labels, N = load_cora()
ic(x.shape, adj_bar.shape, pos_edges.shape, labels.shape, degrees.shape, N, set(labels))
features_nb = x.shape[1]
ic(features_nb);


In [None]:
degrees
np.percentile(degrees, 80)

In [None]:
# Network visualization of a subset of the graph
import networkx as nx
from matplotlib.patches import Patch

# Class distribution
class_names = [
    "Case_Based",
    "Genetic_Algorithms",
    "Neural_Networks",
    "Probabilistic_Methods",
    "Reinforcement_Learning",
    "Rule_Learning",
    "Theory",
]

# Create a NetworkX graph from a subset of nodes (for visualization purposes)
# Using only nodes with high degree to keep visualization manageable
# threshold  of degree where node degree is in the top 10%
high_degree_threshold = np.percentile(degrees, 80)  # Top 10% by degree

high_degree_nodes = np.where(degrees >= high_degree_threshold)[0]

print(
    f"Visualizing subgraph with {len(high_degree_nodes)} high-degree nodes (degree >= {high_degree_threshold:.0f})"
)

# Create subgraph
G = nx.Graph()
G.add_nodes_from(high_degree_nodes)

# Add edges between high-degree nodes
for u, v in pos_edges:
    if u in high_degree_nodes and v in high_degree_nodes:
        G.add_edge(u, v)

print(f"Subgraph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges")

# Create node colors based on class labels
node_colors = [labels[node] for node in G.nodes()]
colors = ["red", "blue", "green", "orange", "purple", "brown", "pink"]

plt.figure(figsize=(12, 10))
pos = nx.spring_layout(G, k=1, iterations=50)

nx.draw(
    G,
    pos,
    node_color=[colors[c] for c in node_colors],
    node_size=50,
    edge_color="gray",
    alpha=0.8,
    with_labels=False,
)

# Create legend

legend_elements = [
    Patch(facecolor=colors[i], label=class_names[i]) for i in range(len(class_names))
]
plt.legend(handles=legend_elements, loc="upper left", bbox_to_anchor=(1, 1))
plt.title(
    "Cora Citation Network (High-degree nodes subgraph)\nColored by Publication Class"
)
plt.axis("off")
plt.tight_layout()
plt.show()

In [None]:
def get_alias_edge(G, src, curr, p, q):
    """Compute transition probabilities for neighbors of 'dst' based on Node2Vec bias.
    dst: current node
    src: previous node
    p: return back to previous node. Lower p means backtracking is more likely.
    q: in-out param: explore new nodes (ratio of bfs[breadth-first search] to dfs[depth-first search]).
        lower q means more exploration of distant nodes
    """
    unnormalized_probs = []
    for neighbor in G[curr]:
        if neighbor == src:
            weight = 1.0 / p
        elif G.has_edge(
            neighbor, src
        ):  # if neighbor is a neighbor of src (above ex: s1 is neighbor of t)
            weight = 1.0
        else:
            weight = 1.0 / q
        unnormalized_probs.append(weight)
    norm_const = sum(unnormalized_probs)
    normalized_probs = [w / norm_const for w in unnormalized_probs]
    return list(G[curr]), normalized_probs


def node2vec_walk(G, start, walk_length, p, q):
    """Generate a random walk starting from the given node."""
    walk = [start]
    while len(walk) < walk_length:
        cur = walk[-1]
        neighbors = list(G.neighbors(cur))
        if len(neighbors) == 0:
            break
        if len(walk) == 1:
            next_node = random.choice(neighbors)
        else:
            prev = walk[-2]
            candidates, probs = get_alias_edge(G, src=prev, curr=cur, p=p, q=q)
            next_node = random.choices(candidates, weights=probs, k=1)[0]
        walk.append(next_node)
    return walk


def generate_random_walks(adj: np.numarray, walk_length=10, num_walks=5, p=1, q=1):
    G = nx.from_numpy_array(adj)
    walks = []
    for _ in range(num_walks):
        nodes = list(G.nodes())
        random.shuffle(nodes)
        for node in nodes:
            walk = node2vec_walk(G, node, walk_length, p=p, q=q)
            walks.append(walk)
    return walks


In [None]:
p = 1
q = 0.2
walk_length = 10
num_walks = 2
walks = generate_random_walks(
    adj=adj_bar.numpy(), walk_length=walk_length, num_walks=num_walks, p=p, q=q
)
len(walks)

# walks

In [None]:
ic(len(walks), walks[0]);

In [None]:
def gen_walk_pairs(walk, window_size):
    walk_pairs = []
    for i, target in enumerate(walk):
        start = max(0, i - window_size)
        end = min(len(walk), i + window_size + 1)
        ic(i, target, start, end)
        for j in range(start, end):
            if i != j:
                ic("pair", target, j, walk[j])
                if walk[j] != target:  # Avoid pairs of the same node
                    walk_pairs.append((target, walk[j]))

    return walk_pairs


In [None]:
ic.enable()
ic.disable()
pos_pairs = []
for walk in walks[:]:
    pairs = gen_walk_pairs(walk, window_size=2)
    pos_pairs.extend(pairs)

num_positive_samples = len(pos_pairs)
ic(num_positive_samples)

In [None]:
len(pos_pairs)

In [None]:
# Negative samples (random non-walk pairs)
ic.enable()
pos_pair_set = set(tuple(p) for p in pos_pairs)
ic(len(pos_pair_set))

neg_pairs = []
num_negative_samples = 0
while num_negative_samples < num_positive_samples:
    u, v = np.random.randint(0, N, 2)
    if u != v and (u, v) not in pos_pair_set:
        neg_pairs.append([u, v])
        num_negative_samples += 1

ic(len(neg_pairs), neg_pairs[:5]);


## Model

In [None]:
# GNN Layer (with separate W for neighbors and B for self)
class GNNLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(GNNLayer, self).__init__()
        self.W = nn.Parameter(torch.FloatTensor(in_dim, out_dim))
        self.B = nn.Parameter(torch.FloatTensor(in_dim, out_dim))
        nn.init.xavier_uniform_(self.W)
        nn.init.xavier_uniform_(self.B)

    def forward(self, h, adj_bar):
        # Neighbor aggregation: adj_bar @ h @ W (where adj_bar = D^{-1} A).
        # adj_bar stands for adjusted adjacency matrix.
        # torch.mm is optimized for 2d matrix operations: faster, less memory,..
        neigh_h = torch.mm(adj_bar, h)
        neigh_h = torch.mm(neigh_h, self.W)
        # Self transformation: h @ B
        self_h = torch.mm(h, self.B)
        # Add them
        return neigh_h + self_h


# GNN Encoder
class GNNEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim):
        super(GNNEncoder, self).__init__()
        self.layer1 = GNNLayer(
            input_dim, hidden_dim
        )  # Map features size to hidden size
        self.layer2 = GNNLayer(hidden_dim, hidden_dim)
        self.layer3 = GNNLayer(hidden_dim, embedding_dim)
        self.activation = nn.ReLU()

    def forward(self, x, adj_bar):
        h = self.layer1(x, adj_bar)
        h = self.activation(h)
        h = self.layer2(h, adj_bar)
        h = self.activation(h)
        h = self.layer3(h, adj_bar)  # No activation on last layer for embeddings
        return h


In [None]:
model = GNNEncoder(input_dim=features_nb, hidden_dim=32, embedding_dim=16).to(device)
z = model(x.to(device), adj_bar.to(device))
ic(
    z.shape
)  # Should be (N, 16), N being number of nodes. As expected, we have a 16-dim embedding for each node.

In [None]:
# Decoder: compute logit (dot product) as similarity measure. Like in node2vec or others previously seen.
def decode(z, edges):
    return torch.sum(z[edges[:, 0]] * z[edges[:, 1]], dim=1)


ic("There are ", pos_edges.shape[0], " edges in the graph. -> positive samples.")
ic(pos_edges.shape)
pos_edges_tensor = torch.LongTensor(pos_edges).to(device)
ic(decode(z, pos_edges_tensor).shape);  # Should be (num_pos,)

In [None]:
# Training

# x, adj_bar, pos_edges, labels, N = load_cora()
# data has been loaded above
model = GNNEncoder(input_dim=features_nb, hidden_dim=32, embedding_dim=16).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)


# Use randow walk
num_positive_pairs = len(pos_pairs)
num_negative_pairs = len(neg_pairs)
ic(num_positive_pairs, num_negative_pairs)
neg_pairs_tensor = torch.LongTensor(neg_pairs).to(device)
pos_pairs_tensor = torch.LongTensor(pos_pairs).to(device)

x = x.to(device)
adj_bar = adj_bar.to(device)  # (where adj_bar = D^{-1} A)
# labels = labels.to(device)

for epoch in range(500):
    model.train()
    optimizer.zero_grad()

    z = model(x, adj_bar)

    # Positive logits
    # for positive edges, decode similarity should be high (1)
    pos_logits = decode(z, pos_pairs_tensor)

    # for negative edges, decode similarity should be low (0)
    neg_logits = decode(z, neg_pairs_tensor)

    # Labels: 1 for pos, 0 for neg
    pos_labels = torch.ones(num_positive_pairs).to(device=device)
    neg_labels = torch.zeros(num_negative_pairs).to(device=device)

    # Loss
    loss = binary_cross_entropy_with_logits(
        pos_logits, pos_labels
    ) + binary_cross_entropy_with_logits(neg_logits, neg_labels)

    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# return model, x, adj_bar, labels
ic(model, x, adj_bar, labels);


In [None]:
# Inference: Generate embeddings
def inference(model, x, adj_bar):
    model.eval()
    with torch.no_grad():
        z = model(x, adj_bar)
    return z.cpu().numpy()


In [None]:
from sklearn.manifold import TSNE


def visualize_tsne_embeddings(z, labels):
    tsne = TSNE(n_components=2, random_state=42, perplexity=30, max_iter=1000)
    z_2d = tsne.fit_transform(z)

    colors = ["red", "blue", "green", "orange", "purple", "brown", "pink"]
    unique_labels = np.unique(labels)

    plt.figure(figsize=(10, 8))
    for i, lbl in enumerate(unique_labels):
        mask = labels == lbl
        plt.scatter(
            z_2d[mask, 0],
            z_2d[mask, 1],
            c=colors[i % len(colors)],
            label=f"Class {lbl}",
            alpha=0.7,
        )
    plt.legend()
    plt.title("2D Projection of Node Embeddings on Cora (t-SNE, colored by class)")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")
    plt.show()


z = inference(model, x, adj_bar)
visualize_tsne_embeddings(z, labels)

In [None]:
from mpl_toolkits.mplot3d import Axes3D


def visualize_tsne_embeddings_3d(z, labels):
    tsne = TSNE(n_components=3, random_state=42, perplexity=30, max_iter=1000)
    z_3d = tsne.fit_transform(z)

    colors = ["red", "blue", "green", "orange", "purple", "brown", "pink"]
    unique_labels = np.unique(labels)

    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection="3d")

    for i, lbl in enumerate(unique_labels):
        mask = labels == lbl
        ax.scatter(
            z_3d[mask, 0],
            z_3d[mask, 1],
            z_3d[mask, 2],
            c=colors[i % len(colors)],
            label=f"Class {lbl}",
            alpha=0.7,
            s=20,  # point size
        )

    ax.legend()
    ax.set_title("3D Projection of Node Embeddings on Cora (t-SNE, colored by class)")
    ax.set_xlabel("t-SNE Component 1")
    ax.set_ylabel("t-SNE Component 2")
    ax.set_zlabel("t-SNE Component 3")

    # Enable rotation
    ax.view_init(elev=20, azim=45)
    plt.show()


# Generate 3D t-SNE visualization
visualize_tsne_embeddings_3d(z, labels)

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
import os


def visualize_embeddings_tensorboard(z, labels, class_names, log_dir="runs/embeddings"):
    """
    Visualize embeddings in TensorBoard with interactive 3D projection
    """
    # Create log directory
    os.makedirs(log_dir, exist_ok=True)

    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir)

    # Convert to torch tensor if needed
    if isinstance(z, np.ndarray):
        embeddings = torch.FloatTensor(z)
    else:
        embeddings = z

    # Create metadata (class labels)
    metadata = [class_names[label] for label in labels]

    # Add embeddings to TensorBoard
    writer.add_embedding(embeddings, metadata=metadata, tag="Node_Embeddings")

    writer.close()
    print(f"Embeddings saved to TensorBoard. Run the following command to view:")
    print(f"tensorboard --logdir=nbks/{log_dir}")
    print("Then open your browser to http://localhost:6006")
    print("Navigate to the 'Projector' tab to see the interactive 3D visualization")


# Generate TensorBoard visualization
class_names = [
    "Case_Based",
    "Genetic_Algorithms",
    "Neural_Networks",
    "Probabilistic_Methods",
    "Reinforcement_Learning",
    "Rule_Learning",
    "Theory",
]

visualize_embeddings_tensorboard(z, labels, class_names)

# uv run tensorboard --logdir nbks/runs/embeddings

# use classification instead of similarity as target


In [None]:
class GNNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, embedding_dim, num_classes):
        super(GNNClassifier, self).__init__()
        self.layer1 = GNNLayer(input_dim, hidden_dim)
        self.layer2 = GNNLayer(hidden_dim, hidden_dim)
        self.layer3 = GNNLayer(hidden_dim, embedding_dim)
        self.classifier = nn.Linear(embedding_dim, num_classes)
        self.activation = nn.ReLU()

    def forward(self, x, adj_bar):
        h = self.layer1(x, adj_bar)
        h = self.activation(h)
        h = self.layer2(h, adj_bar)
        h = self.activation(h)
        h = self.layer3(h, adj_bar)  # Embeddings
        logits = self.classifier(h)  # Classification logits
        return h, logits  # Return embed

In [None]:
nb_classes = len(set(labels.tolist()))
y = torch.LongTensor(labels).to(device)
y

In [None]:
model = GNNClassifier(
    input_dim=features_nb, hidden_dim=32, embedding_dim=16, num_classes=nb_classes
).to(device)
h, train_logits = model(x.to(device), adj_bar.to(device))
ic(h.shape, train_logits.shape)  # h: (N, 16), logits: (N, 7)

In [None]:
model = GNNClassifier(
    input_dim=1433, hidden_dim=32, embedding_dim=16, num_classes=7
).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Split into train/val/test (60%/20%/20%)
idx = np.arange(N)
train_idx, test_idx = train_test_split(idx, test_size=0.4, random_state=42)
val_idx, test_idx = train_test_split(test_idx, test_size=0.5, random_state=42)
train_idx = torch.LongTensor(train_idx).to(device)
val_idx = torch.LongTensor(val_idx).to(device)
test_idx = torch.LongTensor(test_idx).to(device)

# Subset train data
train_x = x[train_idx]
train_adj = adj_bar[train_idx][:, train_idx]  # Subgraph adjacency
train_y = y[train_idx]
train_labels = labels[train_idx.cpu().numpy()]

val_x = x[val_idx]
val_adj = adj_bar[val_idx][:, val_idx]  # Subgraph adjacency
val_y = y[val_idx]
val_labels = labels[val_idx.cpu().numpy()]

best_val_loss = float("inf")
for epoch in range(200):
    model.train()
    optimizer.zero_grad()

    train_embeddings, train_logits = model(train_x, train_adj)  # Train on subgraph

    # Compute loss on training set
    train_loss = cross_entropy(train_logits, train_y)

    train_loss.backward()
    optimizer.step()

    # Validation on full graph
    model.eval()
    with torch.no_grad():
        val_embeddings, val_logits = model(
            val_x, val_adj
        )  # Use full graph for embeddings
        val_loss = cross_entropy(val_logits, val_y)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model.state_dict()

    if epoch % 20 == 0:
        print(
            f"Epoch {epoch}, Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}"
        )

# Load best model and generate full embeddings
model.load_state_dict(best_model)
model.eval()
with torch.no_grad():
    full_embeddings, _ = model(x, adj_bar)

In [None]:
visualize_tsne_embeddings(train_embeddings.detach().cpu().numpy(), train_labels)

In [None]:
visualize_tsne_embeddings(val_embeddings.cpu().numpy(), val_labels)

In [None]:
visualize_tsne_embeddings(full_embeddings.cpu().numpy(), labels)

# GraphSage

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# 1. Map original paper IDs to contiguous integer indices (0 to N-1)
paper_ids = content[:, 0]
id_to_idx = {int(p_id): i for i, p_id in enumerate(paper_ids)}
N = len(paper_ids)  # Number of nodes

# 2. Extract Features (Bag-of-Words)
features = torch.tensor(content[:, 1:-1].astype(np.float32)).to(device)
D = features.size(1)  # Feature dimension

# 3. Extract and Encode Labels
labels_text = content[:, -1]
label_map = {name: i for i, name in enumerate(np.unique(labels_text))}
labels = np.array([label_map[name] for name in labels_text])
labels = torch.tensor(labels, dtype=torch.long).to(device)
C = len(label_map)  # Number of classes
class_names = list(label_map.keys())  # For visualization legend

# 4. Process Edges and Symmetrize
source_nodes = np.array([id_to_idx[src] for src in cites[:, 0]])
target_nodes = np.array([id_to_idx[tgt] for tgt in cites[:, 1]])

# Ensure tensors for stacking
source_tensor = torch.tensor(source_nodes, dtype=torch.long)
target_tensor = torch.tensor(target_nodes, dtype=torch.long)

edge_index = torch.stack([source_tensor, target_tensor], dim=0)

# Symmetrize edges
reversed_edges = torch.stack([target_tensor, source_tensor], dim=0)
edge_index = torch.cat([edge_index, reversed_edges], dim=1)
edge_index = torch.unique(edge_index, dim=1)


# 5. Create Adjacency Matrix and Normalized Adjacency (D^{-1}A)
def create_normalized_adj(edge_index, num_nodes, device):
    # Create the adjacency matrix (A)
    adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float, device=device)
    adj[edge_index[0].to(device), edge_index[1].to(device)] = 1.0

    # Add self-loops (A_hat = A + I) for including self-feature in aggregation
    A_hat = adj + torch.eye(num_nodes, device=device)

    # Calculate degree matrix (D_hat) for Mean Aggregation: D_ii = Sum_j (A_hat_ij)
    D_hat_inv = torch.diag(1.0 / A_hat.sum(dim=1))

    # D_hat_inv @ A_hat is the normalized matrix where each row sums to 1 (Mean Aggregation)
    A_norm = torch.matmul(D_hat_inv, A_hat)
    return A_norm


# Create the normalized adjacency matrix for aggregation
A_norm = create_normalized_adj(edge_index, N, device)


# 6. Create Standard Planetoid Masks
train_mask = torch.zeros(N, dtype=torch.bool, device=device)
val_mask = torch.zeros(N, dtype=torch.bool, device=device)
test_mask = torch.zeros(N, dtype=torch.bool, device=device)

train_mask[:140] = True
val_mask[140 : 140 + 500] = True
test_mask[140 + 500 : 140 + 500 + 1000] = True


# --- 2. Pure PyTorch GraphSAGE Implementation ---


class SageConv(nn.Module):
    """
    GraphSAGE Convolutional Layer (Mean Aggregation) in Pure PyTorch.

    The update rule is h'_u = W * CONCAT(h_u, MEAN({h_v | v in N(u)}))
    In matrix form: H' = A_norm @ H @ W_agg + H @ W_self
    Here, we combine W_agg and W_self into a single CONCAT and W matrix for efficiency.
    """

    def __init__(self, in_features, out_features):
        super().__init__()
        # Since we aggregate neighbors (in_features) and concatenate the node's own feature (in_features),
        # the input dimension to the final linear layer is 2 * in_features.
        self.linear = nn.Linear(2 * in_features, out_features)

    def forward(self, x, A_norm):
        """
        x: Node features (N, D_in)
        A_norm: Normalized Adjacency Matrix (N, N) where A_norm[i, j] is 1/|N(i)| if j is a neighbor.
        """
        # 1. Neighborhood Aggregation (Mean)
        # H_agg = A_norm @ X  -> (N, N) x (N, D_in) = (N, D_in)
        # This computes the average of neighbor features (including self-loop)
        x_aggregated = torch.matmul(A_norm, x)

        # 2. Concatenation
        # CONCAT(X_self, X_aggregated) -> (N, 2 * D_in)
        x_combined = torch.cat([x, x_aggregated], dim=1)

        # 3. Projection and Activation
        # (N, 2 * D_in) @ W -> (N, D_out)
        output = self.linear(x_combined)

        return output


class GraphSAGE(nn.Module):
    """
    The complete GraphSAGE Model.
    """

    def __init__(self, in_features, hidden_dim, out_classes, dropout_rate=0.5):
        super().__init__()

        # Layer 1: in_features -> hidden_dim
        self.conv1 = SageConv(in_features, hidden_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dim)
        self.dropout1 = nn.Dropout(dropout_rate)

        # Layer 2: hidden_dim -> hidden_dim
        self.conv2 = SageConv(hidden_dim, hidden_dim)
        self.bn2 = nn.BatchNorm1d(hidden_dim)
        self.dropout2 = nn.Dropout(dropout_rate)

        # Final classification layer: hidden_dim -> out_classes
        self.output_proj = nn.Linear(hidden_dim, out_classes)

    def forward(self, x, A_norm):
        # Layer 1
        x = self.conv1(x, A_norm)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.dropout1(x)

        # Layer 2
        x = self.conv2(x, A_norm)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.dropout2(x)

        # Store final embeddings before classifier
        final_embeddings = x

        # Classification
        classification_output = self.output_proj(x)

        # Return both the classification output and the hidden embeddings
        return F.log_softmax(classification_output, dim=1), final_embeddings


# --- 3. Model Initialization and Training ---

# Hyperparameters (Adjusted for GraphSAGE common practice)
HIDDEN_DIM = 128  # Typically wider layers than attention for SAGE
DROPOUT = 0.5
LR = 0.01
EPOCHS = 200

# Initialize Model and move it to the selected device
model = GraphSAGE(
    in_features=D, hidden_dim=HIDDEN_DIM, out_classes=C, dropout_rate=DROPOUT
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)


# Training function
def train(model, x, A_norm, y, train_mask, optimizer):
    model.train()
    optimizer.zero_grad()
    # out[0] is classification output, out[1] is embeddings
    out, _ = model(x, A_norm)
    loss = F.nll_loss(out[train_mask], y[train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()


# Evaluation function
@torch.no_grad()
def test(model, x, A_norm, y, test_mask):
    model.eval()
    out, _ = model(x, A_norm)
    pred = out.argmax(dim=1)
    correct = pred[test_mask].eq(y[test_mask]).sum().item()
    acc = correct / test_mask.sum().item()
    return acc


# Visualization function (identical to previous file)
def visualize_embeddings(embeddings, labels, class_names, title):
    print(
        f"\n--- Running t-SNE for Visualization (N={embeddings.shape[0]}, D={embeddings.shape[1]}) ---"
    )

    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embeddings_2d = tsne.fit_transform(embeddings)

    plt.figure(figsize=(10, 8))

    unique_labels = np.unique(labels)
    colors = plt.cm.get_cmap("viridis", len(unique_labels))

    for i, label in enumerate(unique_labels):
        indices = labels == label
        plt.scatter(
            embeddings_2d[indices, 0],
            embeddings_2d[indices, 1],
            c=[colors(i)],
            label=class_names[label],
            s=20,
            alpha=0.7,
            edgecolors="w",
            linewidths=0.5,
        )

    plt.title(title)
    plt.legend(title="Research Area")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.show()


# --- 4. Main Execution and Visualization ---

print(f"Starting training on Cora dataset (Nodes: {N}, Features: {D}, Classes: {C})...")
best_val_acc = 0.0
best_test_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    loss = train(model, features, A_norm, labels, train_mask, optimizer)
    val_acc = test(model, features, A_norm, labels, val_mask)
    current_test_acc = test(model, features, A_norm, labels, test_mask)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = current_test_acc

    if epoch % 50 == 0 or epoch == EPOCHS:
        print(
            f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {current_test_acc:.4f}"
        )

print(f"\n--- Training Complete ---")
print(f"Best Test Accuracy (based on max validation accuracy): **{best_test_acc:.4f}**")

# Get the final embeddings and labels for visualization
model.eval()
with torch.no_grad():
    _, final_embeddings_tensor = model(features, A_norm)

# Move tensors back to CPU and convert to numpy for t-SNE and Matplotlib
final_embeddings_np = final_embeddings_tensor.cpu().numpy()
labels_np = labels.cpu().numpy()

# Run Visualization
visualize_embeddings(
    final_embeddings_np,
    labels_np,
    class_names,
    "t-SNE Visualization of Node Embeddings (GraphSAGE on Cora)",
)