In [40]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from urllib.parse import unquote
from torch.utils.data import Dataset, random_split
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GraphNorm
from tqdm import tqdm
import random
from torch_geometric.nn.models import GCN

# Load Data

In [41]:
# Load text data
data = pd.read_csv(f"../../data/full_text_data.csv")

In [42]:
# Load links
links = pd.read_csv("../../data/Wikispeedia/links.tsv", sep="\t", names=["src", "tgt"], skiprows=12)
links["src"] = links["src"].map(lambda x: unquote(x))
links["tgt"] = links["tgt"].map(lambda x: unquote(x))

# Create adjacency matrix
ordered_data_titles = data["title"].tolist()
src_indices = links["src"].map(lambda x: ordered_data_titles.index(x))
tgt_indices = links["tgt"].map(lambda x: ordered_data_titles.index(x))
A = torch.zeros((len(ordered_data_titles), len(ordered_data_titles)))
A[src_indices, tgt_indices] = 1

In [43]:
# Load coherence graph
with open("../../data/coherence_graph.pkl", 'rb') as handle:
    coherence_graph = pickle.load(handle)

# Combine coherence graph with base links
edge_features = A * coherence_graph

In [44]:
# Load node embeddings
with open("../../data/gpt4_embeddings.pkl", 'rb') as handle:
    obj = pickle.load(handle)
    node_static_embeddings = obj["embeddings"]
    del obj
node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)

  node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)


In [45]:
# Load user-extracted paths
paths_data = pd.read_csv(f"../../data/paths_no_back_links.tsv", sep="\t")
paths_data = paths_data[~(paths_data["rating"].isna())]

# Filter paths with at least four distinct pages
paths_data = paths_data[paths_data["path"].apply(lambda x: len(set(x.split(";"))) >= 4)]

# Map titles to indices
title_to_index = {unquote(title): idx for idx, title in enumerate(data['title'])}
paths = paths_data['path'].apply(lambda path: [title_to_index[unquote(title)] for title in path.split(';')]).tolist()
ratings = (paths_data['rating'] - 1).tolist()  # 0-indexed ratings

In [46]:
class PathDataset(Dataset):
    def __init__(self, paths, ratings, node_embeddings, edge_features):
        self.paths = paths
        self.ratings = ratings
        self.node_embeddings = node_embeddings
        self.edge_features = edge_features

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        rating = self.ratings[idx]
        nodes = list(set(path))
        x = self.node_embeddings[nodes]

        # Create edge index for subgraph
        edge_index = self.get_subgraph_edges(nodes)

        data = Data(
            x=x,
            edge_index=edge_index,
            y=torch.tensor([rating], dtype=torch.long)
        )
        return data

    def get_subgraph_edges(self, nodes):
        node_set = set(nodes)
        edges = []
        for i in node_set:
            for j in node_set:
                if self.edge_features[i, j] > 0:
                    edges.append([nodes.index(i), nodes.index(j)])
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
        return edge_index

# Create dataset
dataset = PathDataset(paths, ratings, node_static_embeddings, edge_features)

# Split dataset
train_ratio = 0.85
val_ratio = 0.05
test_ratio = 0.1
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

# Create data loaders
batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



# GCN model

In [47]:
class GCNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, dropout=0.2):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.norm1 = GraphNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.norm2 = GraphNorm(out_channels)
        self.classifier = torch.nn.Linear(out_channels, num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index)
        x = self.norm2(x)
        x = F.relu(x)

        x = global_mean_pool(x, batch=data.batch)

        x = self.classifier(x)
        return F.log_softmax(x, dim=1)

# Training

In [48]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCNModel(
    in_channels=node_static_embeddings.shape[1],
    hidden_channels=64,
    out_channels=32,
    num_classes=5  # Assuming ratings are from 1 to 5
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.NLLLoss()

def train():
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
    return correct / total

best_val_acc = 0
for epoch in range(1, 201):
    loss = train()
    val_acc = evaluate(val_loader)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

100%|██████████| 3669/3669 [00:14<00:00, 254.59it/s]


Epoch: 001, Loss: 1.4356, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:15<00:00, 238.67it/s]


Epoch: 002, Loss: 1.4327, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:17<00:00, 208.52it/s]


Epoch: 003, Loss: 1.4316, Val Acc: 0.3045


100%|██████████| 3669/3669 [00:15<00:00, 236.48it/s]


Epoch: 004, Loss: 1.4328, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:16<00:00, 227.59it/s]


Epoch: 005, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:15<00:00, 238.58it/s]


Epoch: 006, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 264.51it/s]


Epoch: 007, Loss: 1.4372, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 274.02it/s]


Epoch: 008, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 276.24it/s]


Epoch: 009, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:15<00:00, 238.04it/s]


Epoch: 010, Loss: 1.4370, Val Acc: 0.2898


100%|██████████| 3669/3669 [00:14<00:00, 260.82it/s]


Epoch: 011, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 265.57it/s]


Epoch: 012, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 277.66it/s]


Epoch: 013, Loss: 1.4363, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 244.88it/s]


Epoch: 014, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 255.96it/s]


Epoch: 015, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 262.41it/s]


Epoch: 016, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 266.77it/s]


Epoch: 017, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 269.29it/s]


Epoch: 018, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 257.93it/s]


Epoch: 019, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 247.50it/s]


Epoch: 020, Loss: 1.4371, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 250.81it/s]


Epoch: 021, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:15<00:00, 239.38it/s]


Epoch: 022, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:15<00:00, 237.33it/s]


Epoch: 023, Loss: 1.4371, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 258.51it/s]


Epoch: 024, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 260.08it/s]


Epoch: 025, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 269.93it/s]


Epoch: 026, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 275.10it/s]


Epoch: 027, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 275.50it/s]


Epoch: 028, Loss: 1.4359, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.47it/s]


Epoch: 029, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 280.40it/s]


Epoch: 030, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 281.37it/s]


Epoch: 031, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.06it/s]


Epoch: 032, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 277.71it/s]


Epoch: 033, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 281.92it/s]


Epoch: 034, Loss: 1.4362, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 271.91it/s]


Epoch: 035, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 266.27it/s]


Epoch: 036, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 274.52it/s]


Epoch: 037, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 280.72it/s]


Epoch: 038, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 275.03it/s]


Epoch: 039, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.17it/s]


Epoch: 040, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 277.48it/s]


Epoch: 041, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 281.32it/s]


Epoch: 042, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.21it/s]


Epoch: 043, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 275.70it/s]


Epoch: 044, Loss: 1.4369, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 278.89it/s]


Epoch: 045, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 244.87it/s]


Epoch: 046, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 272.71it/s]


Epoch: 047, Loss: 1.4362, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 270.87it/s]


Epoch: 048, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 246.80it/s]


Epoch: 049, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 274.86it/s]


Epoch: 050, Loss: 1.4369, Val Acc: 0.2898


100%|██████████| 3669/3669 [00:13<00:00, 274.93it/s]


Epoch: 051, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 269.36it/s]


Epoch: 052, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 278.28it/s]


Epoch: 053, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.67it/s]


Epoch: 054, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 272.60it/s]


Epoch: 055, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.21it/s]


Epoch: 056, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.53it/s]


Epoch: 057, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.05it/s]


Epoch: 058, Loss: 1.4365, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 279.23it/s]


Epoch: 059, Loss: 1.4367, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 277.46it/s]


Epoch: 060, Loss: 1.4368, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:13<00:00, 277.46it/s]


Epoch: 061, Loss: 1.4366, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 255.26it/s]


Epoch: 062, Loss: 1.4370, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:14<00:00, 253.89it/s]


Epoch: 063, Loss: 1.4368, Val Acc: 0.3184


 27%|██▋       | 999/3669 [00:04<00:11, 242.38it/s]


KeyboardInterrupt: 

In [None]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

test_acc = evaluate(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')