In [1]:
import torch
import torch.nn.functional as F
import pandas as pd
import numpy as np
import pickle
from urllib.parse import unquote
from torch.utils.data import Dataset, random_split
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, GraphNorm
from tqdm import tqdm
import random
from torch_geometric.nn.models import GCN

# Load Data

In [2]:
# Load text data
data = pd.read_csv(f"../../data/full_text_data.csv")

In [3]:
# Load links
links = pd.read_csv("../../data/Wikispeedia/links.tsv", sep="\t", names=["src", "tgt"], skiprows=12)
links["src"] = links["src"].map(lambda x: unquote(x))
links["tgt"] = links["tgt"].map(lambda x: unquote(x))

# Create adjacency matrix
ordered_data_titles = data["title"].tolist()
src_indices = links["src"].map(lambda x: ordered_data_titles.index(x))
tgt_indices = links["tgt"].map(lambda x: ordered_data_titles.index(x))
A = torch.zeros((len(ordered_data_titles), len(ordered_data_titles)))
A[src_indices, tgt_indices] = 1

In [4]:
# Load coherence graph
with open("../../data/coherence_graph.pkl", 'rb') as handle:
    coherence_graph = pickle.load(handle)

# Combine coherence graph with base links
edge_features = A * coherence_graph

  edge_features = A * coherence_graph


In [5]:
# Load node embeddings
with open("../../data/gpt4_embeddings.pkl", 'rb') as handle:
    obj = pickle.load(handle)
    node_static_embeddings = obj["embeddings"]
    del obj
node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)
node_static_embeddings

  node_static_embeddings = torch.tensor(node_static_embeddings, dtype=torch.float)


tensor([[ 0.0038,  0.0096,  0.0519,  ..., -0.0275,  0.0512, -0.0136],
        [-0.0361,  0.0142,  0.0874,  ..., -0.0140, -0.0216,  0.0194],
        [ 0.0341,  0.0069,  0.0078,  ...,  0.0070,  0.0373, -0.0153],
        ...,
        [-0.0165,  0.0102,  0.0054,  ...,  0.0072,  0.0459, -0.0378],
        [-0.0212,  0.0331,  0.0266,  ..., -0.0187,  0.0237, -0.0141],
        [ 0.0375,  0.0523,  0.0282,  ...,  0.0183, -0.0107, -0.0030]])

In [6]:
# Load user-extracted paths
paths_data = pd.read_csv(f"../../data/paths_no_back_links.tsv", sep="\t")
paths_data = paths_data[~(paths_data["rating"].isna())]

# Filter paths with at least four distinct pages
# paths_data = paths_data[paths_data["path"].apply(lambda x: len(set(x.split(";"))) >= 4)]

# Map titles to indices
title_to_index = {unquote(title): idx for idx, title in enumerate(data['title'])}
paths = paths_data['path'].apply(
    lambda path: [title_to_index[unquote(title)] for title in path.split(';')]
).tolist()

# Convert ratings to binary labels
ratings = (paths_data['rating'] - 1).tolist()  # 0-indexed ratings from 0 to 4

# Map ratings to new labels
def map_rating(r):
    if r == 0:
        return 0
    elif r in [1, 2, 3]:
        return 1
    elif r == 4:
        return 2

ratings = [map_rating(r) for r in ratings]

# # Map ratings [0, 1] to 0, and ratings [2, 3, 4] to 1
# ratings = [0 if r <= 1 else 1 for r in ratings]

In [7]:
class PathDataset(Dataset):
    def __init__(self, paths, ratings, node_embeddings, edge_features):
        self.paths = paths
        self.ratings = ratings
        self.node_embeddings = node_embeddings
        self.edge_features = edge_features

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        rating = self.ratings[idx]
        nodes, edge_index, edge_weight = self.get_subgraph_edges(path)

        x = self.node_embeddings[nodes]

        data = Data(
            x=x,
            edge_index=edge_index,
            edge_weight=edge_weight,
            y=torch.tensor([rating], dtype=torch.long)
        )
        return data

    def get_subgraph_edges(self, path):
        nodes = list(set(path))
        node_to_idx = {node: idx for idx, node in enumerate(nodes)}
        edges = []
        edge_weights = []
        for i in nodes:
            for j in nodes:
                weight = self.edge_features[i, j]
                if weight > 0:
                    edges.append([node_to_idx[i], node_to_idx[j]])
                    edge_weights.append(weight)
        if edges:
            edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
            edge_weight = torch.tensor(edge_weights, dtype=torch.float)
        else:
            edge_index = torch.empty((2, 0), dtype=torch.long)
            edge_weight = torch.tensor([], dtype=torch.float)
        return nodes, edge_index, edge_weight

# Create dataset
dataset = PathDataset(paths, ratings, node_static_embeddings, edge_features)

# Split dataset
train_ratio = 0.85
val_ratio = 0.05
test_ratio = 0.1
total_size = len(dataset)
train_size = int(train_ratio * total_size)
val_size = int(val_ratio * total_size)
test_size = total_size - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42)
)

# Create data loaders
batch_size = 6
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)



In [8]:
# Counts occurrences of each class
train_labels = torch.tensor(ratings)[train_dataset.indices].to(torch.int64)
class_counts = torch.bincount(train_labels)

# Handle any potential zero counts
class_counts = class_counts + 1e-6  # Add a small epsilon to avoid division by zero

# Calculate weights as the inverse of class frequencies
class_weights = 1.0 / class_counts.float()

# # Normalize the weights so that they sum to the number of classes (2)
# class_weights = class_weights / class_weights.sum() * 3

# GCN model

In [9]:
class GCNModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_classes, dropout=0.2):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.norm1 = GraphNorm(hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)
        self.norm2 = GraphNorm(out_channels)
        self.classifier = torch.nn.Linear(out_channels, num_classes)
        self.dropout = dropout

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

        x = self.conv1(x, edge_index, edge_weight=edge_weight)
        x = self.norm1(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = self.conv2(x, edge_index, edge_weight=edge_weight)
        x = self.norm2(x)
        x = F.relu(x)

        x = global_mean_pool(x, batch=data.batch)

        x = self.classifier(x)
        return x  # Return raw logits

# Training

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class_weights = class_weights.to(device)

model = GCNModel(
    in_channels=node_static_embeddings.shape[1],
    hidden_channels=64,
    out_channels=32,
    num_classes=3
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)

def train():
    model.train()
    total_loss = 0
    for data in tqdm(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * data.num_graphs
    return total_loss / len(train_loader.dataset)

def evaluate(loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in loader:
            data = data.to(device)
            out = model(data)
            pred = out.argmax(dim=1)
            correct += (pred == data.y).sum().item()
            total += data.num_graphs
    return correct / total

best_val_acc = 0
for epoch in range(1, 201):
    loss = train()
    val_acc = evaluate(val_loader)
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        # Save the best model
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val Acc: {val_acc:.4f}')

100%|██████████| 4038/4038 [00:21<00:00, 191.37it/s]


Epoch: 001, Loss: 1.0183, Val Acc: 0.5881


100%|██████████| 4038/4038 [00:21<00:00, 191.93it/s]


Epoch: 002, Loss: 1.0183, Val Acc: 0.5614


100%|██████████| 4038/4038 [00:21<00:00, 187.22it/s]


Epoch: 003, Loss: 1.0175, Val Acc: 0.6379


100%|██████████| 4038/4038 [00:21<00:00, 191.95it/s]


Epoch: 004, Loss: 1.0189, Val Acc: 0.6161


100%|██████████| 4038/4038 [00:20<00:00, 193.93it/s]


Epoch: 005, Loss: 1.0160, Val Acc: 0.5775


100%|██████████| 4038/4038 [00:20<00:00, 194.39it/s]


Epoch: 006, Loss: 1.0175, Val Acc: 0.6540


100%|██████████| 4038/4038 [00:20<00:00, 195.13it/s]


Epoch: 007, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 196.28it/s]


Epoch: 008, Loss: 1.0195, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 193.54it/s]


Epoch: 009, Loss: 1.0209, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 193.98it/s]


Epoch: 010, Loss: 1.0219, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.17it/s]


Epoch: 011, Loss: 1.0191, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.22it/s]


Epoch: 012, Loss: 1.0233, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 192.40it/s]


Epoch: 013, Loss: 1.0207, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:21<00:00, 191.37it/s]


Epoch: 014, Loss: 1.0207, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:20<00:00, 193.24it/s]


Epoch: 015, Loss: 1.0233, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 196.13it/s]


Epoch: 016, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 195.01it/s]


Epoch: 017, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:21<00:00, 186.50it/s]


Epoch: 018, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:21<00:00, 192.27it/s]


Epoch: 019, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.64it/s]


Epoch: 020, Loss: 1.0218, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:21<00:00, 188.49it/s]


Epoch: 021, Loss: 1.0200, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 192.67it/s]


Epoch: 022, Loss: 1.0200, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 192.97it/s]


Epoch: 023, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:21<00:00, 192.15it/s]


Epoch: 024, Loss: 1.0209, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.05it/s]


Epoch: 025, Loss: 1.0192, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 193.49it/s]


Epoch: 026, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 193.57it/s]


Epoch: 027, Loss: 1.0218, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 195.61it/s]


Epoch: 028, Loss: 1.0215, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 192.71it/s]


Epoch: 029, Loss: 1.0195, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.65it/s]


Epoch: 030, Loss: 1.0185, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 196.56it/s]


Epoch: 031, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 200.41it/s]


Epoch: 032, Loss: 1.0184, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 196.16it/s]


Epoch: 033, Loss: 1.0200, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 194.73it/s]


Epoch: 034, Loss: 1.0219, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:20<00:00, 193.00it/s]


Epoch: 035, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 160.70it/s]


Epoch: 036, Loss: 1.0208, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.58it/s]


Epoch: 037, Loss: 1.0190, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.93it/s]


Epoch: 038, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.86it/s]


Epoch: 039, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.53it/s]


Epoch: 040, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.81it/s]


Epoch: 041, Loss: 1.0218, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 148.78it/s]


Epoch: 042, Loss: 1.0189, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 150.01it/s]


Epoch: 043, Loss: 1.0234, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 150.42it/s]


Epoch: 044, Loss: 1.0181, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 152.73it/s]


Epoch: 045, Loss: 1.0221, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:27<00:00, 147.92it/s]


Epoch: 046, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 153.70it/s]


Epoch: 047, Loss: 1.0197, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 156.91it/s]


Epoch: 048, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.25it/s]


Epoch: 049, Loss: 1.0215, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.11it/s]


Epoch: 050, Loss: 1.0216, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 161.44it/s]


Epoch: 051, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.40it/s]


Epoch: 052, Loss: 1.0220, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 157.76it/s]


Epoch: 053, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.85it/s]


Epoch: 054, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.35it/s]


Epoch: 055, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.18it/s]


Epoch: 056, Loss: 1.0215, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.58it/s]


Epoch: 057, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 153.72it/s]


Epoch: 058, Loss: 1.0217, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.03it/s]


Epoch: 059, Loss: 1.0197, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.92it/s]


Epoch: 060, Loss: 1.0193, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.06it/s]


Epoch: 061, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.86it/s]


Epoch: 062, Loss: 1.0217, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.34it/s]


Epoch: 063, Loss: 1.0238, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.71it/s]


Epoch: 064, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.31it/s]


Epoch: 065, Loss: 1.0208, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.98it/s]


Epoch: 066, Loss: 1.0197, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.09it/s]


Epoch: 067, Loss: 1.0207, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 153.80it/s]


Epoch: 068, Loss: 1.0232, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.77it/s]


Epoch: 069, Loss: 1.0211, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 154.83it/s]


Epoch: 070, Loss: 1.0215, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.64it/s]


Epoch: 071, Loss: 1.0222, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 156.37it/s]


Epoch: 072, Loss: 1.0223, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 157.00it/s]


Epoch: 073, Loss: 1.0209, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.42it/s]


Epoch: 074, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.65it/s]


Epoch: 075, Loss: 1.0228, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.34it/s]


Epoch: 076, Loss: 1.0198, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.39it/s]


Epoch: 077, Loss: 1.0207, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 155.32it/s]


Epoch: 078, Loss: 1.0193, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.90it/s]


Epoch: 079, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.26it/s]


Epoch: 080, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.76it/s]


Epoch: 081, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.21it/s]


Epoch: 082, Loss: 1.0195, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.48it/s]


Epoch: 083, Loss: 1.0207, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.83it/s]


Epoch: 084, Loss: 1.0185, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.51it/s]


Epoch: 085, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 155.21it/s]


Epoch: 086, Loss: 1.0202, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.76it/s]


Epoch: 087, Loss: 1.0199, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.43it/s]


Epoch: 088, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 155.14it/s]


Epoch: 089, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.49it/s]


Epoch: 090, Loss: 1.0216, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.22it/s]


Epoch: 091, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.10it/s]


Epoch: 092, Loss: 1.0209, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.40it/s]


Epoch: 093, Loss: 1.0197, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.07it/s]


Epoch: 094, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 155.20it/s]


Epoch: 095, Loss: 1.0225, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.34it/s]


Epoch: 096, Loss: 1.0207, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.20it/s]


Epoch: 097, Loss: 1.0217, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.36it/s]


Epoch: 098, Loss: 1.0226, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 155.28it/s]


Epoch: 099, Loss: 1.0214, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.70it/s]


Epoch: 100, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.69it/s]


Epoch: 101, Loss: 1.0214, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 151.84it/s]


Epoch: 102, Loss: 1.0226, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.32it/s]


Epoch: 103, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 155.06it/s]


Epoch: 104, Loss: 1.0186, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.38it/s]


Epoch: 105, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.33it/s]


Epoch: 106, Loss: 1.0208, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 154.86it/s]


Epoch: 107, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 153.63it/s]


Epoch: 108, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:23<00:00, 169.21it/s]


Epoch: 109, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.58it/s]


Epoch: 110, Loss: 1.0220, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.43it/s]


Epoch: 111, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.39it/s]


Epoch: 112, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.16it/s]


Epoch: 113, Loss: 1.0207, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.72it/s]


Epoch: 114, Loss: 1.0202, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.56it/s]


Epoch: 115, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.00it/s]


Epoch: 116, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.76it/s]


Epoch: 117, Loss: 1.0199, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.33it/s]


Epoch: 118, Loss: 1.0197, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.65it/s]


Epoch: 119, Loss: 1.0225, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.26it/s]


Epoch: 120, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.81it/s]


Epoch: 121, Loss: 1.0223, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.60it/s]


Epoch: 122, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.95it/s]


Epoch: 123, Loss: 1.0200, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:26<00:00, 154.93it/s]


Epoch: 124, Loss: 1.0211, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.44it/s]


Epoch: 125, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.03it/s]


Epoch: 126, Loss: 1.0183, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.65it/s]


Epoch: 127, Loss: 1.0227, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.05it/s]


Epoch: 128, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.65it/s]


Epoch: 129, Loss: 1.0191, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.63it/s]


Epoch: 130, Loss: 1.0193, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.28it/s]


Epoch: 131, Loss: 1.0218, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.26it/s]


Epoch: 132, Loss: 1.0202, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.87it/s]


Epoch: 133, Loss: 1.0219, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 159.11it/s]


Epoch: 134, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.96it/s]


Epoch: 135, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.22it/s]


Epoch: 136, Loss: 1.0214, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.42it/s]


Epoch: 137, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.95it/s]


Epoch: 138, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.83it/s]


Epoch: 139, Loss: 1.0211, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.30it/s]


Epoch: 140, Loss: 1.0224, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.97it/s]


Epoch: 141, Loss: 1.0216, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.94it/s]


Epoch: 142, Loss: 1.0202, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.36it/s]


Epoch: 143, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 159.02it/s]


Epoch: 144, Loss: 1.0220, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.96it/s]


Epoch: 145, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 160.04it/s]


Epoch: 146, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.15it/s]


Epoch: 147, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.31it/s]


Epoch: 148, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.44it/s]


Epoch: 149, Loss: 1.0178, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.59it/s]


Epoch: 150, Loss: 1.0174, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.34it/s]


Epoch: 151, Loss: 1.0196, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.82it/s]


Epoch: 152, Loss: 1.0202, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.87it/s]


Epoch: 153, Loss: 1.0210, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 157.53it/s]


Epoch: 154, Loss: 1.0200, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.09it/s]


Epoch: 155, Loss: 1.0206, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.00it/s]


Epoch: 156, Loss: 1.0211, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.66it/s]


Epoch: 157, Loss: 1.0217, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.88it/s]


Epoch: 158, Loss: 1.0214, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 160.06it/s]


Epoch: 159, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 167.75it/s]


Epoch: 160, Loss: 1.0191, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 160.92it/s]


Epoch: 161, Loss: 1.0232, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 159.19it/s]


Epoch: 162, Loss: 1.0185, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.21it/s]


Epoch: 163, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.95it/s]


Epoch: 164, Loss: 1.0220, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.90it/s]


Epoch: 165, Loss: 1.0197, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 153.16it/s]


Epoch: 166, Loss: 1.0199, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 154.31it/s]


Epoch: 167, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.50it/s]


Epoch: 168, Loss: 1.0216, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 164.93it/s]


Epoch: 169, Loss: 1.0211, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 165.12it/s]


Epoch: 170, Loss: 1.0200, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.29it/s]


Epoch: 171, Loss: 1.0221, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 154.14it/s]


Epoch: 172, Loss: 1.0223, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 155.45it/s]


Epoch: 173, Loss: 1.0217, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 161.39it/s]


Epoch: 174, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 163.10it/s]


Epoch: 175, Loss: 1.0195, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 165.38it/s]


Epoch: 176, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 166.17it/s]


Epoch: 177, Loss: 1.0223, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 164.56it/s]


Epoch: 178, Loss: 1.0204, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 162.31it/s]


Epoch: 179, Loss: 1.0222, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 163.72it/s]


Epoch: 180, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:24<00:00, 164.94it/s]


Epoch: 181, Loss: 1.0195, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 161.07it/s]


Epoch: 182, Loss: 1.0219, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.98it/s]


Epoch: 183, Loss: 1.0187, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.37it/s]


Epoch: 184, Loss: 1.0201, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.02it/s]


Epoch: 185, Loss: 1.0192, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 152.27it/s]


Epoch: 186, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.85it/s]


Epoch: 187, Loss: 1.0194, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:26<00:00, 154.57it/s]


Epoch: 188, Loss: 1.0212, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 156.03it/s]


Epoch: 189, Loss: 1.0210, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.24it/s]


Epoch: 190, Loss: 1.0218, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.55it/s]


Epoch: 191, Loss: 1.0184, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.33it/s]


Epoch: 192, Loss: 1.0214, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 157.57it/s]


Epoch: 193, Loss: 1.0189, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.48it/s]


Epoch: 194, Loss: 1.0209, Val Acc: 0.3130


100%|██████████| 4038/4038 [00:25<00:00, 157.48it/s]


Epoch: 195, Loss: 1.0203, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.85it/s]


Epoch: 196, Loss: 1.0227, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.97it/s]


Epoch: 197, Loss: 1.0205, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 157.08it/s]


Epoch: 198, Loss: 1.0213, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 158.83it/s]


Epoch: 199, Loss: 1.0237, Val Acc: 0.6491


100%|██████████| 4038/4038 [00:25<00:00, 161.32it/s]


Epoch: 200, Loss: 1.0194, Val Acc: 0.6491


In [11]:
# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

test_acc = evaluate(test_loader)
print(f'Test Accuracy: {test_acc:.4f}')

  model.load_state_dict(torch.load('best_model.pth'))


Test Accuracy: 0.6499
