## Graph transformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import os
import urllib
import tarfile


In [None]:
# --- DEVICE SETUP ---
# Automatically detect and use the MPS device on Apple Silicon
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

In [None]:
# Download and extract if not present
if not os.path.exists("cora"):
    print("Downloading Cora dataset...")
    urllib.request.urlretrieve(
        "https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz", "cora.tgz"
    )
    with tarfile.open("cora.tgz", "r:gz") as tar:
        tar.extractall()

# Load node features and labels (cora.content)
content_path = "cora/cora.content"
content = np.genfromtxt(content_path, dtype=str)

# Load edge list (cora.cites)
cites_path = "cora/cora.cites"
cites = np.genfromtxt(cites_path, dtype=int)

# 1. Map original paper IDs to contiguous integer indices (0 to N-1)
paper_ids = content[:, 0]
id_to_idx = {int(p_id): i for i, p_id in enumerate(paper_ids)}
N = len(paper_ids)  # Number of nodes

# 2. Extract Features (Bag-of-Words)
# Move features tensor to the selected device
features = torch.tensor(content[:, 1:-1].astype(np.float32)).to(device)
D = features.size(1)  # Feature dimension

# 3. Extract and Encode Labels
labels_text = content[:, -1]
# Map text labels to integer indices (0 to C-1)
label_map = {name: i for i, name in enumerate(np.unique(labels_text))}
labels = np.array([label_map[name] for name in labels_text])
# Move labels tensor to the selected device
labels = torch.tensor(labels, dtype=torch.long).to(device)
C = len(label_map)  # Number of classes
class_names = list(label_map.keys())  # For visualization legend

# 4. Process Edges and Symmetrize
source_nodes = np.array([id_to_idx[src] for src in cites[:, 0]])
target_nodes = np.array([id_to_idx[tgt] for tgt in cites[:, 1]])

# Ensure tensors for stacking
source_tensor = torch.tensor(source_nodes, dtype=torch.long)
target_tensor = torch.tensor(target_nodes, dtype=torch.long)

edge_index = torch.stack([source_tensor, target_tensor], dim=0)

# Symmetrize edges
reversed_edges = torch.stack([target_tensor, source_tensor], dim=0)
edge_index = torch.cat([edge_index, reversed_edges], dim=1)
edge_index = torch.unique(edge_index, dim=1)


# 5. Create Adjacency Matrix and Attention Mask
def to_dense_adj(edge_index, num_nodes, device):
    # Create the adjacency matrix directly on the target device
    adj = torch.zeros(num_nodes, num_nodes, dtype=torch.float, device=device)
    # Ensure edge_index is on CPU before indexing if adj is on a GPU/MPS device
    adj[edge_index[0].to(device), edge_index[1].to(device)] = 1.0
    return adj


# Create adjacency and mask on the selected device
adj = to_dense_adj(edge_index, N, device)
adj = adj + torch.eye(N, device=device)  # Add self-loops
# Create mask (non-neighbors = -inf) on the selected device
attention_mask = (adj == 0).float() * (-1e9)


# 6. Create Standard Planetoid Masks
# Move masks to the selected device
train_mask = torch.zeros(N, dtype=torch.bool, device=device)
val_mask = torch.zeros(N, dtype=torch.bool, device=device)
test_mask = torch.zeros(N, dtype=torch.bool, device=device)

train_mask[:140] = True
val_mask[140 : 140 + 500] = True
test_mask[140 + 500 : 140 + 500 + 1000] = True

In [None]:
# --- 2. Pure PyTorch Graph Transformer Implementation (No change needed here) ---


class AttentionHead(nn.Module):
    def __init__(self, in_features, head_out_features):
        super().__init__()
        self.W_q = nn.Linear(in_features, head_out_features, bias=False)
        self.W_k = nn.Linear(in_features, head_out_features, bias=False)
        self.W_v = nn.Linear(in_features, head_out_features, bias=False)
        self.D_k = head_out_features**0.5

    def forward(self, x, mask):
        Q = self.W_q(x)
        K = self.W_k(x)
        V = self.W_v(x)
        scores = torch.matmul(Q, K.transpose(0, 1)) / self.D_k
        scores = scores + mask  # Apply Graph Mask
        attention_weights = F.softmax(scores, dim=1)
        output = torch.matmul(attention_weights, V)
        return output


class MultiHeadAttention(nn.Module):
    def __init__(self, in_features, n_heads, head_out_features):
        super().__init__()
        self.heads = nn.ModuleList(
            [AttentionHead(in_features, head_out_features) for _ in range(n_heads)]
        )
        self.linear_proj = nn.Linear(n_heads * head_out_features, in_features)

    def forward(self, x, mask):
        outputs = [head(x, mask) for head in self.heads]
        concat_output = torch.cat(outputs, dim=-1)
        output = self.linear_proj(concat_output)
        return output


class GraphTransformerLayer(nn.Module):
    def __init__(self, in_features, n_heads, head_out_features, dropout_rate=0.1):
        super().__init__()
        self.mha = MultiHeadAttention(in_features, n_heads, head_out_features)
        self.norm1 = nn.LayerNorm(in_features)
        self.dropout1 = nn.Dropout(dropout_rate)

        self.ffn = nn.Sequential(
            nn.Linear(in_features, in_features * 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(in_features * 4, in_features),
        )
        self.norm2 = nn.LayerNorm(in_features)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, mask):
        x_attn = self.mha(x, mask)
        x = x + self.dropout1(x_attn)
        x = self.norm1(x)

        x_ffn = self.ffn(x)
        x = x + self.dropout2(x_ffn)
        x = self.norm2(x)

        return x


class GraphTransformer(nn.Module):
    """
    The complete Graph Transformer Model, modified to return intermediate embeddings.
    """

    def __init__(
        self,
        in_features,
        hidden_dim,
        out_classes,
        n_layers,
        n_heads,
        head_out_features,
        dropout_rate=0.1,
    ):
        super().__init__()
        self.input_proj = nn.Linear(in_features, hidden_dim)
        self.layers = nn.ModuleList(
            [
                GraphTransformerLayer(
                    hidden_dim, n_heads, head_out_features, dropout_rate
                )
                for _ in range(n_layers)
            ]
        )
        self.output_proj = nn.Linear(hidden_dim, out_classes)

    def forward(self, x, mask):
        x = F.relu(self.input_proj(x))

        for layer in self.layers:
            x = layer(x, mask)

        # Store or return the final embeddings (x) before the classifier
        final_embeddings = x

        # Classification output
        classification_output = self.output_proj(x)

        # Return both the classification output and the hidden embeddings
        return F.log_softmax(classification_output, dim=1), final_embeddings


In [None]:
# --- 3. Model Initialization and Training ---

# Hyperparameters
HIDDEN_DIM = 64
N_HEADS = 8
HEAD_OUT_FEATURES = 8
N_LAYERS = 2
DROPOUT = 0.6
LR = 0.005
EPOCHS = 200

# Initialize Model and move it to the selected device
model = GraphTransformer(
    in_features=D,
    hidden_dim=HIDDEN_DIM,
    out_classes=C,
    n_layers=N_LAYERS,
    n_heads=N_HEADS,
    head_out_features=HEAD_OUT_FEATURES,
    dropout_rate=DROPOUT,
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=5e-4)


# Training function
def train(model, x, mask, y, train_mask, optimizer):
    model.train()
    optimizer.zero_grad()
    # out[0] is classification output, out[1] is embeddings
    out, _ = model(x, mask)
    loss = F.nll_loss(out[train_mask], y[train_mask])
    # For MPS/CUDA, ensure all gradients are processed before calling step()
    loss.backward()
    optimizer.step()
    return loss.item()


# Evaluation function
@torch.no_grad()
def test(model, x, mask, y, test_mask):
    model.eval()
    out, _ = model(x, mask)
    pred = out.argmax(dim=1)
    correct = pred[test_mask].eq(y[test_mask]).sum().item()
    acc = correct / test_mask.sum().item()
    return acc


# --- 4. Main Execution and Visualization ---

print(f"Starting training on Cora dataset (Nodes: {N}, Features: {D}, Classes: {C})...")
best_val_acc = 0.0
best_test_acc = 0.0

for epoch in range(1, EPOCHS + 1):
    loss = train(model, features, attention_mask, labels, train_mask, optimizer)
    val_acc = test(model, features, attention_mask, labels, val_mask)
    current_test_acc = test(model, features, attention_mask, labels, test_mask)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_test_acc = current_test_acc

    if epoch % 50 == 0 or epoch == EPOCHS:
        print(
            f"Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}, Test Acc: {current_test_acc:.4f}"
        )

print(f"\n--- Training Complete ---")
print(f"Best Test Accuracy (based on max validation accuracy): **{best_test_acc:.4f}**")

In [None]:
# Visualization function
def visualize_embeddings(embeddings, labels, class_names, title):
    print(
        f"\n--- Running t-SNE for Visualization (N={embeddings.shape[0]}, D={embeddings.shape[1]}) ---"
    )

    # Run t-SNE to reduce high-dimensional embeddings (N, D_hidden) to (N, 2)
    tsne = TSNE(n_components=2, random_state=42, perplexity=30)
    embeddings_2d = tsne.fit_transform(embeddings)

    plt.figure(figsize=(10, 8))

    # Get unique labels and create a color map
    unique_labels = np.unique(labels)
    colors = plt.cm.get_cmap("viridis", len(unique_labels))  # Use 'viridis' colormap

    # Scatter plot each class separately for a legend
    for i, label in enumerate(unique_labels):
        indices = labels == label
        plt.scatter(
            embeddings_2d[indices, 0],
            embeddings_2d[indices, 1],
            c=[colors(i)],
            label=class_names[label],
            s=20,
            alpha=0.7,
            edgecolors="w",
            linewidths=0.5,
        )

    plt.title(title)
    plt.legend(title="Research Area")
    plt.xlabel("t-SNE Dimension 1")
    plt.ylabel("t-SNE Dimension 2")
    plt.grid(True, linestyle="--", alpha=0.5)
    plt.show()

In [None]:
# Get the final embeddings and labels for visualization
model.eval()
with torch.no_grad():
    _, final_embeddings_tensor = model(features, attention_mask)

# Move tensors back to CPU and convert to numpy for non-PyTorch libraries (t-SNE and Matplotlib)
final_embeddings_np = final_embeddings_tensor.cpu().numpy()
labels_np = labels.cpu().numpy()

# Run Visualization
visualize_embeddings(
    final_embeddings_np,
    labels_np,
    class_names,
    "t-SNE Visualization of Node Embeddings (Graph Transformer on Cora)",
)