In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from urllib.parse import unquote
from torch.utils.data import Dataset, random_split
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GraphNorm
from tqdm import tqdm
import random

from torch_geometric.nn import TransformerConv, global_mean_pool

# Load Data

In [2]:
# Load text data
data = pd.read_csv(f"../../data/full_text_data.csv")

In [3]:
# Load links
links = pd.read_csv("../../data/Wikispeedia/links.tsv", sep="\t", names=["src", "tgt"], skiprows=12)
links["src"] = links["src"].map(lambda x: unquote(x))
links["tgt"] = links["tgt"].map(lambda x: unquote(x))

# Create adjacency matrix
ordered_data_titles = data["title"].tolist()
src_indices = links["src"].map(lambda x: ordered_data_titles.index(x))
tgt_indices = links["tgt"].map(lambda x: ordered_data_titles.index(x))
A = torch.zeros((len(ordered_data_titles), len(ordered_data_titles)))
A[src_indices, tgt_indices] = 1

In [4]:
# Load coherence graph
with open("../../data/coherence_graph.pkl", 'rb') as handle:
    coherence_graph = pickle.load(handle)

# Combine coherence graph with base links
edge_features = A * coherence_graph

  edge_features = A * coherence_graph


In [5]:
# Load node embeddings
with open("../../data/gpt4_embeddings.pkl", 'rb') as handle:
    obj = pickle.load(handle)
    node_static_embeddings = obj["embeddings"]
    del obj
node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)

  node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)


In [6]:
# Load user-extracted paths
paths_data = pd.read_csv(f"../../data/paths_no_back_links.tsv", sep="\t")
paths_data = paths_data[~(paths_data["rating"].isna())]

# Filter paths with at least four distinct pages
paths_data = paths_data[paths_data["path"].apply(lambda x: len(set(x.split(";"))) >= 4)]

# Map titles to indices
title_to_index = {unquote(title): idx for idx, title in enumerate(data['title'])}
paths = paths_data['path'].apply(lambda path: [title_to_index[unquote(title)] for title in path.split(';')]).tolist()
ratings = (paths_data['rating'] - 1).tolist()  # 0-indexed ratings

In [7]:
class PathDataset(Dataset):
    def __init__(self, paths, ratings, node_embeddings, edge_features):
        self.paths = paths
        self.ratings = ratings
        self.node_embeddings = node_embeddings
        self.edge_features = edge_features

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        rating = self.ratings[idx]
        nodes, edge_index, edge_weight = self.get_subgraph_edges(path)

        x = self.node_embeddings[nodes]

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_weight=edge_weight,
            y=torch.tensor([rating], dtype=torch.long)
        )
        return data

    def get_subgraph_edges(self, path):
        nodes = list(set(path))
        node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        edges = []
        edge_weights = []
        for i in nodes:
            for j in nodes:
                weight = self.edge_features[i, j]
                if weight > 0:
                    edges.append([node_to_idx[i], node_to_idx[j]])
                    edge_weights.append(weight)
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_weight = torch.tensor(edge_weights, dtype=torch.float)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_weight = torch.tensor([], dtype=torch.float)
        return nodes, edge_index, edge_weight

# Create dataset
dataset = PathDataset(paths, ratings, node_static_embeddings, edge_features)

# Split dataset
train_ratio = 0.85
val_ratio = 0.05
test_ratio = 0.1
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

# Create data loaders
batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [8]:
# Counts occurrences of each class
class_counts = torch.bincount(
    torch.tensor(ratings)[train_dataset.indices].to(torch.int64)
)

# Calculate weights as the inverse of class frequencies
class_weights = 1.0 / class_counts.float()

# Normalize the weights so that they sum to the number of classes
class_weights = class_weights / class_weights.sum() * len(class_counts)

class_weights

tensor([0.4479, 0.3379, 0.4016, 1.0386, 2.7740])

#  Model

In [9]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, global_mean_pool, GraphNorm

class GraphTransformerModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_classes, heads=4, dropout=0.2):
        super(GraphTransformerModel, self).__init__()
        self.conv1 = TransformerConv(
            in_channels, hidden_channels // heads, heads=heads, edge_dim=1, dropout=dropout)
        self.norm1 = GraphNorm(hidden_channels)
        self.conv2 = TransformerConv(
            hidden_channels, hidden_channels // heads, heads=heads, edge_dim=1, dropout=dropout)
        self.norm2 = GraphNorm(hidden_channels)
        self.classifier = torch.nn.Linear(hidden_channels, num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

        # Reshape edge_weight to [num_edges, 1]
        if edge_weight is not None:
            edge_attr = edge_weight.view(-1, 1)
        else:
            edge_attr = None

        x = self.conv1(x, edge_index, edge_attr)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index, edge_attr)
        x = self.norm2(x)
        x = F.relu(x)

        x = global_mean_pool(x, batch=data.batch)

        x = self.classifier(x)
        return x  # Return raw logits

# Training

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weights = class_weights.to(device)

model = GraphTransformerModel(
    in_channels=node_static_embeddings.shape[1],
    hidden_channels=64,
    num_classes=5,
    heads=4,
    dropout=0.2
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

def train():
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
    return correct / total

best_val_acc = 0
for epoch in range(1, 201):
    loss = train()
    val_acc = evaluate(val_loader)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

  y=torch.tensor([rating], dtype=torch.long)
100%|██████████| 3669/3669 [00:18<00:00, 197.30it/s]


Epoch: 001, Loss: 1.5618, Val Acc: 0.2728


100%|██████████| 3669/3669 [00:18<00:00, 199.26it/s]


Epoch: 002, Loss: 1.5539, Val Acc: 0.3331


100%|██████████| 3669/3669 [00:18<00:00, 203.26it/s]


Epoch: 003, Loss: 1.5548, Val Acc: 0.3076


100%|██████████| 3669/3669 [00:18<00:00, 201.05it/s]


Epoch: 004, Loss: 1.5512, Val Acc: 0.2998


100%|██████████| 3669/3669 [00:18<00:00, 199.88it/s]


Epoch: 005, Loss: 1.5515, Val Acc: 0.2357


100%|██████████| 3669/3669 [00:19<00:00, 190.87it/s]


Epoch: 006, Loss: 1.5504, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:19<00:00, 192.86it/s]


Epoch: 007, Loss: 1.5512, Val Acc: 0.2852


100%|██████████| 3669/3669 [00:18<00:00, 202.66it/s]


Epoch: 008, Loss: 1.5537, Val Acc: 0.2991


100%|██████████| 3669/3669 [00:18<00:00, 202.03it/s]


Epoch: 009, Loss: 1.5489, Val Acc: 0.2998


100%|██████████| 3669/3669 [00:18<00:00, 203.34it/s]


Epoch: 010, Loss: 1.5531, Val Acc: 0.2597


100%|██████████| 3669/3669 [00:18<00:00, 203.03it/s]


Epoch: 011, Loss: 1.5498, Val Acc: 0.1924


100%|██████████| 3669/3669 [00:18<00:00, 200.77it/s]


Epoch: 012, Loss: 1.5491, Val Acc: 0.2689


100%|██████████| 3669/3669 [00:18<00:00, 201.95it/s]


Epoch: 013, Loss: 1.5490, Val Acc: 0.3122


100%|██████████| 3669/3669 [00:18<00:00, 198.61it/s]


Epoch: 014, Loss: 1.5487, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:18<00:00, 196.45it/s]


Epoch: 015, Loss: 1.5489, Val Acc: 0.2303


100%|██████████| 3669/3669 [00:18<00:00, 201.56it/s]


Epoch: 016, Loss: 1.5472, Val Acc: 0.3029


100%|██████████| 3669/3669 [00:18<00:00, 200.40it/s]


Epoch: 017, Loss: 1.5465, Val Acc: 0.2983


100%|██████████| 3669/3669 [00:17<00:00, 203.87it/s]


Epoch: 018, Loss: 1.5499, Val Acc: 0.2821


100%|██████████| 3669/3669 [00:18<00:00, 202.06it/s]


Epoch: 019, Loss: 1.5461, Val Acc: 0.2581


100%|██████████| 3669/3669 [00:18<00:00, 197.06it/s]


Epoch: 020, Loss: 1.5483, Val Acc: 0.2202


100%|██████████| 3669/3669 [00:18<00:00, 194.70it/s]


Epoch: 021, Loss: 1.5481, Val Acc: 0.3006


100%|██████████| 3669/3669 [00:19<00:00, 184.83it/s]


Epoch: 022, Loss: 1.5469, Val Acc: 0.3168


100%|██████████| 3669/3669 [00:18<00:00, 203.28it/s]


Epoch: 023, Loss: 1.5504, Val Acc: 0.2898


100%|██████████| 3669/3669 [00:18<00:00, 199.62it/s]


Epoch: 024, Loss: 1.5508, Val Acc: 0.2543


100%|██████████| 3669/3669 [00:18<00:00, 195.40it/s]


Epoch: 025, Loss: 1.5502, Val Acc: 0.2759


100%|██████████| 3669/3669 [00:18<00:00, 193.83it/s]


Epoch: 026, Loss: 1.5477, Val Acc: 0.2604


100%|██████████| 3669/3669 [00:19<00:00, 193.00it/s]


Epoch: 027, Loss: 1.5511, Val Acc: 0.2604


100%|██████████| 3669/3669 [00:18<00:00, 194.13it/s]


Epoch: 028, Loss: 1.5476, Val Acc: 0.2581


100%|██████████| 3669/3669 [00:18<00:00, 198.46it/s]


Epoch: 029, Loss: 1.5469, Val Acc: 0.3323


100%|██████████| 3669/3669 [00:18<00:00, 197.55it/s]


Epoch: 030, Loss: 1.5508, Val Acc: 0.3308


100%|██████████| 3669/3669 [00:18<00:00, 198.73it/s]


Epoch: 031, Loss: 1.5487, Val Acc: 0.2813


100%|██████████| 3669/3669 [00:18<00:00, 194.98it/s]


Epoch: 032, Loss: 1.5471, Val Acc: 0.2759


100%|██████████| 3669/3669 [00:18<00:00, 196.91it/s]


Epoch: 033, Loss: 1.5499, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:18<00:00, 198.80it/s]


Epoch: 034, Loss: 1.5515, Val Acc: 0.2651


100%|██████████| 3669/3669 [00:18<00:00, 197.89it/s]


Epoch: 035, Loss: 1.5490, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:18<00:00, 199.66it/s]


Epoch: 036, Loss: 1.5516, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:18<00:00, 194.10it/s]


Epoch: 037, Loss: 1.5487, Val Acc: 0.2944


100%|██████████| 3669/3669 [00:18<00:00, 201.71it/s]


Epoch: 038, Loss: 1.5464, Val Acc: 0.3060


100%|██████████| 3669/3669 [00:18<00:00, 196.59it/s]


Epoch: 039, Loss: 1.5507, Val Acc: 0.2048


100%|██████████| 3669/3669 [00:17<00:00, 208.07it/s]


Epoch: 040, Loss: 1.5481, Val Acc: 0.2612


100%|██████████| 3669/3669 [00:17<00:00, 208.68it/s]


Epoch: 041, Loss: 1.5495, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:17<00:00, 208.94it/s]


Epoch: 042, Loss: 1.5491, Val Acc: 0.3207


100%|██████████| 3669/3669 [00:18<00:00, 197.80it/s]


Epoch: 043, Loss: 1.5471, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:19<00:00, 187.82it/s]


Epoch: 044, Loss: 1.5479, Val Acc: 0.3022


100%|██████████| 3669/3669 [00:22<00:00, 161.44it/s]


Epoch: 045, Loss: 1.5454, Val Acc: 0.3107


100%|██████████| 3669/3669 [00:22<00:00, 165.06it/s]


Epoch: 046, Loss: 1.5443, Val Acc: 0.2913


100%|██████████| 3669/3669 [00:23<00:00, 152.91it/s]


Epoch: 047, Loss: 1.5507, Val Acc: 0.3338


100%|██████████| 3669/3669 [00:21<00:00, 170.88it/s]


Epoch: 048, Loss: 1.5482, Val Acc: 0.2821


100%|██████████| 3669/3669 [00:21<00:00, 171.37it/s]


Epoch: 049, Loss: 1.5484, Val Acc: 0.2342


100%|██████████| 3669/3669 [00:18<00:00, 193.90it/s]


Epoch: 050, Loss: 1.5494, Val Acc: 0.2651


100%|██████████| 3669/3669 [00:20<00:00, 181.04it/s]


Epoch: 051, Loss: 1.5481, Val Acc: 0.2002


100%|██████████| 3669/3669 [00:18<00:00, 198.61it/s]


Epoch: 052, Loss: 1.5514, Val Acc: 0.2357


100%|██████████| 3669/3669 [00:17<00:00, 205.00it/s]


Epoch: 053, Loss: 1.5487, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:18<00:00, 203.52it/s]


Epoch: 054, Loss: 1.5504, Val Acc: 0.2303


100%|██████████| 3669/3669 [00:18<00:00, 200.92it/s]


Epoch: 055, Loss: 1.5461, Val Acc: 0.2396


100%|██████████| 3669/3669 [00:18<00:00, 202.68it/s]


Epoch: 056, Loss: 1.5449, Val Acc: 0.3114


100%|██████████| 3669/3669 [00:18<00:00, 203.22it/s]


Epoch: 057, Loss: 1.5439, Val Acc: 0.3138


100%|██████████| 3669/3669 [00:18<00:00, 201.13it/s]


Epoch: 058, Loss: 1.5463, Val Acc: 0.1553


100%|██████████| 3669/3669 [00:18<00:00, 202.44it/s]


Epoch: 059, Loss: 1.5380, Val Acc: 0.2697


100%|██████████| 3669/3669 [00:18<00:00, 202.85it/s]


Epoch: 060, Loss: 1.5447, Val Acc: 0.3153


100%|██████████| 3669/3669 [00:18<00:00, 203.77it/s]


Epoch: 061, Loss: 1.5464, Val Acc: 0.2581


100%|██████████| 3669/3669 [00:18<00:00, 201.54it/s]


Epoch: 062, Loss: 1.5470, Val Acc: 0.3369


100%|██████████| 3669/3669 [00:18<00:00, 202.63it/s]


Epoch: 063, Loss: 1.5480, Val Acc: 0.1963


100%|██████████| 3669/3669 [00:18<00:00, 203.09it/s]


Epoch: 064, Loss: 1.5466, Val Acc: 0.2883


100%|██████████| 3669/3669 [00:18<00:00, 202.69it/s]


Epoch: 065, Loss: 1.5460, Val Acc: 0.3037


100%|██████████| 3669/3669 [00:18<00:00, 203.07it/s]


Epoch: 066, Loss: 1.5467, Val Acc: 0.3269


100%|██████████| 3669/3669 [00:18<00:00, 202.64it/s]


Epoch: 067, Loss: 1.5484, Val Acc: 0.2651


100%|██████████| 3669/3669 [00:18<00:00, 196.20it/s]


Epoch: 068, Loss: 1.5458, Val Acc: 0.2597


100%|██████████| 3669/3669 [00:18<00:00, 197.61it/s]


Epoch: 069, Loss: 1.5478, Val Acc: 0.2326


100%|██████████| 3669/3669 [00:18<00:00, 203.12it/s]


Epoch: 070, Loss: 1.5487, Val Acc: 0.2844


100%|██████████| 3669/3669 [00:17<00:00, 204.60it/s]


Epoch: 071, Loss: 1.5459, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:18<00:00, 203.46it/s]


Epoch: 072, Loss: 1.5470, Val Acc: 0.3416


100%|██████████| 3669/3669 [00:18<00:00, 200.30it/s]


Epoch: 073, Loss: 1.5478, Val Acc: 0.2643


100%|██████████| 3669/3669 [00:18<00:00, 199.27it/s]


Epoch: 074, Loss: 1.5476, Val Acc: 0.2767


100%|██████████| 3669/3669 [00:18<00:00, 200.20it/s]


Epoch: 075, Loss: 1.5459, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:18<00:00, 200.79it/s]


Epoch: 076, Loss: 1.5468, Val Acc: 0.3053


100%|██████████| 3669/3669 [00:18<00:00, 202.38it/s]


Epoch: 077, Loss: 1.5455, Val Acc: 0.2481


100%|██████████| 3669/3669 [00:18<00:00, 200.25it/s]


Epoch: 078, Loss: 1.5444, Val Acc: 0.2759


100%|██████████| 3669/3669 [00:18<00:00, 197.24it/s]


Epoch: 079, Loss: 1.5499, Val Acc: 0.3083


100%|██████████| 3669/3669 [00:17<00:00, 204.11it/s]


Epoch: 080, Loss: 1.5458, Val Acc: 0.2666


100%|██████████| 3669/3669 [00:18<00:00, 194.71it/s]


Epoch: 081, Loss: 1.5508, Val Acc: 0.3207


100%|██████████| 3669/3669 [00:19<00:00, 187.57it/s]


Epoch: 082, Loss: 1.5508, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:18<00:00, 194.05it/s]


Epoch: 083, Loss: 1.5484, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:17<00:00, 205.69it/s]


Epoch: 084, Loss: 1.5468, Val Acc: 0.2859


100%|██████████| 3669/3669 [00:17<00:00, 205.39it/s]


Epoch: 085, Loss: 1.5471, Val Acc: 0.2844


100%|██████████| 3669/3669 [00:17<00:00, 204.62it/s]


Epoch: 086, Loss: 1.5461, Val Acc: 0.2056


100%|██████████| 3669/3669 [00:18<00:00, 203.07it/s]


Epoch: 087, Loss: 1.5471, Val Acc: 0.3076


100%|██████████| 3669/3669 [00:18<00:00, 201.21it/s]


Epoch: 088, Loss: 1.5472, Val Acc: 0.1886


100%|██████████| 3669/3669 [00:18<00:00, 202.32it/s]


Epoch: 089, Loss: 1.5494, Val Acc: 0.3076


100%|██████████| 3669/3669 [00:18<00:00, 199.69it/s]


Epoch: 090, Loss: 1.5475, Val Acc: 0.3083


100%|██████████| 3669/3669 [00:18<00:00, 199.22it/s]


Epoch: 091, Loss: 1.5476, Val Acc: 0.2643


100%|██████████| 3669/3669 [00:18<00:00, 202.68it/s]


Epoch: 092, Loss: 1.5450, Val Acc: 0.2960


100%|██████████| 3669/3669 [00:18<00:00, 202.78it/s]


Epoch: 093, Loss: 1.5472, Val Acc: 0.3532


100%|██████████| 3669/3669 [00:18<00:00, 198.65it/s]


Epoch: 094, Loss: 1.5482, Val Acc: 0.3068


100%|██████████| 3669/3669 [00:20<00:00, 178.43it/s]


Epoch: 095, Loss: 1.5467, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:18<00:00, 193.19it/s]


Epoch: 096, Loss: 1.5473, Val Acc: 0.2798


100%|██████████| 3669/3669 [00:20<00:00, 179.85it/s]


Epoch: 097, Loss: 1.5481, Val Acc: 0.2836


100%|██████████| 3669/3669 [00:19<00:00, 188.19it/s]


Epoch: 098, Loss: 1.5459, Val Acc: 0.2937


100%|██████████| 3669/3669 [00:18<00:00, 196.66it/s]


Epoch: 099, Loss: 1.5444, Val Acc: 0.2998


100%|██████████| 3669/3669 [00:18<00:00, 199.04it/s]


Epoch: 100, Loss: 1.5438, Val Acc: 0.2287


100%|██████████| 3669/3669 [00:18<00:00, 198.15it/s]


Epoch: 101, Loss: 1.5476, Val Acc: 0.3192


100%|██████████| 3669/3669 [00:19<00:00, 189.08it/s]


Epoch: 102, Loss: 1.5470, Val Acc: 0.2403


100%|██████████| 3669/3669 [00:17<00:00, 204.90it/s]


Epoch: 103, Loss: 1.5444, Val Acc: 0.2450


100%|██████████| 3669/3669 [00:18<00:00, 203.22it/s]


Epoch: 104, Loss: 1.5443, Val Acc: 0.2187


100%|██████████| 3669/3669 [00:18<00:00, 198.78it/s]


Epoch: 105, Loss: 1.5460, Val Acc: 0.2774


100%|██████████| 3669/3669 [00:18<00:00, 196.32it/s]


Epoch: 106, Loss: 1.5441, Val Acc: 0.3253


100%|██████████| 3669/3669 [00:19<00:00, 192.56it/s]


Epoch: 107, Loss: 1.5419, Val Acc: 0.2983


100%|██████████| 3669/3669 [00:17<00:00, 204.49it/s]


Epoch: 108, Loss: 1.5475, Val Acc: 0.2658


100%|██████████| 3669/3669 [00:18<00:00, 203.36it/s]


Epoch: 109, Loss: 1.5484, Val Acc: 0.3223


100%|██████████| 3669/3669 [00:17<00:00, 205.53it/s]


Epoch: 110, Loss: 1.5438, Val Acc: 0.2883


100%|██████████| 3669/3669 [00:17<00:00, 204.24it/s]


Epoch: 111, Loss: 1.5459, Val Acc: 0.3230


100%|██████████| 3669/3669 [00:17<00:00, 204.54it/s]


Epoch: 112, Loss: 1.5454, Val Acc: 0.2790


100%|██████████| 3669/3669 [00:18<00:00, 202.86it/s]


Epoch: 113, Loss: 1.5476, Val Acc: 0.3091


100%|██████████| 3669/3669 [00:18<00:00, 201.85it/s]


Epoch: 114, Loss: 1.5443, Val Acc: 0.2427


100%|██████████| 3669/3669 [00:18<00:00, 203.52it/s]


Epoch: 115, Loss: 1.5444, Val Acc: 0.2929


100%|██████████| 3669/3669 [00:18<00:00, 203.09it/s]


Epoch: 116, Loss: 1.5484, Val Acc: 0.3068


100%|██████████| 3669/3669 [00:20<00:00, 181.27it/s]


Epoch: 117, Loss: 1.5457, Val Acc: 0.2960


100%|██████████| 3669/3669 [00:21<00:00, 174.42it/s]


Epoch: 118, Loss: 1.5476, Val Acc: 0.3130


100%|██████████| 3669/3669 [00:19<00:00, 184.58it/s]


Epoch: 119, Loss: 1.5487, Val Acc: 0.2944


100%|██████████| 3669/3669 [00:19<00:00, 192.78it/s]


Epoch: 120, Loss: 1.5476, Val Acc: 0.2550


100%|██████████| 3669/3669 [00:21<00:00, 173.34it/s]


Epoch: 121, Loss: 1.5474, Val Acc: 0.1801


100%|██████████| 3669/3669 [00:18<00:00, 193.44it/s]


Epoch: 122, Loss: 1.5505, Val Acc: 0.3091


100%|██████████| 3669/3669 [00:17<00:00, 205.04it/s]


Epoch: 123, Loss: 1.5478, Val Acc: 0.2968


100%|██████████| 3669/3669 [00:17<00:00, 204.61it/s]


Epoch: 124, Loss: 1.5452, Val Acc: 0.2798


100%|██████████| 3669/3669 [00:17<00:00, 204.05it/s]


Epoch: 125, Loss: 1.5456, Val Acc: 0.1801


100%|██████████| 3669/3669 [00:18<00:00, 202.85it/s]


Epoch: 126, Loss: 1.5484, Val Acc: 0.2496


100%|██████████| 3669/3669 [00:21<00:00, 173.43it/s]


Epoch: 127, Loss: 1.5458, Val Acc: 0.2597


100%|██████████| 3669/3669 [00:19<00:00, 184.31it/s]


Epoch: 128, Loss: 1.5498, Val Acc: 0.2326


100%|██████████| 3669/3669 [00:19<00:00, 185.11it/s]


Epoch: 129, Loss: 1.5471, Val Acc: 0.2326


100%|██████████| 3669/3669 [00:20<00:00, 183.09it/s]


Epoch: 130, Loss: 1.5476, Val Acc: 0.2682


100%|██████████| 3669/3669 [00:21<00:00, 168.41it/s]


Epoch: 131, Loss: 1.5448, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:22<00:00, 160.00it/s]


Epoch: 132, Loss: 1.5439, Val Acc: 0.3068


100%|██████████| 3669/3669 [00:21<00:00, 168.05it/s]


Epoch: 133, Loss: 1.5495, Val Acc: 0.3439


100%|██████████| 3669/3669 [00:20<00:00, 178.95it/s]


Epoch: 134, Loss: 1.5464, Val Acc: 0.2983


100%|██████████| 3669/3669 [00:20<00:00, 180.60it/s]


Epoch: 135, Loss: 1.5476, Val Acc: 0.3315


100%|██████████| 3669/3669 [00:20<00:00, 180.25it/s]


Epoch: 136, Loss: 1.5490, Val Acc: 0.2960


100%|██████████| 3669/3669 [00:20<00:00, 182.39it/s]


Epoch: 137, Loss: 1.5479, Val Acc: 0.2496


100%|██████████| 3669/3669 [00:20<00:00, 180.26it/s]


Epoch: 138, Loss: 1.5484, Val Acc: 0.3354


100%|██████████| 3669/3669 [00:20<00:00, 174.92it/s]


Epoch: 139, Loss: 1.5465, Val Acc: 0.2821


100%|██████████| 3669/3669 [00:20<00:00, 180.57it/s]


Epoch: 140, Loss: 1.5477, Val Acc: 0.2944


100%|██████████| 3669/3669 [00:20<00:00, 176.73it/s]


Epoch: 141, Loss: 1.5505, Val Acc: 0.3284


100%|██████████| 3669/3669 [00:20<00:00, 180.77it/s]


Epoch: 142, Loss: 1.5446, Val Acc: 0.1847


100%|██████████| 3669/3669 [00:20<00:00, 177.79it/s]


Epoch: 143, Loss: 1.5450, Val Acc: 0.2388


100%|██████████| 3669/3669 [00:20<00:00, 178.26it/s]


Epoch: 144, Loss: 1.5446, Val Acc: 0.3076


100%|██████████| 3669/3669 [00:20<00:00, 180.47it/s]


Epoch: 145, Loss: 1.5470, Val Acc: 0.3246


100%|██████████| 3669/3669 [00:20<00:00, 177.10it/s]


Epoch: 146, Loss: 1.5465, Val Acc: 0.2875


100%|██████████| 3669/3669 [00:20<00:00, 180.21it/s]


Epoch: 147, Loss: 1.5442, Val Acc: 0.3223


100%|██████████| 3669/3669 [00:20<00:00, 181.01it/s]


Epoch: 148, Loss: 1.5460, Val Acc: 0.2349


100%|██████████| 3669/3669 [00:20<00:00, 179.63it/s]


Epoch: 149, Loss: 1.5459, Val Acc: 0.2998


100%|██████████| 3669/3669 [00:21<00:00, 173.67it/s]


Epoch: 150, Loss: 1.5456, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:21<00:00, 172.49it/s]


Epoch: 151, Loss: 1.5456, Val Acc: 0.3400


100%|██████████| 3669/3669 [00:19<00:00, 189.07it/s]


Epoch: 152, Loss: 1.5474, Val Acc: 0.2535


100%|██████████| 3669/3669 [00:19<00:00, 186.63it/s]


Epoch: 153, Loss: 1.5450, Val Acc: 0.2597


100%|██████████| 3669/3669 [00:19<00:00, 184.88it/s]


Epoch: 154, Loss: 1.5463, Val Acc: 0.2295


100%|██████████| 3669/3669 [00:19<00:00, 185.07it/s]


Epoch: 155, Loss: 1.5493, Val Acc: 0.3199


100%|██████████| 3669/3669 [00:20<00:00, 179.97it/s]


Epoch: 156, Loss: 1.5486, Val Acc: 0.2952


100%|██████████| 3669/3669 [00:20<00:00, 180.02it/s]


Epoch: 157, Loss: 1.5442, Val Acc: 0.3068


100%|██████████| 3669/3669 [00:20<00:00, 178.35it/s]


Epoch: 158, Loss: 1.5456, Val Acc: 0.3006


100%|██████████| 3669/3669 [00:20<00:00, 180.91it/s]


Epoch: 159, Loss: 1.5447, Val Acc: 0.2473


100%|██████████| 3669/3669 [00:21<00:00, 167.55it/s]


Epoch: 160, Loss: 1.5460, Val Acc: 0.3138


100%|██████████| 3669/3669 [00:20<00:00, 177.20it/s]


Epoch: 161, Loss: 1.5446, Val Acc: 0.2929


100%|██████████| 3669/3669 [00:20<00:00, 175.52it/s]


Epoch: 162, Loss: 1.5457, Val Acc: 0.2620


100%|██████████| 3669/3669 [00:20<00:00, 175.56it/s]


Epoch: 163, Loss: 1.5484, Val Acc: 0.2094


100%|██████████| 3669/3669 [00:19<00:00, 186.23it/s]


Epoch: 164, Loss: 1.5498, Val Acc: 0.2991


100%|██████████| 3669/3669 [00:20<00:00, 179.86it/s]


Epoch: 165, Loss: 1.5481, Val Acc: 0.2334


100%|██████████| 3669/3669 [00:19<00:00, 185.19it/s]


Epoch: 166, Loss: 1.5439, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:20<00:00, 179.47it/s]


Epoch: 167, Loss: 1.5452, Val Acc: 0.2210


100%|██████████| 3669/3669 [00:19<00:00, 184.81it/s]


Epoch: 168, Loss: 1.5463, Val Acc: 0.2604


100%|██████████| 3669/3669 [00:20<00:00, 178.07it/s]


Epoch: 169, Loss: 1.5460, Val Acc: 0.2025


100%|██████████| 3669/3669 [00:20<00:00, 182.19it/s]


Epoch: 170, Loss: 1.5491, Val Acc: 0.2543


100%|██████████| 3669/3669 [00:20<00:00, 181.55it/s]


Epoch: 171, Loss: 1.5452, Val Acc: 0.3053


100%|██████████| 3669/3669 [00:20<00:00, 182.81it/s]


Epoch: 172, Loss: 1.5462, Val Acc: 0.2944


100%|██████████| 3669/3669 [00:20<00:00, 182.92it/s]


Epoch: 173, Loss: 1.5464, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:20<00:00, 178.87it/s]


Epoch: 174, Loss: 1.5487, Val Acc: 0.2929


100%|██████████| 3669/3669 [00:20<00:00, 183.12it/s]


Epoch: 175, Loss: 1.5463, Val Acc: 0.2620


100%|██████████| 3669/3669 [00:19<00:00, 183.92it/s]


Epoch: 176, Loss: 1.5451, Val Acc: 0.2628


100%|██████████| 3669/3669 [00:19<00:00, 184.57it/s]


Epoch: 177, Loss: 1.5438, Val Acc: 0.2071


100%|██████████| 3669/3669 [00:21<00:00, 168.46it/s]


Epoch: 178, Loss: 1.5489, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:19<00:00, 185.15it/s]


Epoch: 179, Loss: 1.5491, Val Acc: 0.3393


100%|██████████| 3669/3669 [00:19<00:00, 188.92it/s]


Epoch: 180, Loss: 1.5446, Val Acc: 0.2767


100%|██████████| 3669/3669 [00:19<00:00, 189.10it/s]


Epoch: 181, Loss: 1.5484, Val Acc: 0.2921


100%|██████████| 3669/3669 [00:19<00:00, 187.07it/s]


Epoch: 182, Loss: 1.5480, Val Acc: 0.3346


100%|██████████| 3669/3669 [00:19<00:00, 187.84it/s]


Epoch: 183, Loss: 1.5468, Val Acc: 0.2141


100%|██████████| 3669/3669 [00:19<00:00, 185.60it/s]


Epoch: 184, Loss: 1.5460, Val Acc: 0.3184


100%|██████████| 3669/3669 [00:19<00:00, 185.24it/s]


Epoch: 185, Loss: 1.5475, Val Acc: 0.2558


100%|██████████| 3669/3669 [00:19<00:00, 189.16it/s]


Epoch: 186, Loss: 1.5446, Val Acc: 0.2496


100%|██████████| 3669/3669 [00:21<00:00, 174.11it/s]


Epoch: 187, Loss: 1.5454, Val Acc: 0.2071


100%|██████████| 3669/3669 [00:22<00:00, 165.96it/s]


Epoch: 188, Loss: 1.5483, Val Acc: 0.3393


100%|██████████| 3669/3669 [00:25<00:00, 145.54it/s]


Epoch: 189, Loss: 1.5476, Val Acc: 0.2975


100%|██████████| 3669/3669 [00:23<00:00, 153.63it/s]


Epoch: 190, Loss: 1.5452, Val Acc: 0.2620


100%|██████████| 3669/3669 [00:20<00:00, 180.86it/s]


Epoch: 191, Loss: 1.5469, Val Acc: 0.3099


100%|██████████| 3669/3669 [00:20<00:00, 181.76it/s]


Epoch: 192, Loss: 1.5481, Val Acc: 0.3439


100%|██████████| 3669/3669 [00:20<00:00, 179.45it/s]


Epoch: 193, Loss: 1.5456, Val Acc: 0.2442


100%|██████████| 3669/3669 [00:20<00:00, 177.47it/s]


Epoch: 194, Loss: 1.5473, Val Acc: 0.2643


100%|██████████| 3669/3669 [00:18<00:00, 199.57it/s]


Epoch: 195, Loss: 1.5449, Val Acc: 0.2867


100%|██████████| 3669/3669 [00:20<00:00, 175.51it/s]


Epoch: 196, Loss: 1.5452, Val Acc: 0.2705


100%|██████████| 3669/3669 [00:22<00:00, 166.71it/s]


Epoch: 197, Loss: 1.5489, Val Acc: 0.2241


100%|██████████| 3669/3669 [00:22<00:00, 162.41it/s]


Epoch: 198, Loss: 1.5500, Val Acc: 0.2040


100%|██████████| 3669/3669 [00:22<00:00, 161.23it/s]


Epoch: 199, Loss: 1.5467, Val Acc: 0.1406


100%|██████████| 3669/3669 [00:22<00:00, 163.49it/s]


Epoch: 200, Loss: 1.5431, Val Acc: 0.2774


In [11]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

test_acc = evaluate(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')

  model.load_state_dict(torch.load('best_model.pth'))
  y=torch.tensor([rating], dtype=torch.long)


Test Accuracy: 0.3506
