In [15]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from urllib.parse import unquote
from torch.utils.data import Dataset, random_split
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GraphNorm
from tqdm import tqdm
import random

from torch_geometric.nn import TransformerConv, global_mean_pool

# Load Data

In [16]:
# Load text data
data = pd.read_csv(f"../../data/full_text_data.csv")

In [17]:
# Load links
links = pd.read_csv("../../data/Wikispeedia/links.tsv", sep="\t", names=["src", "tgt"], skiprows=12)
links["src"] = links["src"].map(lambda x: unquote(x))
links["tgt"] = links["tgt"].map(lambda x: unquote(x))

# Create adjacency matrix
ordered_data_titles = data["title"].tolist()
src_indices = links["src"].map(lambda x: ordered_data_titles.index(x))
tgt_indices = links["tgt"].map(lambda x: ordered_data_titles.index(x))
A = torch.zeros((len(ordered_data_titles), len(ordered_data_titles)))
A[src_indices, tgt_indices] = 1

In [18]:
# Load coherence graph
with open("../../data/coherence_graph.pkl", 'rb') as handle:
    coherence_graph = pickle.load(handle)

# Combine coherence graph with base links
edge_features = A * coherence_graph

  edge_features = A * coherence_graph


In [19]:
# Load node embeddings
with open("../../data/gpt4_embeddings.pkl", 'rb') as handle:
    obj = pickle.load(handle)
    node_static_embeddings = obj["embeddings"]
    del obj
node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)

  node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)


In [20]:
# Load user-extracted paths
paths_data = pd.read_csv(f"../../data/paths_no_back_links.tsv", sep="\t")
paths_data = paths_data[~(paths_data["rating"].isna())]

# Filter paths with at least four distinct pages
paths_data = paths_data[paths_data["path"].apply(lambda x: len(set(x.split(";"))) >= 4)]

# Map titles to indices
title_to_index = {unquote(title): idx for idx, title in enumerate(data['title'])}
paths = paths_data['path'].apply(lambda path: [title_to_index[unquote(title)] for title in path.split(';')]).tolist()
ratings = (paths_data['rating'] - 1).tolist()  # 0-indexed ratings

In [21]:
class PathDataset(Dataset):
    def __init__(self, paths, ratings, node_embeddings, edge_features):
        self.paths = paths
        self.ratings = ratings
        self.node_embeddings = node_embeddings
        self.edge_features = edge_features

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        rating = self.ratings[idx]
        nodes, edge_index, edge_weight = self.get_subgraph_edges(path)

        x = self.node_embeddings[nodes]

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_weight=edge_weight,
            y=torch.tensor([rating], dtype=torch.long)
        )
        return data

    def get_subgraph_edges(self, path):
        nodes = list(set(path))
        node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        edges = []
        edge_weights = []
        for i in nodes:
            for j in nodes:
                weight = self.edge_features[i, j]
                if weight > 0:
                    edges.append([node_to_idx[i], node_to_idx[j]])
                    edge_weights.append(weight)
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_weight = torch.tensor(edge_weights, dtype=torch.float)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_weight = torch.tensor([], dtype=torch.float)
        return nodes, edge_index, edge_weight

# Create dataset
dataset = PathDataset(paths, ratings, node_static_embeddings, edge_features)

# Split dataset
train_ratio = 0.85
val_ratio = 0.05
test_ratio = 0.1
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

# Create data loaders
batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [22]:
# Counts occurrences of each class
class_counts = torch.bincount(
    torch.tensor(ratings)[train_dataset.indices].to(torch.int64)
)

# Calculate weights as the inverse of class frequencies
class_weights = 1.0 / class_counts.float()

# Normalize the weights so that they sum to the number of classes
class_weights = class_weights / class_weights.sum() * len(class_counts)

class_weights

tensor([0.4479, 0.3379, 0.4016, 1.0386, 2.7740])

#  Model

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_mean_pool, GraphNorm

class GraphTransformerModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, heads=4, dropout=0.2):
        super(GraphTransformerModel, self).__init__()
        self.conv1 = TransformerConv(
            in_channels, hidden_channels // heads, heads=heads, edge_dim=1, dropout=dropout)
        self.norm1 = GraphNorm(hidden_channels)
        self.conv2 = TransformerConv(
            hidden_channels, hidden_channels // heads, heads=heads, edge_dim=1, dropout=dropout)
        self.norm2 = GraphNorm(hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

        # Reshape edge_weight to [num_edges, 1]
        if edge_weight is not None:
            edge_attr = edge_weight.view(-1, 1)
        else:
            edge_attr = None

        x = self.conv1(x, edge_index, edge_attr)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index, edge_attr)
        x = self.norm2(x)
        x = F.relu(x)

        x = global_mean_pool(x, batch=data.batch)

        x = self.classifier(x)
        return x  # Return raw logits

# Training

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weights = class_weights.to(device)

model = GraphTransformerModel(
    in_channels=node_static_embeddings.shape[1],
    hidden_channels=64,
    num_classes=5,
    heads=4,
    dropout=0.2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

def train():
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
    return correct / total

best_val_acc = 0
for epoch in range(1, 201):
    loss = train()
    val_acc = evaluate(val_loader)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

100%|██████████| 3669/3669 [01:02<00:00, 58.34it/s] 


Epoch: 001, Loss: 1.5648, Val Acc: 0.3083


100%|██████████| 3669/3669 [01:05<00:00, 56.04it/s] 


Epoch: 002, Loss: 1.5523, Val Acc: 0.3354


100%|██████████| 3669/3669 [01:09<00:00, 52.45it/s] 


Epoch: 003, Loss: 1.5516, Val Acc: 0.2326


100%|██████████| 3669/3669 [01:13<00:00, 50.08it/s] 


Epoch: 004, Loss: 1.5515, Val Acc: 0.3122


100%|██████████| 3669/3669 [01:11<00:00, 51.25it/s] 


Epoch: 005, Loss: 1.5489, Val Acc: 0.2743


100%|██████████| 3669/3669 [00:58<00:00, 63.07it/s] 


Epoch: 006, Loss: 1.5483, Val Acc: 0.2883


100%|██████████| 3669/3669 [01:05<00:00, 56.25it/s] 


Epoch: 007, Loss: 1.5485, Val Acc: 0.2952


100%|██████████| 3669/3669 [01:14<00:00, 49.54it/s] 


Epoch: 008, Loss: 1.5482, Val Acc: 0.2867


100%|██████████| 3669/3669 [01:10<00:00, 51.80it/s] 


Epoch: 009, Loss: 1.5462, Val Acc: 0.2952


100%|██████████| 3669/3669 [00:51<00:00, 70.88it/s] 


Epoch: 010, Loss: 1.5459, Val Acc: 0.3060


100%|██████████| 3669/3669 [01:05<00:00, 55.91it/s] 


Epoch: 011, Loss: 1.5473, Val Acc: 0.2929


100%|██████████| 3669/3669 [01:11<00:00, 51.19it/s] 


Epoch: 012, Loss: 1.5469, Val Acc: 0.3138


100%|██████████| 3669/3669 [01:12<00:00, 50.78it/s] 


Epoch: 013, Loss: 1.5468, Val Acc: 0.2998


100%|██████████| 3669/3669 [01:11<00:00, 51.13it/s] 


Epoch: 014, Loss: 1.5475, Val Acc: 0.2890


100%|██████████| 3669/3669 [01:11<00:00, 51.22it/s] 


Epoch: 015, Loss: 1.5468, Val Acc: 0.3308


100%|██████████| 3669/3669 [01:04<00:00, 56.80it/s] 


Epoch: 016, Loss: 1.5474, Val Acc: 0.1994


100%|██████████| 3669/3669 [00:49<00:00, 74.03it/s] 


Epoch: 017, Loss: 1.5507, Val Acc: 0.2326


100%|██████████| 3669/3669 [00:52<00:00, 69.70it/s] 


Epoch: 018, Loss: 1.5468, Val Acc: 0.2759


100%|██████████| 3669/3669 [01:00<00:00, 60.52it/s] 


Epoch: 019, Loss: 1.5452, Val Acc: 0.3022


100%|██████████| 3669/3669 [00:42<00:00, 86.87it/s] 


Epoch: 020, Loss: 1.5460, Val Acc: 0.2883


100%|██████████| 3669/3669 [00:38<00:00, 96.52it/s] 


Epoch: 021, Loss: 1.5465, Val Acc: 0.2983


100%|██████████| 3669/3669 [00:38<00:00, 96.11it/s] 


Epoch: 022, Loss: 1.5472, Val Acc: 0.3107


100%|██████████| 3669/3669 [00:38<00:00, 96.22it/s] 


Epoch: 023, Loss: 1.5480, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:37<00:00, 98.99it/s] 


Epoch: 024, Loss: 1.5459, Val Acc: 0.3168


100%|██████████| 3669/3669 [00:38<00:00, 95.59it/s] 


Epoch: 025, Loss: 1.5450, Val Acc: 0.3006


100%|██████████| 3669/3669 [00:38<00:00, 95.32it/s] 


Epoch: 026, Loss: 1.5472, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:38<00:00, 95.99it/s] 


Epoch: 027, Loss: 1.5470, Val Acc: 0.2705


100%|██████████| 3669/3669 [00:37<00:00, 98.20it/s] 


Epoch: 028, Loss: 1.5456, Val Acc: 0.2666


100%|██████████| 3669/3669 [00:37<00:00, 98.46it/s] 


Epoch: 029, Loss: 1.5488, Val Acc: 0.2759


100%|██████████| 3669/3669 [00:37<00:00, 98.97it/s] 


Epoch: 030, Loss: 1.5487, Val Acc: 0.2975


100%|██████████| 3669/3669 [00:37<00:00, 97.24it/s] 


Epoch: 031, Loss: 1.5461, Val Acc: 0.2682


100%|██████████| 3669/3669 [00:36<00:00, 101.32it/s]


Epoch: 032, Loss: 1.5485, Val Acc: 0.2975


100%|██████████| 3669/3669 [00:37<00:00, 98.09it/s] 


Epoch: 033, Loss: 1.5472, Val Acc: 0.3431


100%|██████████| 3669/3669 [00:37<00:00, 97.55it/s] 


Epoch: 034, Loss: 1.5497, Val Acc: 0.2790


100%|██████████| 3669/3669 [00:38<00:00, 96.13it/s] 


Epoch: 035, Loss: 1.5457, Val Acc: 0.3014


100%|██████████| 3669/3669 [00:36<00:00, 99.70it/s] 


Epoch: 036, Loss: 1.5483, Val Acc: 0.2295


100%|██████████| 3669/3669 [00:37<00:00, 96.61it/s] 


Epoch: 037, Loss: 1.5470, Val Acc: 0.3130


100%|██████████| 3669/3669 [00:39<00:00, 93.71it/s] 


Epoch: 038, Loss: 1.5432, Val Acc: 0.2604


100%|██████████| 3669/3669 [00:36<00:00, 99.96it/s] 


Epoch: 039, Loss: 1.5468, Val Acc: 0.3292


100%|██████████| 3669/3669 [00:49<00:00, 73.70it/s] 


Epoch: 040, Loss: 1.5478, Val Acc: 0.2450


100%|██████████| 3669/3669 [00:52<00:00, 69.80it/s] 


Epoch: 041, Loss: 1.5470, Val Acc: 0.2496


100%|██████████| 3669/3669 [00:45<00:00, 79.98it/s] 


Epoch: 042, Loss: 1.5490, Val Acc: 0.3138


100%|██████████| 3669/3669 [00:36<00:00, 100.88it/s]


Epoch: 043, Loss: 1.5472, Val Acc: 0.2944


100%|██████████| 3669/3669 [00:37<00:00, 96.78it/s] 


Epoch: 044, Loss: 1.5495, Val Acc: 0.2689


100%|██████████| 3669/3669 [00:37<00:00, 97.75it/s] 


Epoch: 045, Loss: 1.5510, Val Acc: 0.3029


100%|██████████| 3669/3669 [00:38<00:00, 96.45it/s] 


Epoch: 046, Loss: 1.5484, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:36<00:00, 99.82it/s] 


Epoch: 047, Loss: 1.5481, Val Acc: 0.3192


100%|██████████| 3669/3669 [00:37<00:00, 96.58it/s] 


Epoch: 048, Loss: 1.5471, Val Acc: 0.2867


100%|██████████| 3669/3669 [00:37<00:00, 96.59it/s] 


Epoch: 049, Loss: 1.5471, Val Acc: 0.3238


100%|██████████| 3669/3669 [00:38<00:00, 95.64it/s] 


Epoch: 050, Loss: 1.5440, Val Acc: 0.2257


100%|██████████| 3669/3669 [00:37<00:00, 97.74it/s] 


Epoch: 051, Loss: 1.5457, Val Acc: 0.3161


100%|██████████| 3669/3669 [00:37<00:00, 96.79it/s] 


Epoch: 052, Loss: 1.5460, Val Acc: 0.1399


100%|██████████| 3669/3669 [00:37<00:00, 96.57it/s] 


Epoch: 053, Loss: 1.5460, Val Acc: 0.2859


100%|██████████| 3669/3669 [00:36<00:00, 100.15it/s]


Epoch: 054, Loss: 1.5454, Val Acc: 0.3825


100%|██████████| 3669/3669 [00:38<00:00, 96.48it/s] 


Epoch: 055, Loss: 1.5465, Val Acc: 0.2937


100%|██████████| 3669/3669 [00:40<00:00, 89.67it/s] 


Epoch: 056, Loss: 1.5429, Val Acc: 0.2535


100%|██████████| 3669/3669 [00:33<00:00, 108.97it/s]


Epoch: 057, Loss: 1.5437, Val Acc: 0.3354


100%|██████████| 3669/3669 [00:49<00:00, 74.55it/s] 


Epoch: 058, Loss: 1.5507, Val Acc: 0.2566


100%|██████████| 3669/3669 [00:55<00:00, 65.76it/s] 


Epoch: 059, Loss: 1.5469, Val Acc: 0.1051


100%|██████████| 3669/3669 [00:37<00:00, 98.99it/s] 


Epoch: 060, Loss: 1.5506, Val Acc: 0.3083


100%|██████████| 3669/3669 [00:36<00:00, 99.37it/s] 


Epoch: 061, Loss: 1.5465, Val Acc: 0.2937


100%|██████████| 3669/3669 [00:37<00:00, 99.13it/s] 


Epoch: 062, Loss: 1.5467, Val Acc: 0.3354


100%|██████████| 3669/3669 [00:36<00:00, 99.82it/s] 


Epoch: 063, Loss: 1.5470, Val Acc: 0.2852


100%|██████████| 3669/3669 [00:36<00:00, 99.66it/s] 


Epoch: 064, Loss: 1.5471, Val Acc: 0.2890


100%|██████████| 3669/3669 [00:37<00:00, 98.98it/s] 


Epoch: 065, Loss: 1.5482, Val Acc: 0.2782


100%|██████████| 3669/3669 [00:36<00:00, 99.93it/s] 


Epoch: 066, Loss: 1.5513, Val Acc: 0.2210


100%|██████████| 3669/3669 [00:35<00:00, 103.46it/s]


Epoch: 067, Loss: 1.5536, Val Acc: 0.2774


100%|██████████| 3669/3669 [00:36<00:00, 100.24it/s]


Epoch: 068, Loss: 1.5566, Val Acc: 0.2929


100%|██████████| 3669/3669 [00:36<00:00, 100.60it/s]


Epoch: 069, Loss: 1.5523, Val Acc: 0.2342


100%|██████████| 3669/3669 [00:36<00:00, 100.18it/s]


Epoch: 070, Loss: 1.5492, Val Acc: 0.2782


100%|██████████| 3669/3669 [00:36<00:00, 100.39it/s]


Epoch: 071, Loss: 1.5494, Val Acc: 0.3431


100%|██████████| 3669/3669 [00:36<00:00, 99.44it/s] 


Epoch: 072, Loss: 1.5457, Val Acc: 0.2821


100%|██████████| 3669/3669 [00:37<00:00, 98.23it/s] 


Epoch: 073, Loss: 1.5451, Val Acc: 0.2813


100%|██████████| 3669/3669 [00:36<00:00, 99.73it/s] 


Epoch: 074, Loss: 1.5477, Val Acc: 0.2195


100%|██████████| 3669/3669 [00:36<00:00, 99.94it/s] 


Epoch: 075, Loss: 1.5487, Val Acc: 0.3060


100%|██████████| 3669/3669 [00:35<00:00, 102.93it/s]


Epoch: 076, Loss: 1.5493, Val Acc: 0.2488


100%|██████████| 3669/3669 [00:36<00:00, 99.40it/s] 


Epoch: 077, Loss: 1.5474, Val Acc: 0.3122


100%|██████████| 3669/3669 [00:36<00:00, 99.70it/s] 


Epoch: 078, Loss: 1.5492, Val Acc: 0.2403


100%|██████████| 3669/3669 [00:36<00:00, 99.59it/s] 


Epoch: 079, Loss: 1.5460, Val Acc: 0.2828


100%|██████████| 3669/3669 [00:36<00:00, 99.37it/s] 


Epoch: 080, Loss: 1.5492, Val Acc: 0.1893


100%|██████████| 3669/3669 [00:36<00:00, 99.49it/s] 


Epoch: 081, Loss: 1.5491, Val Acc: 0.2334


100%|██████████| 3669/3669 [00:36<00:00, 99.98it/s] 


Epoch: 082, Loss: 1.5478, Val Acc: 0.1816


100%|██████████| 3669/3669 [00:36<00:00, 99.75it/s] 


Epoch: 083, Loss: 1.5461, Val Acc: 0.2852


100%|██████████| 3669/3669 [00:37<00:00, 98.96it/s] 


Epoch: 084, Loss: 1.5475, Val Acc: 0.2550


100%|██████████| 3669/3669 [00:35<00:00, 102.17it/s]


Epoch: 085, Loss: 1.5483, Val Acc: 0.3192


100%|██████████| 3669/3669 [00:37<00:00, 98.71it/s] 


Epoch: 086, Loss: 1.5498, Val Acc: 0.2318


100%|██████████| 3669/3669 [00:37<00:00, 98.01it/s] 


Epoch: 087, Loss: 1.5446, Val Acc: 0.3215


100%|██████████| 3669/3669 [00:42<00:00, 87.02it/s] 


Epoch: 088, Loss: 1.5445, Val Acc: 0.3114


100%|██████████| 3669/3669 [01:13<00:00, 50.12it/s] 


Epoch: 089, Loss: 1.5478, Val Acc: 0.3516


100%|██████████| 3669/3669 [00:51<00:00, 71.73it/s] 


Epoch: 090, Loss: 1.5482, Val Acc: 0.2867


100%|██████████| 3669/3669 [00:59<00:00, 61.17it/s] 


Epoch: 091, Loss: 1.5484, Val Acc: 0.3122


100%|██████████| 3669/3669 [01:05<00:00, 55.80it/s] 


Epoch: 092, Loss: 1.5457, Val Acc: 0.3114


100%|██████████| 3669/3669 [01:04<00:00, 56.69it/s] 


Epoch: 093, Loss: 1.5476, Val Acc: 0.2249


100%|██████████| 3669/3669 [01:10<00:00, 52.30it/s] 


Epoch: 094, Loss: 1.5414, Val Acc: 0.2017


100%|██████████| 3669/3669 [01:08<00:00, 53.51it/s] 


Epoch: 095, Loss: 1.5472, Val Acc: 0.2859


100%|██████████| 3669/3669 [01:10<00:00, 52.11it/s] 


Epoch: 096, Loss: 1.5483, Val Acc: 0.2388


100%|██████████| 3669/3669 [01:10<00:00, 51.84it/s] 


Epoch: 097, Loss: 1.5458, Val Acc: 0.2002


100%|██████████| 3669/3669 [01:02<00:00, 59.02it/s] 


Epoch: 098, Loss: 1.5463, Val Acc: 0.2403


100%|██████████| 3669/3669 [01:11<00:00, 51.38it/s] 


Epoch: 099, Loss: 1.5491, Val Acc: 0.2257


100%|██████████| 3669/3669 [01:11<00:00, 51.23it/s] 


Epoch: 100, Loss: 1.5489, Val Acc: 0.3037


100%|██████████| 3669/3669 [01:08<00:00, 53.80it/s] 


Epoch: 101, Loss: 1.5485, Val Acc: 0.2906


  9%|▉         | 347/3669 [00:07<01:07, 49.31it/s]


KeyboardInterrupt: 

In [None]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

test_acc = evaluate(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')

  model.load_state_dict(torch.load('best_model.pth'))


Test Accuracy: 0.3201
