In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/MyDrive/CS5242_Project

Mounted at /content/gdrive
/content/gdrive/MyDrive/CS5242_Project


In [2]:
import json
import jieba # Chinese text segmentation
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, f1_score
from collections import Counter
import numpy as np
import pandas as pd
import time
import os
from tqdm.auto import tqdm

In [3]:
# --- Configuration ---
CONFIG = {
    "train_data_path": "dataset/train.json",
    "val_data_path": "dataset/dev.json",
    "test_data_path": "dataset/test_public.json",
    "label_field": "label",
    "vocab_max_size": 10000,      # Max vocabulary size (incl. PAD, UNK, SEP)
    "vocab_min_freq": 2,          # Min frequency for word inclusion
    "embedding_dim": 128,         # Dimension of word embeddings
    "hidden_dim": 128,            # Dimension of RNN hidden state (total for single tower, per tower for two-tower affects FC layer input)
    "rnn_layers": 1,              # Number of RNN layers
    "rnn_bidirectional": True,   # Use True to get more context
    "rnn_nonlinearity": 'tanh',   # Activation for nn.RNN ('tanh' or 'relu')
    "dropout_prob": 0.3,          # Dropout probability
    "num_classes": 3,             # Relevance classes (0, 1, 2)
    "max_query_len": 30,          # Max length for query sequence
    "max_title_len": 50,          # Max length for title sequence
    "max_combined_len": 30 + 50 + 1, # Max length for query + SEP + title
    "batch_size": 64,
    "num_epochs": 10,             # Epochs
    "learning_rate": 0.001,
    "seed": 42,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "save_path_single": "best_single_tower_rnn.pth",
    "save_path_two": "best_two_tower_rnn.pth",
    "save_path_single_lstm": "best_single_tower_lstm.pth",
    "save_path_two_lstm": "best_two_tower_lstm.pth",
    "save_path_single_gru": "best_single_tower_gru.pth",
    "save_path_two_gru": "best_two_tower_gru.pth",
}

# Set random seed
np.random.seed(CONFIG["seed"])
torch.manual_seed(CONFIG["seed"])
if CONFIG["device"] == "cuda":
    torch.cuda.manual_seed_all(CONFIG["seed"])

print(f"Using device: {CONFIG['device']}")
print(f"RNN Bidirectional: {CONFIG['rnn_bidirectional']}")


Using device: cuda
RNN Bidirectional: True


###Data Loading

In [4]:
def load_data(filepath, label_field):
    queries, titles, labels = [], [], []
    print(f"Loading data from {filepath}...")
    line_num = 0

    def process_item(item, line_info=""):
        query = item.get("query", "")
        title = item.get("title", "")
        label_raw = item.get(label_field)
        label_int = -1
        valid_label = False
        if label_raw is not None:
            try:
                label_int = int(label_raw)
                if 0 <= label_int <= 2:
                    valid_label = True
            except (ValueError, TypeError):
                pass
        if valid_label:
            return query, title, label_int
        else:
            print(f"Warning: Skipping {line_info} due to invalid or missing label: {item}")
            return None

    with open(filepath, 'r', encoding='utf-8') as f:
        try:
            data_list = json.load(f)
            print("Loaded data as a single JSON list.")
            for i, item in enumerate(data_list):
                processed_data = process_item(item, line_info=f"item {i+1}")
                if processed_data:
                    q, t, l = processed_data
                    queries.append(q)
                    titles.append(t)
                    labels.append(l)

        except json.JSONDecodeError:
            print("Failed loading as list, trying JSON Lines format...")
            f.seek(0)
            for line_num, line in enumerate(f):
                line_strip = line.strip()
                if not line_strip: continue
                item = json.loads(line_strip)
                processed_data = process_item(item, line_info=f"line {line_num+1}")
                if processed_data:
                    q, t, l = processed_data
                    queries.append(q)
                    titles.append(t)
                    labels.append(l)



    print(f"Loaded {len(labels)} valid samples.")
    return queries, titles, labels

###Preprocessing

In [5]:
def tokenize(text):
    """Tokenizes Chinese text using jieba."""
    if text is None: return []
    # Use cut_for_search for potentially better recall
    return jieba.lcut_for_search(str(text)) # Ensure text is string

def build_vocab(texts, max_size, min_freq):
    """Builds a vocabulary including <PAD>, <UNK>, <SEP>."""
    word_counts = Counter()
    print("Building vocabulary...")
    for text in texts:
        word_counts.update(text)

    sorted_words = sorted(word_counts.items(), key=lambda x: (-x[1], x[0]))
    max_real_words = max_size - 3 # Reserve space for PAD, UNK, SEP
    vocab_words = [word for word, freq in sorted_words if freq >= min_freq]
    if len(vocab_words) > max_real_words:
        vocab_words = vocab_words[:max_real_words]

    word_to_idx = {'<PAD>': 0, '<UNK>': 1, '<SEP>': 2}
    idx_counter = 3
    for word in vocab_words:
        if word not in word_to_idx:
             word_to_idx[word] = idx_counter
             idx_counter += 1

    print(f"Vocabulary built with {len(word_to_idx)} words.")
    return word_to_idx

def tokens_to_indices(tokens, word_to_idx, max_len):
     """Converts tokens to indices, truncates."""
     indices = [word_to_idx.get(word, word_to_idx['<UNK>']) for word in tokens]
     return indices[:max_len]

def pad_sequence(indices, max_len, pad_idx):
    """Pads sequence to max_len."""
    current_len = len(indices)
    if current_len < max_len:
        return indices + [pad_idx] * (max_len - current_len)
    else:
        return indices[:max_len]

###Unified Dataset

In [6]:
class RelevanceDataset(data.Dataset):
    def __init__(self, queries, titles, labels, word_to_idx, max_query_len, max_title_len, max_combined_len):
        self.queries = queries
        self.titles = titles
        self.labels = labels
        self.word_to_idx = word_to_idx
        self.max_query_len = max_query_len
        self.max_title_len = max_title_len
        self.max_combined_len = max_combined_len
        self.pad_idx = word_to_idx['<PAD>']
        self.sep_idx = word_to_idx['<SEP>']

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        query = self.queries[index]
        title = self.titles[index]
        label = self.labels[index]

        query_tokens = tokenize(query)
        title_tokens = tokenize(title)

        # Indices before padding (for combining)
        query_indices_unpadded = tokens_to_indices(query_tokens, self.word_to_idx, self.max_query_len)
        title_indices_unpadded = tokens_to_indices(title_tokens, self.word_to_idx, self.max_title_len)

        # --- Prepare data for BOTH architectures ---

        # For Two-Tower Model (Separate Padded Sequences)
        query_padded = pad_sequence(query_indices_unpadded, self.max_query_len, self.pad_idx)
        title_padded = pad_sequence(title_indices_unpadded, self.max_title_len, self.pad_idx)

        # For Single-Tower Model (Combined Padded Sequence)
        combined_indices = query_indices_unpadded + [self.sep_idx] + title_indices_unpadded
        combined_padded = pad_sequence(combined_indices, self.max_combined_len, self.pad_idx)

        # Convert to tensors
        query_tensor = torch.tensor(query_padded, dtype=torch.long)
        title_tensor = torch.tensor(title_padded, dtype=torch.long)
        combined_tensor = torch.tensor(combined_padded, dtype=torch.long)
        label_tensor = torch.tensor(label, dtype=torch.long)

        return query_tensor, title_tensor, combined_tensor, label_tensor

###Model Definitions

####RNN(Single Tower+Tow Tower)

In [7]:
def _extract_final_hidden(hidden_state, num_layers, num_directions, batch_size):
    # hidden_state shape: (num_layers * num_directions, batch_size, rnn_hidden_dim)
    # Reshape to view layers and directions separately
    hidden = hidden_state.view(num_layers, num_directions, batch_size, -1)
    # Get the hidden state of the last layer (forward and backward if bidirectional)
    # Shape: (num_directions, batch_size, rnn_hidden_dim)
    hidden_last_layer = hidden[-1]
    # Permute and reshape to (batch_size, num_directions * rnn_hidden_dim)
    final_hidden = hidden_last_layer.permute(1, 0, 2).reshape(batch_size, -1)
    return final_hidden

class SingleTowerRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, rnn_nonlinearity, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        rnn_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1
        self.rnn = nn.RNN(embedding_dim, rnn_hidden_dim, num_layers=rnn_layers,
                          nonlinearity=rnn_nonlinearity, batch_first=True,
                          dropout=dropout_prob if rnn_layers > 1 else 0,
                          bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_dim, num_classes) # Input is hidden_dim (num_directions * rnn_hidden_dim)
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, combined_indices):
        # combined_indices: (batch_size, max_combined_len)
        embedded = self.dropout(self.embedding(combined_indices))
        output, hidden = self.rnn(embedded)
        # hidden: (num_layers * num_directions, batch_size, rnn_hidden_dim)
        final_hidden = _extract_final_hidden(hidden, self.rnn_layers, self.num_directions, combined_indices.size(0))
        # final_hidden: (batch_size, hidden_dim)
        logits = self.fc(self.dropout(final_hidden))
        return logits

class TwoTowerRNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, rnn_nonlinearity, pad_idx):
        super().__init__()
        # Share the embedding layer between towers
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)

        rnn_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1

        self.rnn_query = nn.RNN(embedding_dim, rnn_hidden_dim, num_layers=rnn_layers,
                                nonlinearity=rnn_nonlinearity, batch_first=True,
                                dropout=dropout_prob if rnn_layers > 1 else 0,
                                bidirectional=bidirectional)
        self.rnn_title = nn.RNN(embedding_dim, rnn_hidden_dim, num_layers=rnn_layers,
                                nonlinearity=rnn_nonlinearity, batch_first=True,
                                dropout=dropout_prob if rnn_layers > 1 else 0,
                                bidirectional=bidirectional)

        self.dropout = nn.Dropout(dropout_prob)
        # Input to FC layer is concatenation of final hidden states from both towers
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, query_indices, title_indices):
        # query_indices: (batch_size, max_query_len)
        # title_indices: (batch_size, max_title_len)

        query_embedded = self.dropout(self.embedding(query_indices))
        title_embedded = self.dropout(self.embedding(title_indices))

        _, query_hidden = self.rnn_query(query_embedded)
        _, title_hidden = self.rnn_title(title_embedded)
        # hidden shapes: (num_layers * num_directions, batch_size, rnn_hidden_dim)

        query_final_hidden = _extract_final_hidden(query_hidden, self.rnn_layers, self.num_directions, query_indices.size(0))
        title_final_hidden = _extract_final_hidden(title_hidden, self.rnn_layers, self.num_directions, title_indices.size(0))
        # final hidden shapes: (batch_size, hidden_dim)

        # Concatenate the final states
        combined = torch.cat((query_final_hidden, title_final_hidden), dim=1)
        # combined: (batch_size, hidden_dim * 2)

        logits = self.fc(self.dropout(combined))
        return logits


####LSTM(Single Tower+Tow Tower)

In [8]:
class SingleTowerLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        lstm_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1
        self.rnn = nn.LSTM(embedding_dim, lstm_hidden_dim, num_layers=rnn_layers,
                           batch_first=True,
                           dropout=dropout_prob if rnn_layers > 1 else 0,
                           bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout_prob)
        # FC layer input dim is: total hidden_dim (num_directions * lstm_hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, combined_indices):
        # combined_indices: (batch_size, max_combined_len)
        embedded = self.dropout(self.embedding(combined_indices))
        # output shape: (batch_size, seq_len, num_directions * lstm_hidden_dim)
        # hidden_n shape: (num_layers * num_directions, batch_size, lstm_hidden_dim)
        # cell_n shape:   (num_layers * num_directions, batch_size, lstm_hidden_dim)
        output, (hidden_n, cell_n) = self.rnn(embedded)
        final_hidden = _extract_final_hidden(hidden_n, self.rnn_layers, self.num_directions, combined_indices.size(0))
        # final_hidden: (batch_size, hidden_dim)

        logits = self.fc(self.dropout(final_hidden))
        return logits

class TwoTowerLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, pad_idx): # Removed rnn_nonlinearity
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)

        lstm_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1

        # Use LSTM layers for query and title
        self.rnn_query = nn.LSTM(embedding_dim, lstm_hidden_dim, num_layers=rnn_layers,
                                 batch_first=True,
                                 dropout=dropout_prob if rnn_layers > 1 else 0,
                                 bidirectional=bidirectional)
        self.rnn_title = nn.LSTM(embedding_dim, lstm_hidden_dim, num_layers=rnn_layers,
                                 batch_first=True,
                                 dropout=dropout_prob if rnn_layers > 1 else 0,
                                 bidirectional=bidirectional)

        self.dropout = nn.Dropout(dropout_prob)
        # Input to FC layer is concatenation of final hidden states (h_n) from both towers
        self.fc = nn.Linear(hidden_dim * 2, num_classes) # hidden_dim * 2 because of two towers
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, query_indices, title_indices):
        # query_indices: (batch_size, max_query_len)
        # title_indices: (batch_size, max_title_len)

        query_embedded = self.dropout(self.embedding(query_indices))
        title_embedded = self.dropout(self.embedding(title_indices))

        _, (query_hidden_n, query_cell_n) = self.rnn_query(query_embedded)
        _, (title_hidden_n, title_cell_n) = self.rnn_title(title_embedded)
        # hidden_n shapes: (num_layers * num_directions, batch_size, lstm_hidden_dim)

        query_final_hidden = _extract_final_hidden(query_hidden_n, self.rnn_layers, self.num_directions, query_indices.size(0))
        title_final_hidden = _extract_final_hidden(title_hidden_n, self.rnn_layers, self.num_directions, title_indices.size(0))
        # final hidden shapes: (batch_size, hidden_dim)

        # Concatenate the final hidden states
        combined = torch.cat((query_final_hidden, title_final_hidden), dim=1)
        # combined: (batch_size, hidden_dim * 2)

        logits = self.fc(self.dropout(combined))
        return logits

####GRU(Single Tower+Tow Tower)

In [9]:
class SingleTowerGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
        gru_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1
        # Use nn.GRU layer
        self.rnn = nn.GRU(embedding_dim, gru_hidden_dim, num_layers=rnn_layers,
                          batch_first=True,
                          dropout=dropout_prob if rnn_layers > 1 else 0,
                          bidirectional=bidirectional)
        self.dropout = nn.Dropout(dropout_prob)
        self.fc = nn.Linear(hidden_dim, num_classes) # Input dim is total hidden_dim
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, combined_indices):
        # combined_indices: (batch_size, max_combined_len)
        embedded = self.dropout(self.embedding(combined_indices))
        # output shape: (batch_size, seq_len, num_directions * gru_hidden_dim)
        # hidden_n shape: (num_layers * num_directions, batch_size, gru_hidden_dim)
        output, hidden_n = self.rnn(embedded)
        final_hidden = _extract_final_hidden(hidden_n, self.rnn_layers, self.num_directions, combined_indices.size(0))
        # final_hidden: (batch_size, hidden_dim)

        logits = self.fc(self.dropout(final_hidden))
        return logits

class TwoTowerGRU(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_classes,
                 rnn_layers, dropout_prob, bidirectional, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)

        gru_hidden_dim = hidden_dim // 2 if bidirectional else hidden_dim
        num_directions = 2 if bidirectional else 1

        # Use GRU layers for query and title
        self.rnn_query = nn.GRU(embedding_dim, gru_hidden_dim, num_layers=rnn_layers,
                                batch_first=True,
                                dropout=dropout_prob if rnn_layers > 1 else 0,
                                bidirectional=bidirectional)
        self.rnn_title = nn.GRU(embedding_dim, gru_hidden_dim, num_layers=rnn_layers,
                                batch_first=True,
                                dropout=dropout_prob if rnn_layers > 1 else 0,
                                bidirectional=bidirectional)

        self.dropout = nn.Dropout(dropout_prob)
        # Input to FC layer is concatenation of final hidden states (h_n) from both towers
        self.fc = nn.Linear(hidden_dim * 2, num_classes)
        self.rnn_layers = rnn_layers
        self.num_directions = num_directions

    def forward(self, query_indices, title_indices):
        # query_indices: (batch_size, max_query_len)
        # title_indices: (batch_size, max_title_len)

        query_embedded = self.dropout(self.embedding(query_indices))
        title_embedded = self.dropout(self.embedding(title_indices))

        _, query_hidden_n = self.rnn_query(query_embedded)
        _, title_hidden_n = self.rnn_title(title_embedded)
        # hidden_n shapes: (num_layers * num_directions, batch_size, gru_hidden_dim)

        query_final_hidden = _extract_final_hidden(query_hidden_n, self.rnn_layers, self.num_directions, query_indices.size(0))
        title_final_hidden = _extract_final_hidden(title_hidden_n, self.rnn_layers, self.num_directions, title_indices.size(0))
        # final hidden shapes: (batch_size, hidden_dim)

        # Concatenate the final hidden states
        combined = torch.cat((query_final_hidden, title_final_hidden), dim=1)
        # combined: (batch_size, hidden_dim * 2)

        logits = self.fc(self.dropout(combined))
        return logits

###Training Function

In [10]:
def train_model(model, model_architecture, train_loader, val_loader, optimizer, criterion, num_epochs, device, save_path):
    print(f"\n--- Starting Training ({model_architecture.upper()}) ---")
    best_val_metric = -1.0 # Initialize with a value lower than any possible F1 score
    start_train_time = time.time()

    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        model.train()
        epoch_train_loss = 0.0
        total_train_samples = 0

        # --- Training Loop with TQDM ---
        train_iterator = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} Training", leave=False)
        for batch_idx, batch in enumerate(train_iterator):
            query_batch, title_batch, combined_batch, labels_batch = batch
            labels_batch = labels_batch.to(device)

            optimizer.zero_grad()

            if model_architecture == 'single':
                model_input = combined_batch.to(device)
                logits = model(model_input)
            elif model_architecture == 'two':
                query_input = query_batch.to(device)
                title_input = title_batch.to(device)
                logits = model(query_input, title_input)
            else:
                raise ValueError("Invalid model_architecture specified")

            loss = criterion(logits, labels_batch)
            loss.backward()
            optimizer.step()

            epoch_train_loss += loss.item() * labels_batch.size(0)
            total_train_samples += labels_batch.size(0)

            # Update TQDM postfix with current batch loss
            train_iterator.set_postfix(loss=f"{loss.item():.4f}")
        # --- End Training Loop ---

        avg_train_loss = epoch_train_loss / total_train_samples
        print(f"Epoch {epoch+1}/{num_epochs} Training: loss={avg_train_loss:.4f}") # Print average loss after epoch

        # --- Validation Step with TQDM ---
        model.eval()
        all_val_preds = []
        all_val_labels = []

        val_iterator = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} Evaluating", leave=False)
        with torch.no_grad():
            for batch in val_iterator:
                query_batch, title_batch, combined_batch, labels_batch = batch
                labels_batch = labels_batch.to(device)
                all_val_labels.extend(labels_batch.cpu().numpy()) # Collect labels

                if model_architecture == 'single':
                    model_input = combined_batch.to(device)
                    logits = model(model_input)
                elif model_architecture == 'two':
                    query_input = query_batch.to(device)
                    title_input = title_batch.to(device)
                    logits = model(query_input, title_input)
                else:
                    raise ValueError("Invalid model_architecture specified")

                preds = torch.argmax(logits, dim=1)
                all_val_preds.extend(preds.cpu().numpy()) # Collect predictions
        # --- End Validation Loop ---

        # Calculate Validation Metrics
        val_accuracy = accuracy_score(all_val_labels, all_val_preds)
        # Calculate weighted F1 score
        val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)

        print(f"Validation Accuracy: {val_accuracy:.4f}")
        print(f"Validation F1 Score: {val_f1:.4f}")

        # Print Classification Report for Validation
        target_names = ['Relevance 0', 'Relevance 1', 'Relevance 2']
        try:
            unique_labels_in_data = sorted(list(set(all_val_labels) | set(all_val_preds)))
            current_target_names = [target_names[i] for i in unique_labels_in_data if i < len(target_names)]
            report = classification_report(all_val_labels, all_val_preds, target_names=current_target_names, labels=unique_labels_in_data, digits=4, zero_division=0)
            print(report)
        except Exception as e:
            print(f"Could not generate full classification report (validation): {e}")
            print(classification_report(all_val_labels, all_val_preds, digits=4, zero_division=0))


        # Save best model based on weighted F1 score
        if val_f1 > best_val_metric:
            best_val_metric = val_f1
            torch.save(model.state_dict(), save_path)
            print(f"save best model to {save_path} (Val F1: {best_val_metric:.4f})") # Indicate saving

    total_train_duration = time.time() - start_train_time
    print(f"--- Training Finished ({model_architecture.upper()}) in {total_train_duration:.2f}s ---")
    print(f"Best Validation F1 Score ({model_architecture.upper()}): {best_val_metric:.4f}")

###Evaluation Function

In [11]:
def evaluate_model(model, model_architecture, test_loader, criterion, device):
    print(f"\n--- Starting Evaluation on Test Set ({model_architecture.upper()}) ---")
    model.eval()
    all_preds = []
    all_labels = []

    # --- Evaluation Loop with TQDM ---
    test_iterator = tqdm(test_loader, desc="Evaluating Test Set", leave=True) # leave=True to keep bar after completion
    with torch.no_grad():
        for batch in test_iterator:
            query_batch, title_batch, combined_batch, labels_batch = batch
            labels_batch = labels_batch.to(device)
            all_labels.extend(labels_batch.cpu().numpy()) # Collect labels

            if model_architecture == 'single':
                model_input = combined_batch.to(device)
                logits = model(model_input)
            elif model_architecture == 'two':
                query_input = query_batch.to(device)
                title_input = title_batch.to(device)
                logits = model(query_input, title_input)
            else:
                raise ValueError("Invalid model_architecture specified")

            # loss = criterion(logits, labels_batch)
            # total_loss += loss.item() * labels_batch.size(0)

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy()) # Collect predictions
    # --- End Evaluation Loop ---


    if not all_labels:
        print("Error: No samples found in the test loader for evaluation.")
        return 0.0, 0.0, [], [] # Return zeros or handle appropriately

    # Calculate Test Metrics
    accuracy = accuracy_score(all_labels, all_preds)
    test_f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) # Weighted F1

    print(f"Test Set Accuracy: {accuracy:.4f}")
    print(f"Test Set F1 Score: {test_f1:.4f}")

    # Print Classification Report for Test Set
    target_names = ['Relevance 0', 'Relevance 1', 'Relevance 2']
    try:
        unique_labels_in_data = sorted(list(set(all_labels) | set(all_preds)))
        current_target_names = [target_names[i] for i in unique_labels_in_data if i < len(target_names)]
        report = classification_report(all_labels, all_preds, target_names=current_target_names, labels=unique_labels_in_data, digits=4, zero_division=0)
        print(report)
    except Exception as e:
        print(f"Could not generate full classification report (test): {e}")
        print(classification_report(all_labels, all_preds, digits=4, zero_division=0))

    # print("\nConfusion Matrix:")
    # cm = confusion_matrix(all_labels, all_preds, labels=[0, 1, 2])
    # print(cm)

    print(f"--- Evaluation Finished ({model_architecture.upper()}) ---")
    return accuracy, test_f1, all_preds, all_labels

###Main Execution

In [12]:
# Load Data
queries1, titles1, labels1 = load_data(CONFIG["train_data_path"], CONFIG["label_field"])
train_df = pd.DataFrame({'query': queries1, 'title': titles1, 'label': labels1})

queries2, titles2, labels2 = load_data(CONFIG["val_data_path"], CONFIG["label_field"])
val_df = pd.DataFrame({'query': queries2, 'title': titles2, 'label': labels2})

queries3, titles3, labels3 = load_data(CONFIG["test_data_path"], CONFIG["label_field"])
test_df = pd.DataFrame({'query': queries3, 'title': titles3, 'label': labels3})

print(f"  Training samples:   {len(train_df)}")
print(f"  Validation samples: {len(val_df)}")
print(f"  Test samples:       {len(test_df)}")
print(train_df.head())
print(val_df.head())
print(test_df.head())

Loading data from dataset/train.json...
Failed loading as list, trying JSON Lines format...
Loaded 180000 valid samples.
Loading data from dataset/dev.json...
Failed loading as list, trying JSON Lines format...
Loaded 20000 valid samples.
Loading data from dataset/test_public.json...
Failed loading as list, trying JSON Lines format...
Loaded 5000 valid samples.
  Training samples:   180000
  Validation samples: 20000
  Test samples:       5000
             query                           title  label
0            应届生实习                    实习生招聘-应届生求职网      1
1  ln1+x-ln1+y=x-y  已知函数fx=1lnx+1-x则y=fx的图像高考吧百度贴吧      0
2         大秦之悍卒189                   起点中文网阅文集团旗下网站      0
3             出门经咒                     快快乐乐出门咒-豆丁网      1
4           盖中盖广告词              谁知道盖中盖所有的广告词急用百度知道      1
              query                                              title  label
0            小孩咳嗽感冒                              小孩感冒过后久咳嗽该吃什么药育儿问答宝宝树      1
1      前列腺癌根治术后能活多久                    前列腺癌转移能活多

In [13]:
# Build Vocabulary from Training Data Only
print("\nTokenizing training data for vocabulary...")
train_query_tokens = [tokenize(q) for q in train_df['query']]
train_title_tokens = [tokenize(t) for t in train_df['title']]
word_to_idx = build_vocab(train_query_tokens + train_title_tokens,
                          CONFIG["vocab_max_size"], CONFIG["vocab_min_freq"])
vocab_size = len(word_to_idx)
pad_idx = word_to_idx['<PAD>'] # Get pad index for models

# Create Datasets and DataLoaders (Unified Dataset)
train_dataset = RelevanceDataset(train_df['query'].tolist(), train_df['title'].tolist(), train_df['label'].tolist(),
                                  word_to_idx, CONFIG["max_query_len"], CONFIG["max_title_len"], CONFIG["max_combined_len"])
val_dataset = RelevanceDataset(val_df['query'].tolist(), val_df['title'].tolist(), val_df['label'].tolist(),
                                word_to_idx, CONFIG["max_query_len"], CONFIG["max_title_len"], CONFIG["max_combined_len"])
test_dataset = RelevanceDataset(test_df['query'].tolist(), test_df['title'].tolist(), test_df['label'].tolist(),
                                word_to_idx, CONFIG["max_query_len"], CONFIG["max_title_len"], CONFIG["max_combined_len"])

train_loader = data.DataLoader(train_dataset, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=0, pin_memory=True)
val_loader = data.DataLoader(val_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=0, pin_memory=True)
test_loader = data.DataLoader(test_dataset, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=0, pin_memory=True)


# --- Shared Components ---
criterion = nn.CrossEntropyLoss()

Building prefix dict from the default dictionary ...
DEBUG:jieba:Building prefix dict from the default dictionary ...



Tokenizing training data for vocabulary...


Dumping model to file cache /tmp/jieba.cache
DEBUG:jieba:Dumping model to file cache /tmp/jieba.cache
Loading model cost 0.653 seconds.
DEBUG:jieba:Loading model cost 0.653 seconds.
Prefix dict has been built successfully.
DEBUG:jieba:Prefix dict has been built successfully.


Building vocabulary...
Vocabulary built with 10000 words.


####--- Single Tower RNN ---

In [None]:
print("\n" + "="*20 + " SINGLE TOWER RNN " + "="*20)
single_tower_model = SingleTowerRNN(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"], rnn_nonlinearity=CONFIG["rnn_nonlinearity"], pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_single = optim.Adam(single_tower_model.parameters(), lr=CONFIG["learning_rate"])

# Train Single Tower
train_model(single_tower_model, 'single', train_loader, val_loader, optimizer_single, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_single"])

# Evaluate Single Tower
print("\nLoading best single-tower model for final evaluation...")
try:
    single_tower_model.load_state_dict(torch.load(CONFIG["save_path_single"], map_location=CONFIG["device"]))
except Exception as e:
    print(f"Could not load saved single-tower model state: {e}. Evaluating model state after last epoch.")
accuracy_rnn_single, f1_rnn_single, _, _ = evaluate_model(single_tower_model, 'single', test_loader, criterion, CONFIG["device"])



--- Starting Training (SINGLE) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8758


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6317
Validation F1 Score: 0.4961
              precision    recall  f1-score   support

 Relevance 0     0.6025    0.0198    0.0384      4894
 Relevance 1     0.6319    0.9956    0.7731     12592
 Relevance 2     0.0000    0.0000    0.0000      2514

    accuracy                         0.6317     20000
   macro avg     0.4115    0.3385    0.2705     20000
weighted avg     0.5453    0.6317    0.4961     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.4961)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.8417


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6408
Validation F1 Score: 0.5610
              precision    recall  f1-score   support

 Relevance 0     0.5219    0.2260    0.3154      4894
 Relevance 1     0.6548    0.9299    0.7685     12592
 Relevance 2     0.0000    0.0000    0.0000      2514

    accuracy                         0.6408     20000
   macro avg     0.3923    0.3853    0.3613     20000
weighted avg     0.5400    0.6408    0.5610     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.5610)


Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.8206


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6475
Validation F1 Score: 0.5732
              precision    recall  f1-score   support

 Relevance 0     0.5456    0.2638    0.3556      4894
 Relevance 1     0.6614    0.9256    0.7715     12592
 Relevance 2     0.4545    0.0020    0.0040      2514

    accuracy                         0.6475     20000
   macro avg     0.5538    0.3971    0.3770     20000
weighted avg     0.6070    0.6475    0.5732     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.5732)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.8060


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6358
Validation F1 Score: 0.5846
              precision    recall  f1-score   support

 Relevance 0     0.4818    0.3897    0.4309      4894
 Relevance 1     0.6749    0.8549    0.7543     12592
 Relevance 2     0.4835    0.0175    0.0338      2514

    accuracy                         0.6358     20000
   macro avg     0.5467    0.4207    0.4063     20000
weighted avg     0.6036    0.6358    0.5846     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.5846)


Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7935


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6430
Validation F1 Score: 0.5952
              precision    recall  f1-score   support

 Relevance 0     0.5314    0.2540    0.3437      4894
 Relevance 1     0.6696    0.8899    0.7642     12592
 Relevance 2     0.4428    0.1631    0.2384      2514

    accuracy                         0.6430     20000
   macro avg     0.5479    0.4357    0.4488     20000
weighted avg     0.6073    0.6430    0.5952     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.5952)


Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7862


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6462
Validation F1 Score: 0.6030
              precision    recall  f1-score   support

 Relevance 0     0.5257    0.3222    0.3995      4894
 Relevance 1     0.6757    0.8761    0.7630     12592
 Relevance 2     0.4674    0.1253    0.1976      2514

    accuracy                         0.6462     20000
   macro avg     0.5563    0.4412    0.4534     20000
weighted avg     0.6128    0.6462    0.6030     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.6030)


Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.7795


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6412
Validation F1 Score: 0.6082
              precision    recall  f1-score   support

 Relevance 0     0.5071    0.3784    0.4334      4894
 Relevance 1     0.6845    0.8437    0.7558     12592
 Relevance 2     0.4215    0.1388    0.2089      2514

    accuracy                         0.6412     20000
   macro avg     0.5377    0.4537    0.4660     20000
weighted avg     0.6081    0.6412    0.6082     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.6082)


Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.7710


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6448
Validation F1 Score: 0.6138
              precision    recall  f1-score   support

 Relevance 0     0.5146    0.3782    0.4360      4894
 Relevance 1     0.6874    0.8443    0.7578     12592
 Relevance 2     0.4402    0.1639    0.2388      2514

    accuracy                         0.6448     20000
   macro avg     0.5474    0.4621    0.4776     20000
weighted avg     0.6140    0.6448    0.6138     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.6138)


Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.7650


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6462
Validation F1 Score: 0.6168
              precision    recall  f1-score   support

 Relevance 0     0.5073    0.4003    0.4475      4894
 Relevance 1     0.6911    0.8385    0.7577     12592
 Relevance 2     0.4733    0.1619    0.2413      2514

    accuracy                         0.6462     20000
   macro avg     0.5572    0.4669    0.4821     20000
weighted avg     0.6187    0.6462    0.6168     20000

save best model to best_single_tower_rnn.pth (Val F1: 0.6168)


Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.7583


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6379
Validation F1 Score: 0.6156
              precision    recall  f1-score   support

 Relevance 0     0.4999    0.4142    0.4530      4894
 Relevance 1     0.6949    0.8135    0.7496     12592
 Relevance 2     0.4037    0.1933    0.2614      2514

    accuracy                         0.6379     20000
   macro avg     0.5328    0.4737    0.4880     20000
weighted avg     0.6106    0.6379    0.6156     20000

--- Training Finished (SINGLE) in 602.04s ---
Best Validation F1 Score (SINGLE): 0.6168

Loading best single-tower model for final evaluation...

--- Starting Evaluation on Test Set (SINGLE) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6440
Test Set F1 Score: 0.6167
              precision    recall  f1-score   support

 Relevance 0     0.5082    0.4119    0.4550      1209
 Relevance 1     0.6938    0.8297    0.7557      3159
 Relevance 2     0.4174    0.1598    0.2311       632

    accuracy                         0.6440      5000
   macro avg     0.5398    0.4671    0.4806      5000
weighted avg     0.6139    0.6440    0.6167      5000

--- Evaluation Finished (SINGLE) ---


#### --- Two Tower RNN ---

In [None]:
print("\n" + "="*20 + " TWO TOWER RNN " + "="*20)
two_tower_model = TwoTowerRNN(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"], rnn_nonlinearity=CONFIG["rnn_nonlinearity"], pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_two = optim.Adam(two_tower_model.parameters(), lr=CONFIG["learning_rate"])

# Train Two Tower
train_model(two_tower_model, 'two', train_loader, val_loader, optimizer_two, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_two"])

# Evaluate Two Tower
print("\nLoading best two-tower model for final evaluation...")
try:
    two_tower_model.load_state_dict(torch.load(CONFIG["save_path_two"], map_location=CONFIG["device"]))
except Exception as e:
      print(f"Could not load saved two-tower model state: {e}. Evaluating model state after last epoch.")
accuracy_rnn_two, f1_rnn_two, _, _ = evaluate_model(two_tower_model, 'two', test_loader, criterion, CONFIG["device"])



--- Starting Training (TWO) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8726


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6341
Validation F1 Score: 0.5473
              precision    recall  f1-score   support

 Relevance 0     0.4956    0.1620    0.2442      4894
 Relevance 1     0.6474    0.9401    0.7668     12592
 Relevance 2     0.4348    0.0199    0.0380      2514

    accuracy                         0.6341     20000
   macro avg     0.5259    0.3740    0.3497     20000
weighted avg     0.5835    0.6341    0.5473     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.5473)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.8350


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6399
Validation F1 Score: 0.5384
              precision    recall  f1-score   support

 Relevance 0     0.5873    0.1058    0.1794      4894
 Relevance 1     0.6436    0.9688    0.7734     12592
 Relevance 2     0.4848    0.0318    0.0597      2514

    accuracy                         0.6399     20000
   macro avg     0.5719    0.3688    0.3375     20000
weighted avg     0.6099    0.6399    0.5384     20000



Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.8116


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6440
Validation F1 Score: 0.5616
              precision    recall  f1-score   support

 Relevance 0     0.5657    0.1716    0.2634      4894
 Relevance 1     0.6529    0.9469    0.7728     12592
 Relevance 2     0.4603    0.0461    0.0839      2514

    accuracy                         0.6440     20000
   macro avg     0.5596    0.3882    0.3734     20000
weighted avg     0.6073    0.6440    0.5616     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.5616)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.7928


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6473
Validation F1 Score: 0.5934
              precision    recall  f1-score   support

 Relevance 0     0.5626    0.2166    0.3128      4894
 Relevance 1     0.6688    0.9085    0.7704     12592
 Relevance 2     0.4411    0.1774    0.2530      2514

    accuracy                         0.6473     20000
   macro avg     0.5575    0.4342    0.4454     20000
weighted avg     0.6142    0.6473    0.5934     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.5934)


Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7795


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6446
Validation F1 Score: 0.5946
              precision    recall  f1-score   support

 Relevance 0     0.5239    0.2908    0.3740      4894
 Relevance 1     0.6715    0.8901    0.7655     12592
 Relevance 2     0.4394    0.1038    0.1680      2514

    accuracy                         0.6446     20000
   macro avg     0.5450    0.4282    0.4358     20000
weighted avg     0.6062    0.6446    0.5946     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.5946)


Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7666


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6427
Validation F1 Score: 0.6022
              precision    recall  f1-score   support

 Relevance 0     0.5148    0.3102    0.3871      4894
 Relevance 1     0.6756    0.8698    0.7605     12592
 Relevance 2     0.4560    0.1523    0.2284      2514

    accuracy                         0.6427     20000
   macro avg     0.5488    0.4441    0.4587     20000
weighted avg     0.6086    0.6427    0.6022     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.6022)


Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.7583


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6455
Validation F1 Score: 0.6070
              precision    recall  f1-score   support

 Relevance 0     0.5304    0.3169    0.3968      4894
 Relevance 1     0.6796    0.8697    0.7630     12592
 Relevance 2     0.4252    0.1627    0.2353      2514

    accuracy                         0.6455     20000
   macro avg     0.5451    0.4498    0.4650     20000
weighted avg     0.6111    0.6455    0.6070     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.6070)


Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.7482


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6455
Validation F1 Score: 0.6120
              precision    recall  f1-score   support

 Relevance 0     0.5511    0.2720    0.3642      4894
 Relevance 1     0.6825    0.8627    0.7621     12592
 Relevance 2     0.4293    0.2848    0.3424      2514

    accuracy                         0.6455     20000
   macro avg     0.5543    0.4732    0.4896     20000
weighted avg     0.6185    0.6455    0.6120     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.6120)


Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.7390


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6498
Validation F1 Score: 0.6175
              precision    recall  f1-score   support

 Relevance 0     0.5321    0.3731    0.4386      4894
 Relevance 1     0.6893    0.8543    0.7630     12592
 Relevance 2     0.4283    0.1639    0.2371      2514

    accuracy                         0.6498     20000
   macro avg     0.5499    0.4638    0.4795     20000
weighted avg     0.6180    0.6498    0.6175     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.6175)


Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.7300


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6497
Validation F1 Score: 0.6205
              precision    recall  f1-score   support

 Relevance 0     0.5447    0.3537    0.4289      4894
 Relevance 1     0.6891    0.8515    0.7617     12592
 Relevance 2     0.4287    0.2152    0.2865      2514

    accuracy                         0.6497     20000
   macro avg     0.5541    0.4735    0.4924     20000
weighted avg     0.6210    0.6497    0.6205     20000

save best model to best_two_tower_rnn.pth (Val F1: 0.6205)
--- Training Finished (TWO) in 633.35s ---
Best Validation F1 Score (TWO): 0.6205

Loading best two-tower model for final evaluation...

--- Starting Evaluation on Test Set (TWO) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6562
Test Set F1 Score: 0.6277
              precision    recall  f1-score   support

 Relevance 0     0.5679    0.3598    0.4405      1209
 Relevance 1     0.6939    0.8563    0.7666      3159
 Relevance 2     0.4196    0.2231    0.2913       632

    accuracy                         0.6562      5000
   macro avg     0.5605    0.4797    0.4995      5000
weighted avg     0.6288    0.6562    0.6277      5000

--- Evaluation Finished (TWO) ---


#### --- Single Tower LSTM ---

In [None]:
print("\n" + "="*20 + " SINGLE TOWER LSTM " + "="*20)
single_tower_lstm_model = SingleTowerLSTM(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"],
    pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_single_lstm = optim.Adam(single_tower_lstm_model.parameters(), lr=CONFIG["learning_rate"])

# Train Single Tower LSTM
train_model(single_tower_lstm_model, 'single', train_loader, val_loader, optimizer_single_lstm, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_single_lstm"])

# Evaluate Single Tower LSTM
print("\nLoading best single-tower LSTM model for final evaluation...")
try:
    single_tower_lstm_model.load_state_dict(torch.load(CONFIG["save_path_single_lstm"], map_location=CONFIG["device"]))
except Exception as e:
    print(f"Could not load saved single-tower LSTM model state: {e}. Evaluating model state after last epoch.")
accuracy_lstm_single, f1_lstm_single, _, _ = evaluate_model(single_tower_lstm_model, 'single', test_loader, criterion, CONFIG["device"])



--- Starting Training (SINGLE) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8401


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6462
Validation F1 Score: 0.5908
              precision    recall  f1-score   support

 Relevance 0     0.5175    0.3020    0.3814      4894
 Relevance 1     0.6701    0.8956    0.7666     12592
 Relevance 2     0.5335    0.0664    0.1181      2514

    accuracy                         0.6462     20000
   macro avg     0.5737    0.4214    0.4221     20000
weighted avg     0.6156    0.6462    0.5908     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.5908)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.7912


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6513
Validation F1 Score: 0.6079
              precision    recall  f1-score   support

 Relevance 0     0.5271    0.3441    0.4164      4894
 Relevance 1     0.6810    0.8783    0.7672     12592
 Relevance 2     0.4991    0.1122    0.1832      2514

    accuracy                         0.6513     20000
   macro avg     0.5691    0.4449    0.4556     20000
weighted avg     0.6205    0.6513    0.6079     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.6079)


Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.7660


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6558
Validation F1 Score: 0.6201
              precision    recall  f1-score   support

 Relevance 0     0.5571    0.2989    0.3891      4894
 Relevance 1     0.6844    0.8769    0.7688     12592
 Relevance 2     0.4915    0.2426    0.3249      2514

    accuracy                         0.6558     20000
   macro avg     0.5777    0.4728    0.4943     20000
weighted avg     0.6290    0.6558    0.6201     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.6201)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.7474


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6585
Validation F1 Score: 0.6192
              precision    recall  f1-score   support

 Relevance 0     0.5711    0.2846    0.3799      4894
 Relevance 1     0.6829    0.8888    0.7724     12592
 Relevance 2     0.4996    0.2331    0.3179      2514

    accuracy                         0.6585     20000
   macro avg     0.5845    0.4688    0.4901     20000
weighted avg     0.6325    0.6585    0.6192     20000



Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7313


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6590
Validation F1 Score: 0.6349
              precision    recall  f1-score   support

 Relevance 0     0.5474    0.3682    0.4403      4894
 Relevance 1     0.6985    0.8475    0.7658     12592
 Relevance 2     0.4941    0.2808    0.3581      2514

    accuracy                         0.6590     20000
   macro avg     0.5800    0.4989    0.5214     20000
weighted avg     0.6358    0.6590    0.6349     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.6349)


Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7174


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6585
Validation F1 Score: 0.6294
              precision    recall  f1-score   support

 Relevance 0     0.5512    0.3396    0.4203      4894
 Relevance 1     0.6928    0.8621    0.7682     12592
 Relevance 2     0.4966    0.2601    0.3414      2514

    accuracy                         0.6585     20000
   macro avg     0.5802    0.4873    0.5100     20000
weighted avg     0.6335    0.6585    0.6294     20000



Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.7051


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6562
Validation F1 Score: 0.6378
              precision    recall  f1-score   support

 Relevance 0     0.5266    0.4422    0.4807      4894
 Relevance 1     0.7068    0.8194    0.7590     12592
 Relevance 2     0.4957    0.2550    0.3367      2514

    accuracy                         0.6562     20000
   macro avg     0.5764    0.5055    0.5255     20000
weighted avg     0.6362    0.6562    0.6378     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.6378)


Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.6939


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6567
Validation F1 Score: 0.6360
              precision    recall  f1-score   support

 Relevance 0     0.5326    0.4119    0.4646      4894
 Relevance 1     0.7031    0.8312    0.7618     12592
 Relevance 2     0.4910    0.2593    0.3394      2514

    accuracy                         0.6567     20000
   macro avg     0.5756    0.5008    0.5219     20000
weighted avg     0.6347    0.6567    0.6360     20000



Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.6851


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6541
Validation F1 Score: 0.6375
              precision    recall  f1-score   support

 Relevance 0     0.5197    0.4555    0.4855      4894
 Relevance 1     0.7080    0.8102    0.7557     12592
 Relevance 2     0.5000    0.2589    0.3412      2514

    accuracy                         0.6541     20000
   macro avg     0.5759    0.5082    0.5274     20000
weighted avg     0.6358    0.6541    0.6375     20000



Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.6759


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6560
Validation F1 Score: 0.6406
              precision    recall  f1-score   support

 Relevance 0     0.5286    0.4397    0.4801      4894
 Relevance 1     0.7104    0.8123    0.7580     12592
 Relevance 2     0.4824    0.2936    0.3650      2514

    accuracy                         0.6560     20000
   macro avg     0.5738    0.5152    0.5343     20000
weighted avg     0.6373    0.6560    0.6406     20000

save best model to best_single_tower_lstm.pth (Val F1: 0.6406)
--- Training Finished (SINGLE) in 635.59s ---
Best Validation F1 Score (SINGLE): 0.6406

Loading best single-tower LSTM model for final evaluation...

--- Starting Evaluation on Test Set (SINGLE) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6634
Test Set F1 Score: 0.6487
              precision    recall  f1-score   support

 Relevance 0     0.5357    0.4467    0.4871      1209
 Relevance 1     0.7185    0.8177    0.7649      3159
 Relevance 2     0.4887    0.3070    0.3771       632

    accuracy                         0.6634      5000
   macro avg     0.5810    0.5238    0.5430      5000
weighted avg     0.6452    0.6634    0.6487      5000

--- Evaluation Finished (SINGLE) ---


#### --- Two Tower LSTM ---

In [None]:
print("\n" + "="*20 + " TWO TOWER LSTM " + "="*20)
two_tower_lstm_model = TwoTowerLSTM(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"],
    pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_two_lstm = optim.Adam(two_tower_lstm_model.parameters(), lr=CONFIG["learning_rate"])

# Train Two Tower LSTM
train_model(two_tower_lstm_model, 'two', train_loader, val_loader, optimizer_two_lstm, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_two_lstm"])

# Evaluate Two Tower LSTM
print("\nLoading best two-tower LSTM model for final evaluation...")
try:
    two_tower_lstm_model.load_state_dict(torch.load(CONFIG["save_path_two_lstm"], map_location=CONFIG["device"]))
except Exception as e:
      print(f"Could not load saved two-tower LSTM model state: {e}. Evaluating model state after last epoch.")
accuracy_lstm_two, f1_lstm_two, _, _ = evaluate_model(two_tower_lstm_model, 'two', test_loader, criterion, CONFIG["device"])



--- Starting Training (TWO) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8342


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6482
Validation F1 Score: 0.6041
              precision    recall  f1-score   support

 Relevance 0     0.5530    0.2366    0.3314      4894
 Relevance 1     0.6758    0.8901    0.7683     12592
 Relevance 2     0.4531    0.2383    0.3123      2514

    accuracy                         0.6482     20000
   macro avg     0.5606    0.4550    0.4707     20000
weighted avg     0.6178    0.6482    0.6041     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6041)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.7886


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6515
Validation F1 Score: 0.6171
              precision    recall  f1-score   support

 Relevance 0     0.5614    0.2738    0.3681      4894
 Relevance 1     0.6862    0.8715    0.7678     12592
 Relevance 2     0.4426    0.2852    0.3469      2514

    accuracy                         0.6515     20000
   macro avg     0.5634    0.4768    0.4943     20000
weighted avg     0.6250    0.6515    0.6171     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6171)


Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.7632


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6618
Validation F1 Score: 0.6260
              precision    recall  f1-score   support

 Relevance 0     0.5755    0.3202    0.4114      4894
 Relevance 1     0.6891    0.8829    0.7741     12592
 Relevance 2     0.4821    0.2196    0.3017      2514

    accuracy                         0.6618     20000
   macro avg     0.5822    0.4742    0.4957     20000
weighted avg     0.6353    0.6618    0.6260     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6260)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.7428


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6604
Validation F1 Score: 0.6338
              precision    recall  f1-score   support

 Relevance 0     0.5584    0.3555    0.4345      4894
 Relevance 1     0.6980    0.8574    0.7695     12592
 Relevance 2     0.4746    0.2677    0.3423      2514

    accuracy                         0.6604     20000
   macro avg     0.5770    0.4935    0.5154     20000
weighted avg     0.6358    0.6604    0.6338     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6338)


Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7254


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6635
Validation F1 Score: 0.6306
              precision    recall  f1-score   support

 Relevance 0     0.5769    0.3304    0.4202      4894
 Relevance 1     0.6928    0.8779    0.7744     12592
 Relevance 2     0.4835    0.2387    0.3196      2514

    accuracy                         0.6635     20000
   macro avg     0.5844    0.4823    0.5047     20000
weighted avg     0.6381    0.6635    0.6306     20000



Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7109


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6634
Validation F1 Score: 0.6425
              precision    recall  f1-score   support

 Relevance 0     0.5525    0.3948    0.4605      4894
 Relevance 1     0.7063    0.8416    0.7680     12592
 Relevance 2     0.4927    0.2940    0.3682      2514

    accuracy                         0.6634     20000
   macro avg     0.5838    0.5101    0.5322     20000
weighted avg     0.6418    0.6634    0.6425     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6425)


Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.6975


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6636
Validation F1 Score: 0.6404
              precision    recall  f1-score   support

 Relevance 0     0.5826    0.3562    0.4420      4894
 Relevance 1     0.7031    0.8524    0.7706     12592
 Relevance 2     0.4575    0.3170    0.3745      2514

    accuracy                         0.6636     20000
   macro avg     0.5810    0.5085    0.5290     20000
weighted avg     0.6427    0.6636    0.6404     20000



Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.6864


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6667
Validation F1 Score: 0.6453
              precision    recall  f1-score   support

 Relevance 0     0.5695    0.3954    0.4667      4894
 Relevance 1     0.7080    0.8472    0.7714     12592
 Relevance 2     0.4762    0.2908    0.3611      2514

    accuracy                         0.6667     20000
   macro avg     0.5846    0.5111    0.5331     20000
weighted avg     0.6450    0.6667    0.6453     20000

save best model to best_two_tower_lstm.pth (Val F1: 0.6453)


Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.6748


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6626
Validation F1 Score: 0.6371
              precision    recall  f1-score   support

 Relevance 0     0.5627    0.3833    0.4560      4894
 Relevance 1     0.7007    0.8540    0.7698     12592
 Relevance 2     0.4716    0.2474    0.3245      2514

    accuracy                         0.6626     20000
   macro avg     0.5783    0.4949    0.5168     20000
weighted avg     0.6381    0.6626    0.6371     20000



Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.6660


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6643
Validation F1 Score: 0.6445
              precision    recall  f1-score   support

 Relevance 0     0.5712    0.3911    0.4643      4894
 Relevance 1     0.7085    0.8407    0.7689     12592
 Relevance 2     0.4599    0.3123    0.3719      2514

    accuracy                         0.6643     20000
   macro avg     0.5798    0.5147    0.5351     20000
weighted avg     0.6436    0.6643    0.6445     20000

--- Training Finished (TWO) in 670.58s ---
Best Validation F1 Score (TWO): 0.6453

Loading best two-tower LSTM model for final evaluation...

--- Starting Evaluation on Test Set (TWO) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6646
Test Set F1 Score: 0.6434
              precision    recall  f1-score   support

 Relevance 0     0.5707    0.3937    0.4660      1209
 Relevance 1     0.7057    0.8433    0.7684      3159
 Relevance 2     0.4680    0.2896    0.3578       632

    accuracy                         0.6646      5000
   macro avg     0.5815    0.5089    0.5307      5000
weighted avg     0.6430    0.6646    0.6434      5000

--- Evaluation Finished (TWO) ---


#### --- Single Tower GRU ---

In [14]:
print("\n" + "="*20 + " SINGLE TOWER GRU " + "="*20)
single_tower_gru_model = SingleTowerGRU(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"],
    pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_single_gru = optim.Adam(single_tower_gru_model.parameters(), lr=CONFIG["learning_rate"])

# Train Single Tower GRU
train_model(single_tower_gru_model, 'single', train_loader, val_loader, optimizer_single_gru, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_single_gru"])

# Evaluate Single Tower GRU
print("\nLoading best single-tower GRU model for final evaluation...")
try:
    single_tower_gru_model.load_state_dict(torch.load(CONFIG["save_path_single_gru"], map_location=CONFIG["device"]))
except Exception as e:
    print(f"Could not load saved single-tower GRU model state: {e}. Evaluating model state after last epoch.")
accuracy_gru_single, f1_gru_single, _, _ = evaluate_model(single_tower_gru_model, 'single', test_loader, criterion, CONFIG["device"])



--- Starting Training (SINGLE) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8395


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6505
Validation F1 Score: 0.6004
              precision    recall  f1-score   support

 Relevance 0     0.5572    0.2327    0.3283      4894
 Relevance 1     0.6717    0.9042    0.7708     12592
 Relevance 2     0.4831    0.1933    0.2761      2514

    accuracy                         0.6505     20000
   macro avg     0.5707    0.4434    0.4584     20000
weighted avg     0.6200    0.6505    0.6004     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6004)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.7910


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6506
Validation F1 Score: 0.6152
              precision    recall  f1-score   support

 Relevance 0     0.5279    0.3327    0.4081      4894
 Relevance 1     0.6839    0.8671    0.7646     12592
 Relevance 2     0.4911    0.1858    0.2696      2514

    accuracy                         0.6506     20000
   macro avg     0.5676    0.4618    0.4808     20000
weighted avg     0.6215    0.6506    0.6152     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6152)


Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.7687


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6550
Validation F1 Score: 0.6206
              precision    recall  f1-score   support

 Relevance 0     0.5562    0.2942    0.3849      4894
 Relevance 1     0.6852    0.8735    0.7680     12592
 Relevance 2     0.4867    0.2629    0.3414      2514

    accuracy                         0.6550     20000
   macro avg     0.5760    0.4769    0.4981     20000
weighted avg     0.6287    0.6550    0.6206     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6206)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.7509


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6612
Validation F1 Score: 0.6255
              precision    recall  f1-score   support

 Relevance 0     0.5809    0.2934    0.3899      4894
 Relevance 1     0.6872    0.8833    0.7730     12592
 Relevance 2     0.4952    0.2645    0.3448      2514

    accuracy                         0.6612     20000
   macro avg     0.5878    0.4804    0.5026     20000
weighted avg     0.6371    0.6612    0.6255     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6255)


Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7359


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6600
Validation F1 Score: 0.6341
              precision    recall  f1-score   support

 Relevance 0     0.5482    0.3670    0.4397      4894
 Relevance 1     0.6977    0.8534    0.7677     12592
 Relevance 2     0.4977    0.2617    0.3431      2514

    accuracy                         0.6600     20000
   macro avg     0.5812    0.4940    0.5168     20000
weighted avg     0.6360    0.6600    0.6341     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6341)


Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7243


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6622
Validation F1 Score: 0.6293
              precision    recall  f1-score   support

 Relevance 0     0.5833    0.2983    0.3948      4894
 Relevance 1     0.6911    0.8777    0.7733     12592
 Relevance 2     0.4864    0.2912    0.3643      2514

    accuracy                         0.6622     20000
   macro avg     0.5869    0.4891    0.5108     20000
weighted avg     0.6390    0.6622    0.6293     20000



Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.7134


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6596
Validation F1 Score: 0.6412
              precision    recall  f1-score   support

 Relevance 0     0.5430    0.4130    0.4691      4894
 Relevance 1     0.7072    0.8284    0.7630     12592
 Relevance 2     0.4840    0.2944    0.3661      2514

    accuracy                         0.6596     20000
   macro avg     0.5781    0.5119    0.5327     20000
weighted avg     0.6390    0.6596    0.6412     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6412)


Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.7027


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6589
Validation F1 Score: 0.6423
              precision    recall  f1-score   support

 Relevance 0     0.5564    0.3880    0.4572      4894
 Relevance 1     0.7092    0.8254    0.7629     12592
 Relevance 2     0.4584    0.3524    0.3985      2514

    accuracy                         0.6589     20000
   macro avg     0.5747    0.5219    0.5395     20000
weighted avg     0.6403    0.6589    0.6423     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6423)


Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.6946


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6602
Validation F1 Score: 0.6455
              precision    recall  f1-score   support

 Relevance 0     0.5391    0.4324    0.4799      4894
 Relevance 1     0.7145    0.8166    0.7621     12592
 Relevance 2     0.4780    0.3202    0.3835      2514

    accuracy                         0.6602     20000
   macro avg     0.5772    0.5230    0.5418     20000
weighted avg     0.6418    0.6602    0.6455     20000

save best model to best_single_tower_gru.pth (Val F1: 0.6455)


Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.6857


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6607
Validation F1 Score: 0.6398
              precision    recall  f1-score   support

 Relevance 0     0.5533    0.3915    0.4585      4894
 Relevance 1     0.7046    0.8395    0.7662     12592
 Relevance 2     0.4743    0.2896    0.3596      2514

    accuracy                         0.6607     20000
   macro avg     0.5774    0.5069    0.5281     20000
weighted avg     0.6386    0.6607    0.6398     20000

--- Training Finished (SINGLE) in 638.16s ---
Best Validation F1 Score (SINGLE): 0.6455

Loading best single-tower GRU model for final evaluation...

--- Starting Evaluation on Test Set (SINGLE) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6640
Test Set F1 Score: 0.6497
              precision    recall  f1-score   support

 Relevance 0     0.5445    0.4351    0.4837      1209
 Relevance 1     0.7197    0.8192    0.7662      3159
 Relevance 2     0.4703    0.3259    0.3850       632

    accuracy                         0.6640      5000
   macro avg     0.5782    0.5268    0.5450      5000
weighted avg     0.6458    0.6640    0.6497      5000

--- Evaluation Finished (SINGLE) ---


#### --- Two Tower GRU ---

In [15]:
print("\n" + "="*20 + " TWO TOWER GRU " + "="*20)
two_tower_gru_model = TwoTowerGRU(
    vocab_size=vocab_size, embedding_dim=CONFIG["embedding_dim"], hidden_dim=CONFIG["hidden_dim"],
    num_classes=CONFIG["num_classes"], rnn_layers=CONFIG["rnn_layers"], dropout_prob=CONFIG["dropout_prob"],
    bidirectional=CONFIG["rnn_bidirectional"],
    pad_idx=pad_idx
).to(CONFIG["device"])
optimizer_two_gru = optim.Adam(two_tower_gru_model.parameters(), lr=CONFIG["learning_rate"])

# Train Two Tower GRU
train_model(two_tower_gru_model, 'two', train_loader, val_loader, optimizer_two_gru, criterion,
            CONFIG["num_epochs"], CONFIG["device"], CONFIG["save_path_two_gru"])

# Evaluate Two Tower GRU
print("\nLoading best two-tower GRU model for final evaluation...")
try:
    two_tower_gru_model.load_state_dict(torch.load(CONFIG["save_path_two_gru"], map_location=CONFIG["device"]))
except Exception as e:
      print(f"Could not load saved two-tower GRU model state: {e}. Evaluating model state after last epoch.")
accuracy_gru_two, f1_gru_two, _, _ = evaluate_model(two_tower_gru_model, 'two', test_loader, criterion, CONFIG["device"])



--- Starting Training (TWO) ---


Epoch 1/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 1/10 Training: loss=0.8347


Epoch 1/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6507
Validation F1 Score: 0.5941
              precision    recall  f1-score   support

 Relevance 0     0.5598    0.2131    0.3087      4894
 Relevance 1     0.6693    0.9171    0.7739     12592
 Relevance 2     0.4785    0.1683    0.2490      2514

    accuracy                         0.6507     20000
   macro avg     0.5692    0.4328    0.4439     20000
weighted avg     0.6186    0.6507    0.5941     20000

save best model to best_two_tower_gru.pth (Val F1: 0.5941)


Epoch 2/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 2/10 Training: loss=0.7873


Epoch 2/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6550
Validation F1 Score: 0.6231
              precision    recall  f1-score   support

 Relevance 0     0.5555    0.2934    0.3840      4894
 Relevance 1     0.6895    0.8677    0.7684     12592
 Relevance 2     0.4700    0.2932    0.3611      2514

    accuracy                         0.6550     20000
   macro avg     0.5717    0.4848    0.5045     20000
weighted avg     0.6291    0.6550    0.6231     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6231)


Epoch 3/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 3/10 Training: loss=0.7628


Epoch 3/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6584
Validation F1 Score: 0.6263
              precision    recall  f1-score   support

 Relevance 0     0.5595    0.3163    0.4041      4894
 Relevance 1     0.6901    0.8710    0.7701     12592
 Relevance 2     0.4877    0.2597    0.3390      2514

    accuracy                         0.6584     20000
   macro avg     0.5791    0.4824    0.5044     20000
weighted avg     0.6327    0.6584    0.6263     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6263)


Epoch 4/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 4/10 Training: loss=0.7443


Epoch 4/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6622
Validation F1 Score: 0.6304
              precision    recall  f1-score   support

 Relevance 0     0.5891    0.2924    0.3908      4894
 Relevance 1     0.6938    0.8750    0.7739     12592
 Relevance 2     0.4698    0.3158    0.3777      2514

    accuracy                         0.6622     20000
   macro avg     0.5842    0.4944    0.5142     20000
weighted avg     0.6400    0.6622    0.6304     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6304)


Epoch 5/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 5/10 Training: loss=0.7270


Epoch 5/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6635
Validation F1 Score: 0.6306
              precision    recall  f1-score   support

 Relevance 0     0.5661    0.3641    0.4432      4894
 Relevance 1     0.6938    0.8729    0.7731     12592
 Relevance 2     0.4921    0.1977    0.2821      2514

    accuracy                         0.6635     20000
   macro avg     0.5840    0.4782    0.4994     20000
weighted avg     0.6372    0.6635    0.6306     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6306)


Epoch 6/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 6/10 Training: loss=0.7129


Epoch 6/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6637
Validation F1 Score: 0.6422
              precision    recall  f1-score   support

 Relevance 0     0.5665    0.3754    0.4515      4894
 Relevance 1     0.7049    0.8459    0.7690     12592
 Relevance 2     0.4778    0.3126    0.3780      2514

    accuracy                         0.6637     20000
   macro avg     0.5830    0.5113    0.5328     20000
weighted avg     0.6425    0.6637    0.6422     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6422)


Epoch 7/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 7/10 Training: loss=0.7009


Epoch 7/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6635
Validation F1 Score: 0.6433
              precision    recall  f1-score   support

 Relevance 0     0.5544    0.4291    0.4838      4894
 Relevance 1     0.7098    0.8359    0.7677     12592
 Relevance 2     0.4657    0.2562    0.3305      2514

    accuracy                         0.6635     20000
   macro avg     0.5766    0.5071    0.5273     20000
weighted avg     0.6411    0.6635    0.6433     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6433)


Epoch 8/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 8/10 Training: loss=0.6886


Epoch 8/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6641
Validation F1 Score: 0.6466
              precision    recall  f1-score   support

 Relevance 0     0.5604    0.4150    0.4769      4894
 Relevance 1     0.7116    0.8314    0.7669     12592
 Relevance 2     0.4697    0.3111    0.3743      2514

    accuracy                         0.6641     20000
   macro avg     0.5806    0.5192    0.5393     20000
weighted avg     0.6442    0.6641    0.6466     20000

save best model to best_two_tower_gru.pth (Val F1: 0.6466)


Epoch 9/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 9/10 Training: loss=0.6803


Epoch 9/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6677
Validation F1 Score: 0.6435
              precision    recall  f1-score   support

 Relevance 0     0.5911    0.3625    0.4494      4894
 Relevance 1     0.7029    0.8586    0.7730     12592
 Relevance 2     0.4753    0.3059    0.3722      2514

    accuracy                         0.6677     20000
   macro avg     0.5898    0.5090    0.5315     20000
weighted avg     0.6470    0.6677    0.6435     20000



Epoch 10/10 Training:   0%|          | 0/2813 [00:00<?, ?it/s]

Epoch 10/10 Training: loss=0.6704


Epoch 10/10 Evaluating:   0%|          | 0/313 [00:00<?, ?it/s]

Validation Accuracy: 0.6658
Validation F1 Score: 0.6419
              precision    recall  f1-score   support

 Relevance 0     0.5866    0.3792    0.4607      4894
 Relevance 1     0.7031    0.8544    0.7714     12592
 Relevance 2     0.4573    0.2788    0.3464      2514

    accuracy                         0.6658     20000
   macro avg     0.5823    0.5042    0.5262     20000
weighted avg     0.6437    0.6658    0.6419     20000

--- Training Finished (TWO) in 661.87s ---
Best Validation F1 Score (TWO): 0.6466

Loading best two-tower GRU model for final evaluation...

--- Starting Evaluation on Test Set (TWO) ---


Evaluating Test Set:   0%|          | 0/79 [00:00<?, ?it/s]

Test Set Accuracy: 0.6708
Test Set F1 Score: 0.6533
              precision    recall  f1-score   support

 Relevance 0     0.5727    0.4202    0.4847      1209
 Relevance 1     0.7158    0.8370    0.7716      3159
 Relevance 2     0.4821    0.3196    0.3844       632

    accuracy                         0.6708      5000
   macro avg     0.5902    0.5256    0.5469      5000
weighted avg     0.6516    0.6708    0.6533      5000

--- Evaluation Finished (TWO) ---
