# Link Prediction - Data Preparation & Feature Engineering

# Imports

In [1]:
import gzip
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pandas as pd
import seaborn as sns
import sys
import torch

from collections import Counter
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.nn import SAGEConv
from torch_geometric.nn.models import Node2Vec
from torch_geometric.utils import from_networkx, negative_sampling
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm

# Data Preparation

### Recreating the community subgraph

In [2]:
def load_graph(file_path):
    edges = []
    with gzip.open(file_path, "rt") as f:
        for line in f:
            if line.startswith("#"):
                continue
            source, target = map(int, line.strip().split())
            edges.append((source, target))
    return nx.Graph(edges)


def extract_top_communities_subgraph(G, community_path):

    community_map = {}
    community_id = 0

    # Read top-k communities and build the node set
    with gzip.open(community_path, "rt") as f:
        for line in f:
            nodes = list(map(int, line.strip().split("\t")))
            for node in nodes:
                if node in G and node not in community_map:
                    community_map[node] = community_id
            community_id += 1

    top_k_nodes = set(community_map.keys())

    # Include all edges where at least one node is in top-k
    expanded_edges = [
        (u, v) for u, v in G.edges() if u in top_k_nodes or v in top_k_nodes
    ]

    # Build the new graph
    G_expanded = nx.Graph()
    G_expanded.add_edges_from(expanded_edges)

    return G_expanded, community_map

In [3]:
graph = load_graph(file_path="data/raw/com-amazon.ungraph.txt.gz")
G, community_map = extract_top_communities_subgraph(
    graph, "data/raw/com-amazon.top5000.cmty.txt.gz"
)

print(
    f"Top-5000 expanded subgraph has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges"
)

Top-5000 expanded subgraph has 19905 nodes and 53780 edges


In [4]:
data = from_networkx(G)
data.num_nodes = G.number_of_nodes()
print(data)

Data(edge_index=[2, 107560], num_nodes=19905)


## Feature Engineering

Since you're working with the Amazon co-purchase graph (sparse edges, metadata available), and training a GNN for link prediction, use the following hybrid feature set:

✅ data.x = [Node2Vec Embedding] + [Log(Salesrank)] + [Product Group One-Hot] + [Degree Centrality]
Feature Source	Dimensionality	Why It Helps
Node2Vec Embedding	64–128	Captures local+global node context
Salesrank (log)	1	Popularity signal
Product Group (one-hot)	~3–5	Category indicator (Book/Music/etc)
Degree Centrality	1	Captures node popularity structurally

### 1. Node2Vec Training

In [8]:
node2vec = Node2Vec(
    data.edge_index,
    embedding_dim=64,
    walk_length=20,
    context_size=10,
    walks_per_node=10,
)
loader = node2vec.loader(batch_size=256, shuffle=True)

optimizer = torch.optim.Adam(node2vec.parameters(), lr=0.01)
for epoch in tqdm(range(100)):
    for pos_rw, neg_rw in loader:
        loss = node2vec.loss(pos_rw, neg_rw)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

100%|██████████| 100/100 [23:48<00:00, 14.28s/it]


In [9]:
def parse_amazon_meta_gz(path):
    metadata = {}
    current_id = None

    with gzip.open(path, "rt", encoding="latin-1") as f:
        for line in f:
            line = line.strip()

            if line.startswith("Id:"):
                current_id = int(line.split("Id:")[1].strip())
                metadata[current_id] = {}
            elif line.startswith("ASIN:") and current_id is not None:
                metadata[current_id]["asin"] = line.split("ASIN:")[1].strip()
            elif line.startswith("title:") and current_id is not None:
                metadata[current_id]["title"] = line.split("title:")[1].strip()
            elif line.startswith("group:") and current_id is not None:
                metadata[current_id]["group"] = line.split("group:")[1].strip()
            elif line.startswith("salesrank:") and current_id is not None:
                try:
                    metadata[current_id]["salesrank"] = int(
                        line.split("salesrank:")[1].strip()
                    )
                except ValueError:
                    metadata[current_id]["salesrank"] = None

    df = (
        pd.DataFrame.from_dict(metadata, orient="index")
        .reset_index()
        .rename(columns={"index": "id"})
    )
    return df

In [13]:
embeddings = node2vec.embedding.weight.detach()
print(embeddings.shape)

torch.Size([19905, 64])


In [10]:
metadata_df = parse_amazon_meta_gz("data/raw/amazon-meta.txt.gz")

In [15]:
df = metadata_df.set_index("id").loc[range(embeddings.shape[0])]  # Ensure node ordering
salesrank = np.log1p(df["salesrank"].fillna(df["salesrank"].max())).values.reshape(
    -1, 1
)
group_enc = OneHotEncoder(sparse_output=False).fit_transform(
    df["group"].fillna("Unknown").values.reshape(-1, 1)
)

  result = getattr(ufunc, method)(*inputs, **kwargs)


In [28]:
degree_centrality = nx.degree_centrality(G)
closeness_centrality = nx.closeness_centrality(G)
eigenvector_centrality = nx.eigenvector_centrality(G, max_iter=500, tol=1e-6)

centrality_df = pd.DataFrame(
    {
        "node": list(degree_centrality.keys()),
        "node_id": list(range(G.number_of_nodes())),
        "degree": list(degree_centrality.values()),
        "closeness": [closeness_centrality[n] for n in degree_centrality],
        "eigenvector": [eigenvector_centrality.get(n, 0.0) for n in degree_centrality],
    }
)
# centrality_df = centrality_df.set_index('node_id').sort_index()
print(centrality_df)

         node  node_id    degree  closeness    eigenvector
0           4        0  0.000301   0.000656   1.859320e-41
1       16050        1  0.000603   0.000875   2.923846e-41
2      286286        2  0.000151   0.000505   9.377731e-42
3      310803        3  0.000402   0.000625   2.280970e-41
4      320519        4  0.000703   0.000757   3.432783e-41
...       ...      ...       ...        ...            ...
19900  530702    19900  0.000100   0.000113  7.254087e-102
19901  502918    19901  0.000050   0.000090  3.911323e-102
19902  546178    19902  0.000201   0.000356   1.000616e-43
19903  547107    19903  0.000201   0.000356   1.000616e-43
19904  548142    19904  0.000201   0.000356   1.000616e-43

[19905 rows x 5 columns]


In [37]:
d_c = np.array(centrality_df["degree"]).reshape(-1, 1)
c_c = np.array(centrality_df["closeness"]).reshape(-1, 1)
e_c = np.array(centrality_df["eigenvector"]).reshape(-1, 1)

centrality_vec = np.hstack([d_c, c_c, e_c])
print(centrality_vec.shape)

(19905, 3)


In [38]:
embeddings.shape, salesrank.shape, group_enc.shape, centrality_vec.shape

(torch.Size([19905, 64]), (19905, 1), (19905, 6), (19905, 3))

In [39]:
features = np.hstack([embeddings.numpy(), salesrank, group_enc, centrality_vec])

In [44]:
features_df = pd.DataFrame(features)
features_df.replace([np.inf, -np.inf], np.nan, inplace=True)
features_df.fillna(0, inplace=True)

In [45]:
features.max(), features.min()

(15.148747965099776, -1.3330624103546143)

In [74]:
scaler = StandardScaler()
scaled_features = scaler.fit_transform(features)

data.x = torch.tensor(scaled_features, dtype=torch.float)
print(data)

Data(edge_index=[2, 77714], num_nodes=19905, x=[19905, 74], pos_edge_label=[38857], pos_edge_label_index=[2, 38857], neg_edge_label=[38857], neg_edge_label_index=[2, 38857])


In [75]:
transform = RandomLinkSplit(
    num_val=0.05,
    num_test=0.1,
    is_undirected=True,
    split_labels=True,
    add_negative_train_samples=True,
)

train_data, val_data, test_data = transform(data)

In [76]:
train_data

Data(edge_index=[2, 66060], num_nodes=19905, x=[19905, 74], pos_edge_label=[33030], pos_edge_label_index=[2, 33030], neg_edge_label=[33030], neg_edge_label_index=[2, 33030])

In [77]:
val_data

Data(edge_index=[2, 66060], num_nodes=19905, x=[19905, 74], pos_edge_label=[1942], pos_edge_label_index=[2, 1942], neg_edge_label=[1942], neg_edge_label_index=[2, 1942])

In [78]:
test_data

Data(edge_index=[2, 69944], num_nodes=19905, x=[19905, 74], pos_edge_label=[3885], pos_edge_label_index=[2, 3885], neg_edge_label=[3885], neg_edge_label_index=[2, 3885])

# Training

In [90]:
class GNNEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


class DotProductDecoder(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(2 * embed_dim, 64), nn.ReLU(), nn.Linear(64, 1)
        )

    def forward(self, z, edge_index):
        src = z[edge_index[0]]
        dst = z[edge_index[1]]
        out = torch.cat([src, dst], dim=1)
        return self.mlp(out).squeeze()


class LinkPredictionModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.encoder = GNNEncoder(in_channels, hidden_channels, out_channels)
        self.decoder = DotProductDecoder(out_channels)  # or use dot-product

    def forward(self, x, edge_index, edge_label_index):
        z = self.encoder(x, edge_index)
        return torch.sigmoid(self.decoder(z, edge_label_index))  # outputs probs

In [108]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LinkPredictionModel(
    in_channels=74,
    hidden_channels=64,
    out_channels=32,
)
model = model.to(device)
train_data = train_data.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0, weight_decay=5e-4)
loss_fn = nn.BCEWithLogitsLoss()

In [109]:
def train_link_prediction_model(
    model, optimizer, train_data, val_data, epochs=1000, log_every=100
):
    train_losses, val_losses = [], []
    train_aucs, val_aucs = [], []

    for epoch in range(1, epochs + 1):
        model.train()
        optimizer.zero_grad()

        pos_score = model(
            train_data.x, train_data.edge_index, train_data.pos_edge_label_index
        )
        neg_edge = negative_sampling(
            edge_index=train_data.edge_index,
            num_nodes=train_data.num_nodes,
            num_neg_samples=pos_score.size(0),
            method="sparse",
        )
        neg_score = model(train_data.x, train_data.edge_index, neg_edge)

        # Loss
        y_true = torch.cat(
            [
                torch.ones(pos_score.size(0), device=pos_score.device),
                torch.zeros(neg_score.size(0), device=neg_score.device),
            ]
        )
        y_pred = torch.cat([pos_score, neg_score])
        loss = loss_fn(y_true, y_pred)

        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

        # AUC
        train_auc = roc_auc_score(y_true.cpu(), y_pred.cpu().detach())
        train_aucs.append(train_auc)

        # --- Validation ---
        model.eval()
        with torch.no_grad():
            pos_val_score = model(
                val_data.x, val_data.edge_index, val_data.pos_edge_label_index
            )
            neg_val_score = model(
                val_data.x, val_data.edge_index, val_data.neg_edge_label_index
            )

            val_loss = compute_loss(pos_val_score, neg_val_score)
            val_losses.append(val_loss.item())

            y_true_val = torch.cat(
                [torch.ones_like(pos_val_score), torch.zeros_like(neg_val_score)]
            )
            y_score_val = torch.cat([pos_val_score, neg_val_score])
            val_auc = roc_auc_score(y_true_val.cpu(), y_score_val.cpu().detach())
            val_aucs.append(val_auc)

        if epoch % log_every == 0 or epoch == 1:
            print(
                f"Epoch {epoch:04d} | "
                f"Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f} | "
                f"Train AUC: {train_auc:.4f} | Val AUC: {val_auc:.4f}"
            )

    return train_losses, val_losses, train_aucs, val_aucs

In [110]:
train_losses, val_losses, train_aucs, val_aucs = train_link_prediction_model(
    model, optimizer, train_data, val_data, epochs=1000, log_every=10
)

Epoch 0001 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4568 | Val AUC: 0.4871
Epoch 0010 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4540 | Val AUC: 0.4871
Epoch 0020 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4544 | Val AUC: 0.4871
Epoch 0030 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4558 | Val AUC: 0.4871
Epoch 0040 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4548 | Val AUC: 0.4871
Epoch 0050 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4577 | Val AUC: 0.4871
Epoch 0060 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4561 | Val AUC: 0.4871
Epoch 0070 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4575 | Val AUC: 0.4871
Epoch 0080 | Train Loss: 0.7531 | Val Loss: 0.7230 | Train AUC: 0.4537 | Val AUC: 0.4871


KeyboardInterrupt: 