# Encoder-Only Transformer for Classification

This notebook implements a bidirectional encoder-only transformer (similar to BERT) for text classification tasks. Unlike decoder-only transformers (like GPT), encoder-only models use bidirectional attention, allowing each token to attend to all other tokens in the sequence.

## Key Differences from Decoder-Only Transformers

1. **Bidirectional Attention**: No causal masking - tokens can see both past and future tokens
2. **Classification Task**: Trained for classification rather than next-token prediction
3. **Pooling Strategy**: Uses CLS token or mean pooling to create sequence-level representations
4. **No Autoregressive Generation**: Processes entire sequences at once


In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tokenizers import Tokenizer
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace
import math
import pickle
import os
import csv
import json
from datetime import datetime
import matplotlib.pyplot as plt
from datasets import load_dataset
import random
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import pandas as pd


## Dataset Loading

We'll reuse the ArXiv dataset loading functions from the decoder-only notebook. For classification, we categorize abstracts based on **topic keywords** into different research areas:

- **Class 0: Neural Networks & Deep Learning** - Papers about neural networks, CNNs, deep learning architectures
- **Class 1: Natural Language Processing** - Papers about NLP, language models, text processing
- **Class 2: Computer Vision** - Papers about image processing, object detection, visual recognition
- **Class 3: Reinforcement Learning** - Papers about RL, agents, policies, rewards


In [None]:
# DATASET LOADING

def load_arxiv_huggingface(num_samples=10000):
    """Load ML ArXiv papers from Hugging Face."""
    print("=" * 60)
    print("LOADING ARXIV DATASET FROM HUGGING FACE")
    print("=" * 60)
    print(f"Dataset: CShorten/ML-ArXiv-Papers")
    print(f"Target samples: {num_samples}\n")

    try:
        print("Downloading dataset...")
        dataset = load_dataset("CShorten/ML-ArXiv-Papers", split="train")
        print(f"[OK] Dataset loaded: {len(dataset)} papers available")

        abstracts = []
        for i, paper in enumerate(dataset):
            if len(abstracts) >= num_samples:
                break
            abstract = paper['abstract'].strip()
            if 100 < len(abstract) < 5000:
                abstract = ' '.join(abstract.split())
                abstracts.append(abstract)

        print(f"[OK] Extracted {len(abstracts)} quality abstracts")
        print(f"[OK] Average length: {sum(len(a.split()) for a in abstracts) / len(abstracts):.1f} words")
        print("=" * 60)
        return abstracts

    except Exception as e:
        print(f"[ERROR] Loading dataset: {e}")
        print("\nFalling back to synthetic generation...")
        return generate_synthetic_abstracts(num_samples)


def generate_synthetic_abstracts(num_samples=1000):
    """Generate synthetic academic abstracts as fallback."""
    print("\n[WARNING] Generating synthetic abstracts as fallback")

    templates = [
        "We present a novel approach to {topic} using {method}. "
        "Our experiments demonstrate that {result}. "
        "The proposed framework {contribution} and achieves {performance}. "
        "We analyze the {analysis} and provide {insights}. "
        "Experimental results on {datasets} show that our method {comparison}.",
        "This paper introduces a new method for {topic} based on {method}. "
        "The key innovation is {innovation}. "
        "We evaluate our approach on {datasets} and show that {result}. "
        "Compared to existing methods, our approach {comparison}. "
        "The main contributions include {contribution}.",
        "In this work, we propose {method} for {topic}. "
        "Our method {advantage} while {constraint}. "
        "We demonstrate {result} through extensive experiments on {datasets}. "
        "The results show that {performance}. "
        "We also provide {analysis} of {insights}.",
    ]

    components = {
        'topic': ["machine learning", "deep neural networks", "natural language processing",
                 "computer vision", "reinforcement learning", "optimization algorithms"],
        'method': ["transformer architectures", "self-supervised learning", "contrastive learning",
                  "adversarial training", "multi-task learning", "few-shot learning"],
        'result': ["significant improvements", "state-of-the-art performance", "competitive results"],
        'contribution': ["combines multiple techniques", "introduces a novel loss function",
                        "provides theoretical guarantees"],
        'performance': ["10% improvement in accuracy", "better sample efficiency", "superior generalization"],
        'datasets': ["standard benchmarks", "multiple datasets", "real-world scenarios"],
        'comparison': ["outperforms previous methods", "achieves comparable results", "demonstrates robustness"],
        'analysis': ["model behavior", "learned representations", "convergence properties"],
        'insights': ["theoretical justification", "practical guidelines", "design principles"],
        'innovation': ["a new architecture design", "an improved training procedure", "a novel technique"],
        'advantage': ["improves accuracy", "reduces computational cost", "enhances interpretability"],
        'constraint': ["maintaining efficiency", "preserving simplicity", "ensuring stability"]
    }

    abstracts = []
    for _ in range(num_samples):
        template = random.choice(templates)
        kwargs = {k: random.choice(v) for k, v in components.items()}
        abstracts.append(template.format(**kwargs))

    print(f"[OK] Generated {len(abstracts)} synthetic abstracts")
    return abstracts


## Dataset Class for Classification

The dataset class is adapted for topic-based classification. We create labels by detecting **topic keywords** in abstracts, categorizing them into research areas:

| Class | Topic | Example Keywords |
|-------|-------|------------------|
| 0 | Neural Networks & Deep Learning | neural network, deep learning, CNN, RNN, LSTM |
| 1 | Natural Language Processing | NLP, language model, text, word embedding, BERT |
| 2 | Computer Vision | image, vision, object detection, segmentation |
| 3 | Reinforcement Learning | reinforcement learning, agent, policy, reward |

We also add a CLS token at the beginning of each sequence for classification.


In [None]:
# DATASET CLASS FOR CLASSIFICATION (TOPIC-BASED)

TOPIC_CATEGORIES = {
    0: {
        'name': 'Neural Networks & Deep Learning',
        'keywords': [
            'neural network', 'deep learning', 'cnn', 'convolutional neural',
            'rnn', 'recurrent neural', 'lstm', 'gru', 'feedforward',
            'backpropagation', 'gradient descent', 'activation function',
            'batch normalization', 'dropout', 'deep neural', 'mlp',
            'perceptron', 'autoencoder', 'variational autoencoder', 'vae',
            'generative adversarial', 'gan', 'discriminator', 'generator'
        ]
    },
    1: {
        'name': 'Natural Language Processing',
        'keywords': [
            'natural language', 'nlp', 'language model', 'text classification',
            'sentiment analysis', 'named entity', 'ner', 'machine translation',
            'word embedding', 'word2vec', 'glove', 'bert', 'gpt', 'transformer',
            'attention mechanism', 'sequence to sequence', 'seq2seq',
            'text generation', 'question answering', 'summarization',
            'tokenization', 'parsing', 'syntax', 'semantic', 'corpus',
            'vocabulary', 'embedding', 'language understanding'
        ]
    },
    2: {
        'name': 'Computer Vision',
        'keywords': [
            'computer vision', 'image classification', 'object detection',
            'image segmentation', 'semantic segmentation', 'instance segmentation',
            'face recognition', 'facial', 'pose estimation', 'image recognition',
            'visual', 'pixel', 'convolution', 'resnet', 'vgg', 'inception',
            'yolo', 'faster rcnn', 'mask rcnn', 'feature extraction',
            'image processing', 'video', 'optical flow', 'depth estimation',
            '3d reconstruction', 'scene understanding', 'visual question'
        ]
    },
    3: {
        'name': 'Reinforcement Learning',
        'keywords': [
            'reinforcement learning', 'rl', 'q-learning', 'deep q',
            'dqn', 'policy gradient', 'actor critic', 'a2c', 'a3c', 'ppo',
            'proximal policy', 'reward', 'agent', 'environment', 'mdp',
            'markov decision', 'exploration', 'exploitation', 'epsilon greedy',
            'value function', 'bellman', 'temporal difference', 'td learning',
            'monte carlo', 'sarsa', 'multi-agent', 'game playing', 'control'
        ]
    }
}

TOPIC_NAMES = [TOPIC_CATEGORIES[i]['name'] for i in range(len(TOPIC_CATEGORIES))]


def classify_by_topic(abstract):
    """Classify an abstract based on topic keywords."""
    abstract_lower = abstract.lower()
    scores = {}
    for class_id, category in TOPIC_CATEGORIES.items():
        score = sum(1 for keyword in category['keywords'] if keyword in abstract_lower)
        scores[class_id] = score
    max_score = max(scores.values())
    if max_score == 0:
        return 0
    return max(scores, key=scores.get)


class ArXivClassificationDataset(Dataset):
    """Dataset for arXiv abstracts with topic-based classification labels."""

    def __init__(self, abstracts, tokenizer, max_len=256, num_classes=4):
        self.abstracts = abstracts
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.num_classes = num_classes
        self.cls_id = tokenizer.token_to_id("[CLS]") if tokenizer.token_to_id("[CLS]") is not None else tokenizer.token_to_id("[BOS]")
        self.pad_id = tokenizer.token_to_id("[PAD]")
        self.sep_id = tokenizer.token_to_id("[SEP]") if tokenizer.token_to_id("[SEP]") is not None else tokenizer.token_to_id("[EOS]")
        self.labels = self._create_topic_labels(abstracts)
        self._print_class_distribution()

    def _create_topic_labels(self, abstracts):
        """Create classification labels based on topic keywords."""
        return [classify_by_topic(abstract) for abstract in abstracts]
    
    def _print_class_distribution(self):
        """Print the distribution of classes."""
        from collections import Counter
        distribution = Counter(self.labels)
        print("\nTopic Distribution:")
        for class_id in sorted(distribution.keys()):
            count = distribution[class_id]
            percentage = count / len(self.labels) * 100
            print(f"  - {TOPIC_NAMES[class_id]}: {count} ({percentage:.1f}%)")

    def __len__(self):
        return len(self.abstracts)

    def __getitem__(self, idx):
        text = self.abstracts[idx]
        encoding = self.tokenizer.encode(text)
        tokens = encoding.ids
        if len(tokens) > self.max_len - 2:
            tokens = tokens[:self.max_len - 2]
        tokens = [self.cls_id] + tokens + [self.sep_id]
        pad_len = self.max_len - len(tokens)
        tokens = tokens + [self.pad_id] * pad_len
        attention_mask = [1 if t != self.pad_id else 0 for t in tokens]
        return {
            'input_ids': torch.tensor(tokens, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'label': torch.tensor(self.labels[idx], dtype=torch.long)
        }


def build_tokenizer(abstracts, vocab_size=8000):
    """Train WordPiece tokenizer from scratch with CLS and SEP tokens."""
    tokenizer = Tokenizer(WordPiece(unk_token="[UNK]"))
    tokenizer.pre_tokenizer = Whitespace()

    trainer = WordPieceTrainer(
        vocab_size=vocab_size,
        special_tokens=["[PAD]", "[CLS]", "[SEP]", "[UNK]", "[BOS]", "[EOS]"]
    )

    tokenizer.train_from_iterator(abstracts, trainer)
    return tokenizer


### Positional Encoding

The positional encoding is identical to the decoder-only transformer. It uses sinusoidal functions to encode position information.


In [None]:
# MODEL COMPONENTS

class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding."""

    def __init__(self, d_model, max_len=256):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]


### Multi-Head Self-Attention (Bidirectional)

**Key Difference from Decoder-Only**: This attention mechanism is **bidirectional** - tokens can attend to all other tokens in the sequence, not just previous ones. There is no causal masking.

The attention mechanism computes:
- Query (Q), Key (K), Value (V) matrices
- Attention scores = QK^T / √d_k
- Apply padding mask (but NO causal mask)
- Softmax to get attention weights
- Weighted sum of values


In [None]:
class MultiHeadAttention(nn.Module):
    """Bidirectional multi-head self-attention (no causal masking)."""

    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x, mask=None):
        """Forward pass with optional padding mask."""
        batch_size, seq_len, _ = x.shape
        Q = self.W_q(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(x).view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = torch.nan_to_num(attn_weights)
        attn_output = torch.matmul(attn_weights, V)
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.W_o(attn_output)
        return output


### Feed-Forward Network

The feed-forward network is identical to the decoder-only transformer - a two-layer MLP with ReLU activation.


In [None]:
class FeedForward(nn.Module):
    """Position-wise feed-forward network."""

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear2(F.relu(self.linear1(x)))


### Encoder Block

The encoder block combines:
1. **Bidirectional Multi-Head Attention** (no causal masking)
2. **Layer Normalization** with residual connection
3. **Feed-Forward Network**
4. **Layer Normalization** with residual connection
5. **Dropout** for regularization

This is similar to the decoder block, but uses bidirectional attention instead of masked attention.


In [None]:
class EncoderBlock(nn.Module):
    """Single encoder block with bidirectional attention."""

    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        """Forward pass with attention and feed-forward layers."""
        attn_output = self.attention(x, mask)
        x = self.norm1(x + self.dropout(attn_output))
        ffn_output = self.ffn(x)
        x = self.norm2(x + self.dropout(ffn_output))
        return x


### Encoder-Only Transformer

The complete encoder-only transformer model includes:

1. **Token Embeddings**: Maps token IDs to dense vectors
2. **Positional Encoding**: Adds position information
3. **Stack of Encoder Blocks**: Multiple layers of bidirectional attention
4. **Pooling Layer**: Extracts sequence-level representation
   - **CLS token pooling**: Uses the first token (CLS) representation
   - **Mean pooling**: Averages all token representations
5. **Classification Head**: Linear layer mapping to number of classes

**Key Architectural Choices**:
- All hyperparameters are configurable (d_model, num_layers, num_heads, etc.)
- Supports both CLS token and mean pooling strategies
- Designed for classification tasks


In [None]:
class EncoderOnlyTransformer(nn.Module):
    """Encoder-only Transformer for classification tasks."""

    def __init__(self, vocab_size, d_model=256, num_layers=4, num_heads=8,
                 d_ff=1024, max_seq_len=256, num_classes=3, dropout=0.1, pooling_type='cls'):
        """Initialize encoder transformer with configurable architecture."""
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.pooling_type = pooling_type
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_seq_len)
        self.encoder_blocks = nn.ModuleList([
            EncoderBlock(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(d_model, num_classes)
        )
        self.dropout = nn.Dropout(dropout)
        self._init_weights()

    def _init_weights(self):
        """Initialize weights using Xavier uniform."""
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, input_ids, attention_mask=None):
        """Forward pass returning classification logits."""
        x = self.token_embedding(input_ids)
        x = self.pos_encoding(x)
        x = self.dropout(x)
        for encoder_block in self.encoder_blocks:
            x = encoder_block(x, attention_mask)
        if self.pooling_type == 'cls':
            pooled = x[:, 0, :]
        else:
            if attention_mask is not None:
                mask_expanded = attention_mask.unsqueeze(-1).float()
                x_masked = x * mask_expanded
                sum_pooled = x_masked.sum(dim=1)
                lengths = attention_mask.sum(dim=1, keepdim=True).float()
                pooled = sum_pooled / lengths
            else:
                pooled = x.mean(dim=1)
        logits = self.classifier(pooled)
        return logits


In [None]:
# This cell is kept for reference but config is defined near main() for easy editing


## Training Functions

Training functions for classification tasks. We use CrossEntropyLoss for multi-class classification and track both loss and accuracy metrics.


In [None]:
# TRAINING FUNCTIONS

def train_epoch(model, dataloader, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    batch_losses = []
    progress_bar = tqdm(dataloader, desc="Training", leave=False)

    for batch in progress_bar:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        logits = model(input_ids, attention_mask)
        loss = F.cross_entropy(logits, labels)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        predictions = torch.argmax(logits, dim=-1)
        correct = (predictions == labels).sum().item()
        total_correct += correct
        total_samples += labels.size(0)
        total_loss += loss.item()
        batch_losses.append(loss.item())
        accuracy = correct / labels.size(0)
        progress_bar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy:.4f}'})

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_correct / total_samples
    return avg_loss, avg_accuracy, batch_losses


def evaluate_model(model, dataloader, device):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0
    total_correct = 0
    total_samples = 0
    all_predictions = []
    all_labels = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            logits = model(input_ids, attention_mask)
            loss = F.cross_entropy(logits, labels)
            predictions = torch.argmax(logits, dim=-1)
            correct = (predictions == labels).sum().item()
            total_correct += correct
            total_samples += labels.size(0)
            total_loss += loss.item()
            all_predictions.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    avg_loss = total_loss / len(dataloader)
    avg_accuracy = total_correct / total_samples
    return avg_loss, avg_accuracy, all_predictions, all_labels


## Visualization Functions

Visualization functions for training curves, accuracy metrics, and confusion matrices.


In [None]:
# VISUALIZATION FUNCTIONS

def plot_training_curves(train_losses, train_accuracies, val_losses=None,
                         val_accuracies=None, save_dir=None):
    """Plot training curves and save to files."""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Encoder-Only Transformer Training Curves', fontsize=16, fontweight='bold')
    epochs = range(1, len(train_losses) + 1)

    # Loss plot
    ax1 = axes[0, 0]
    ax1.plot(epochs, train_losses, 'b-o', linewidth=2, markersize=6, label='Training Loss')
    if val_losses is not None:
        ax1.plot(epochs, val_losses, 'r-s', linewidth=2, markersize=6, label='Validation Loss')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training Loss per Epoch', fontsize=13, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()

    # Accuracy plot
    ax2 = axes[0, 1]
    ax2.plot(epochs, train_accuracies, 'g-o', linewidth=2, markersize=6, label='Training Accuracy')
    if val_accuracies is not None:
        ax2.plot(epochs, val_accuracies, 'm-s', linewidth=2, markersize=6, label='Validation Accuracy')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy', fontsize=12)
    ax2.set_title('Training Accuracy per Epoch', fontsize=13, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    ax2.legend()
    ax2.set_ylim([0, 1])

    # Loss log scale
    ax3 = axes[1, 0]
    ax3.semilogy(epochs, train_losses, 'b-o', linewidth=2, markersize=6, label='Training Loss')
    if val_losses is not None:
        ax3.semilogy(epochs, val_losses, 'r-s', linewidth=2, markersize=6, label='Validation Loss')
    ax3.set_xlabel('Epoch', fontsize=12)
    ax3.set_ylabel('Loss (log scale)', fontsize=12)
    ax3.set_title('Loss Convergence (Log Scale)', fontsize=13, fontweight='bold')
    ax3.grid(True, alpha=0.3, which='both')
    ax3.legend()

    # Accuracy comparison
    ax4 = axes[1, 1]
    if val_accuracies is not None:
        ax4.plot(epochs, train_accuracies, 'g-o', linewidth=2, markersize=6, label='Training Accuracy', alpha=0.7)
        ax4.plot(epochs, val_accuracies, 'm-s', linewidth=2, markersize=6, label='Validation Accuracy')
        ax4.fill_between(epochs, train_accuracies, val_accuracies, alpha=0.2)
    else:
        ax4.plot(epochs, train_accuracies, 'g-o', linewidth=2, markersize=6, label='Training Accuracy')
    ax4.set_xlabel('Epoch', fontsize=12)
    ax4.set_ylabel('Accuracy', fontsize=12)
    ax4.set_title('Accuracy Over Time', fontsize=13, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    ax4.legend()
    ax4.set_ylim([0, 1])

    plt.tight_layout()

    if save_dir:
        curves_path = os.path.join(save_dir, "training_curves.png")
        plt.savefig(curves_path, dpi=300, bbox_inches='tight')
        print(f"[OK] Training curves saved to {curves_path}")

    plt.show()

    # Print statistics
    print("\n" + "="*60)
    print("TRAINING STATISTICS")
    print("="*60)
    print(f"Initial Loss: {train_losses[0]:.4f}")
    print(f"Final Loss: {train_losses[-1]:.4f}")
    print(f"Loss Reduction: {((train_losses[0] - train_losses[-1]) / train_losses[0] * 100):.2f}%")
    print(f"Best Loss: {min(train_losses):.4f} (Epoch {train_losses.index(min(train_losses)) + 1})")
    print(f"\nInitial Accuracy: {train_accuracies[0]:.4f}")
    print(f"Final Accuracy: {train_accuracies[-1]:.4f}")
    print(f"Accuracy Improvement: {((train_accuracies[-1] - train_accuracies[0]) * 100):.2f}%")
    if val_accuracies:
        print(f"Best Validation Accuracy: {max(val_accuracies):.4f} (Epoch {val_accuracies.index(max(val_accuracies)) + 1})")
    print("="*60)


def plot_confusion_matrix(y_true, y_pred, num_classes, class_names=None, save_dir=None):
    """Plot confusion matrix and save to file."""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names or [f'Class {i}' for i in range(num_classes)],
                yticklabels=class_names or [f'Class {i}' for i in range(num_classes)])
    plt.title('Confusion Matrix', fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    
    if save_dir:
        cm_path = os.path.join(save_dir, "confusion_matrix.png")
        plt.savefig(cm_path, dpi=300, bbox_inches='tight')
        print(f"[OK] Confusion matrix saved to {cm_path}")
    
    plt.show()
    
    # Print classification report
    print("\n" + "="*60)
    print("CLASSIFICATION REPORT")
    print("="*60)
    report = classification_report(y_true, y_pred, 
                                   target_names=class_names or [f'Class {i}' for i in range(num_classes)])
    print(report)
    print("="*60)
    
    return cm


def save_training_metrics_csv(save_dir, train_losses, train_accuracies, val_losses, val_accuracies, config):
    """Save all training metrics to CSV files for report generation."""
    # Epoch-by-epoch metrics
    epochs_data = {
        'epoch': list(range(1, len(train_losses) + 1)),
        'train_loss': train_losses,
        'train_accuracy': train_accuracies,
        'val_loss': val_losses,
        'val_accuracy': val_accuracies
    }
    df_epochs = pd.DataFrame(epochs_data)
    epochs_csv = os.path.join(save_dir, "training_metrics_epochs.csv")
    df_epochs.to_csv(epochs_csv, index=False)
    print(f"[OK] Epoch metrics saved to {epochs_csv}")
    
    # Summary statistics
    summary_data = {
        'metric': [
            'initial_train_loss', 'final_train_loss', 'best_train_loss', 'best_train_loss_epoch',
            'initial_train_accuracy', 'final_train_accuracy', 'best_train_accuracy', 'best_train_accuracy_epoch',
            'initial_val_loss', 'final_val_loss', 'best_val_loss', 'best_val_loss_epoch',
            'initial_val_accuracy', 'final_val_accuracy', 'best_val_accuracy', 'best_val_accuracy_epoch',
            'loss_reduction_pct', 'accuracy_improvement_pct', 'total_epochs'
        ],
        'value': [
            train_losses[0], train_losses[-1], min(train_losses), train_losses.index(min(train_losses)) + 1,
            train_accuracies[0], train_accuracies[-1], max(train_accuracies), train_accuracies.index(max(train_accuracies)) + 1,
            val_losses[0], val_losses[-1], min(val_losses), val_losses.index(min(val_losses)) + 1,
            val_accuracies[0], val_accuracies[-1], max(val_accuracies), val_accuracies.index(max(val_accuracies)) + 1,
            (train_losses[0] - train_losses[-1]) / train_losses[0] * 100,
            (train_accuracies[-1] - train_accuracies[0]) * 100,
            len(train_losses)
        ]
    }
    df_summary = pd.DataFrame(summary_data)
    summary_csv = os.path.join(save_dir, "training_summary.csv")
    df_summary.to_csv(summary_csv, index=False)
    print(f"[OK] Summary statistics saved to {summary_csv}")
    
    # Config as CSV
    config_data = {'parameter': list(config.keys()), 'value': list(config.values())}
    df_config = pd.DataFrame(config_data)
    config_csv = os.path.join(save_dir, "config.csv")
    df_config.to_csv(config_csv, index=False)
    print(f"[OK] Config saved to {config_csv}")
    
    return epochs_csv, summary_csv, config_csv


def save_classification_report_csv(save_dir, y_true, y_pred, class_names):
    """Save classification report as CSV."""
    report_dict = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    df_report = pd.DataFrame(report_dict).transpose()
    report_csv = os.path.join(save_dir, "classification_report.csv")
    df_report.to_csv(report_csv)
    print(f"[OK] Classification report saved to {report_csv}")
    
    # Save confusion matrix as CSV
    cm = confusion_matrix(y_true, y_pred)
    df_cm = pd.DataFrame(cm, index=class_names, columns=class_names)
    cm_csv = os.path.join(save_dir, "confusion_matrix.csv")
    df_cm.to_csv(cm_csv)
    print(f"[OK] Confusion matrix CSV saved to {cm_csv}")
    
    return report_csv, cm_csv


## Model Saving Functions

Functions to save the trained model with a timestamped folder and metadata file containing training information.


In [None]:
# MODEL SAVING FUNCTIONS

def get_model_save_dir(base_dir="models"):
    """Generate a timestamped directory name for saving the model."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_dir = os.path.join(base_dir, f"encoder_{timestamp}")
    return save_dir


def save_model(model, tokenizer, config, save_dir, training_history=None, 
               best_epoch=None, best_val_accuracy=None):
    """Save the trained model, tokenizer, config, and metadata to a timestamped folder."""
    os.makedirs(save_dir, exist_ok=True)
    
    model_path = os.path.join(save_dir, "model.pt")
    torch.save(model.state_dict(), model_path)
    print(f"[OK] Model weights saved to {model_path}")
    
    tokenizer_path = os.path.join(save_dir, "tokenizer.json")
    tokenizer.save(tokenizer_path)
    print(f"[OK] Tokenizer saved to {tokenizer_path}")
    
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    print(f"[OK] Config saved to {config_path}")
    
    metadata = {
        "model_type": "EncoderOnlyTransformer",
        "task": "Topic Classification",
        "classes": TOPIC_NAMES,
        "num_classes": len(TOPIC_NAMES),
        "created_at": datetime.now().isoformat(),
        "architecture": {
            "vocab_size": config['vocab_size'],
            "d_model": config['d_model'],
            "num_layers": config['num_layers'],
            "num_heads": config['num_heads'],
            "d_ff": config['d_ff'],
            "max_seq_len": config['max_seq_len'],
            "dropout": config['dropout'],
            "pooling_type": config['pooling_type']
        },
        "training": {
            "num_epochs": config['num_epochs'],
            "batch_size": config['batch_size'],
            "learning_rate": config['learning_rate'],
            "weight_decay": config['weight_decay'],
            "num_samples": config['num_samples'],
            "train_split": config['train_split']
        },
        "num_parameters": sum(p.numel() for p in model.parameters()),
        "model_size_mb": sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024
    }
    
    if best_epoch is not None:
        metadata["best_epoch"] = best_epoch
    if best_val_accuracy is not None:
        metadata["best_val_accuracy"] = best_val_accuracy
    
    if training_history is not None:
        metadata["training_history"] = {
            "train_losses": training_history.get('train_losses', []),
            "train_accuracies": training_history.get('train_accuracies', []),
            "val_losses": training_history.get('val_losses', []),
            "val_accuracies": training_history.get('val_accuracies', [])
        }
        if training_history.get('train_losses'):
            metadata["final_train_loss"] = training_history['train_losses'][-1]
        if training_history.get('train_accuracies'):
            metadata["final_train_accuracy"] = training_history['train_accuracies'][-1]
        if training_history.get('val_losses'):
            metadata["final_val_loss"] = training_history['val_losses'][-1]
        if training_history.get('val_accuracies'):
            metadata["final_val_accuracy"] = training_history['val_accuracies'][-1]
    
    metadata_path = os.path.join(save_dir, "metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    print(f"[OK] Metadata saved to {metadata_path}")
    
    print(f"\n[OK] All files saved to: {save_dir}")
    return save_dir


def load_saved_model(save_dir, device='cuda'):
    """Load a saved model from a timestamped directory."""
    print(f"Loading model from {save_dir}...")
    
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, 'r') as f:
        config = json.load(f)
    print(f"[OK] Config loaded")
    
    tokenizer_path = os.path.join(save_dir, "tokenizer.json")
    tokenizer = Tokenizer.from_file(tokenizer_path)
    print(f"[OK] Tokenizer loaded")
    
    metadata_path = os.path.join(save_dir, "metadata.json")
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    print(f"[OK] Metadata loaded")
    
    model = EncoderOnlyTransformer(
        vocab_size=tokenizer.get_vocab_size(),
        d_model=config['d_model'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        d_ff=config['d_ff'],
        max_seq_len=config['max_seq_len'],
        num_classes=config['num_classes'],
        dropout=config['dropout'],
        pooling_type=config['pooling_type']
    )
    
    model_path = os.path.join(save_dir, "model.pt")
    model.load_state_dict(torch.load(model_path, map_location=device))
    model = model.to(device)
    model.eval()
    print(f"[OK] Model weights loaded")
    
    print(f"\n[OK] Model loaded successfully!")
    print(f"  - Created: {metadata.get('created_at', 'Unknown')}")
    print(f"  - Best Epoch: {metadata.get('best_epoch', 'Unknown')}")
    best_val = metadata.get('best_val_accuracy', 0)
    print(f"  - Best Val Accuracy: {best_val:.4f}" if best_val else "  - Best Val Accuracy: Unknown")
    
    return model, tokenizer, config, metadata


In [None]:
# CONFIGURATION - Edit these values and re-run to train with different settings

def get_default_config():
    """Get default configuration with all hyperparameters."""
    return {
        'd_model': 256,           # embedding dimension
        'num_layers': 4,          # number of encoder blocks
        'num_heads': 8,           # number of attention heads
        'd_ff': 1024,             # feed-forward dimension
        'max_seq_len': 256,       # maximum sequence length
        'vocab_size': 8000,       # vocabulary size
        'num_classes': 4,         # number of topic classes (NN, NLP, CV, RL)
        'dropout': 0.1,           # dropout rate
        'pooling_type': 'cls',    # 'cls' or 'mean' pooling
        'batch_size': 32,         # batch size
        'num_epochs': 20,         # number of training epochs
        'learning_rate': 3e-4,    # learning rate
        'weight_decay': 0.01,     # weight decay for regularization
        'num_samples': 10000,     # number of samples to load
        'train_split': 0.9,       # train/validation split ratio
    }


def print_config(config):
    """Print configuration in a readable format."""
    print("=" * 60)
    print("MODEL CONFIGURATION")
    print("=" * 60)
    print("\nArchitecture:")
    print(f"  - Embedding Dimension (d_model): {config['d_model']}")
    print(f"  - Number of Layers: {config['num_layers']}")
    print(f"  - Number of Attention Heads: {config['num_heads']}")
    print(f"  - Feed-Forward Dimension: {config['d_ff']}")
    print(f"  - Max Sequence Length: {config['max_seq_len']}")
    print(f"  - Vocabulary Size: {config['vocab_size']}")
    print(f"  - Number of Classes: {config['num_classes']}")
    print(f"  - Dropout Rate: {config['dropout']}")
    print(f"  - Pooling Type: {config['pooling_type']}")
    print("\nTraining:")
    print(f"  - Batch Size: {config['batch_size']}")
    print(f"  - Number of Epochs: {config['num_epochs']}")
    print(f"  - Learning Rate: {config['learning_rate']}")
    print(f"  - Weight Decay: {config['weight_decay']}")
    print("\nDataset:")
    print(f"  - Number of Samples: {config['num_samples']}")
    print(f"  - Train Split: {config['train_split'] * 100:.0f}%")
    print("=" * 60)


# MAIN TRAINING FUNCTION

def main():
    """Main training function for encoder-only transformer."""
    config = get_default_config()
    print_config(config)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"\nUsing device: {device}\n")

    # Load dataset
    print("="*60)
    print("LOADING ARXIV DATASET")
    print("="*60 + "\n")

    all_abstracts = load_arxiv_huggingface(num_samples=config['num_samples'])

    # Split data
    random.shuffle(all_abstracts)
    split_idx = int(config['train_split'] * len(all_abstracts))
    train_abstracts = all_abstracts[:split_idx]
    val_abstracts = all_abstracts[split_idx:]

    print(f"\n[OK] Dataset split:")
    print(f"  Training: {len(train_abstracts)} abstracts")
    print(f"  Validation: {len(val_abstracts)} abstracts")

    # Build tokenizer
    print("\nTraining WordPiece tokenizer...")
    tokenizer = build_tokenizer(train_abstracts, vocab_size=config['vocab_size'])
    print(f"[OK] Vocabulary size: {tokenizer.get_vocab_size()}")

    # Create datasets
    train_dataset = ArXivClassificationDataset(
        train_abstracts, tokenizer, 
        max_len=config['max_seq_len'],
        num_classes=config['num_classes']
    )
    val_dataset = ArXivClassificationDataset(
        val_abstracts, tokenizer,
        max_len=config['max_seq_len'],
        num_classes=config['num_classes']
    )

    train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], shuffle=False)

    # Initialize model
    model = EncoderOnlyTransformer(
        vocab_size=tokenizer.get_vocab_size(),
        d_model=config['d_model'],
        num_layers=config['num_layers'],
        num_heads=config['num_heads'],
        d_ff=config['d_ff'],
        max_seq_len=config['max_seq_len'],
        num_classes=config['num_classes'],
        dropout=config['dropout'],
        pooling_type=config['pooling_type']
    ).to(device)

    num_params = sum(p.numel() for p in model.parameters())
    print(f"\n[OK] Model parameters: {num_params:,}")
    print(f"[OK] Model size: ~{num_params * 4 / 1024 / 1024:.2f} MB")

    # Training setup
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )

    # Training loop
    import time
    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []
    best_val_accuracy = 0.0
    best_epoch = 0

    print("\n" + "="*60)
    print("STARTING TRAINING")
    print("="*60 + "\n")

    for epoch in range(config['num_epochs']):
        epoch_start_time = time.time()

        train_loss, train_acc, batch_losses = train_epoch(model, train_loader, optimizer, device)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        val_loss, val_acc, val_predictions, val_labels = evaluate_model(model, val_loader, device)
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        scheduler.step(val_loss)
        epoch_time = time.time() - epoch_start_time

        print(f"Epoch {epoch+1}/{config['num_epochs']}:")
        print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
        print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        print(f"  Time: {epoch_time:.1f}s")

        if val_acc > best_val_accuracy:
            best_val_accuracy = val_acc
            best_epoch = epoch + 1
            print(f"  [NEW BEST] Epoch {best_epoch}, Val Acc: {best_val_accuracy:.4f}")

        print()

    # Create save directory first for saving graphs and CSVs
    save_dir = get_model_save_dir()
    os.makedirs(save_dir, exist_ok=True)

    # Save training curves and metrics
    print("\nGenerating and saving training curves...")
    plot_training_curves(train_losses, train_accuracies, val_losses, val_accuracies, save_dir=save_dir)

    print("\nGenerating and saving confusion matrix...")
    plot_confusion_matrix(val_labels, val_predictions, config['num_classes'], TOPIC_NAMES, save_dir=save_dir)

    # Save all metrics as CSV for reports
    print("\nSaving training metrics to CSV files...")
    save_training_metrics_csv(save_dir, train_losses, train_accuracies, val_losses, val_accuracies, config)
    save_classification_report_csv(save_dir, val_labels, val_predictions, TOPIC_NAMES)

    # Save model
    print("\n" + "="*60)
    print("SAVING MODEL")
    print("="*60)
    
    training_history = {
        'train_losses': train_losses,
        'train_accuracies': train_accuracies,
        'val_losses': val_losses,
        'val_accuracies': val_accuracies
    }
    save_model(
        model=model,
        tokenizer=tokenizer,
        config=config,
        save_dir=save_dir,
        training_history=training_history,
        best_epoch=best_epoch,
        best_val_accuracy=best_val_accuracy
    )

    # Final summary
    print("\n" + "="*60)
    print("TRAINING COMPLETE!")
    print("="*60)
    print(f"\nFinal Results:")
    print(f"  Best Model: Epoch {best_epoch} (Val Accuracy: {best_val_accuracy:.4f})")
    print(f"  Final Train Loss: {train_losses[-1]:.4f}")
    print(f"  Final Train Accuracy: {train_accuracies[-1]:.4f}")
    print(f"  Final Val Accuracy: {val_accuracies[-1]:.4f}")
    print(f"  Model saved to: {save_dir}")
    print("\nSaved files:")
    print(f"  - model.pt (model weights)")
    print(f"  - tokenizer.json (trained tokenizer)")
    print(f"  - config.json (model configuration)")
    print(f"  - metadata.json (training metadata)")
    print(f"  - training_curves.png (training visualization)")
    print(f"  - confusion_matrix.png (classification matrix)")
    print(f"  - training_metrics_epochs.csv (per-epoch metrics)")
    print(f"  - training_summary.csv (summary statistics)")
    print(f"  - config.csv (configuration)")
    print(f"  - classification_report.csv (classification metrics)")
    print(f"  - confusion_matrix.csv (confusion matrix data)")
    print("="*60)

    return model, tokenizer, config, train_losses, train_accuracies, val_losses, val_accuracies, save_dir


In [None]:
if __name__ == "__main__":
    model, tokenizer, config, train_losses, train_accuracies, val_losses, val_accuracies, save_dir = main()
    print(f"\nModel ready for inference!")
    print(f"To load this model later, use:")
    print(f"   model, tokenizer, config, metadata = load_saved_model('{save_dir}')")


## Understanding Encoder-Only vs Decoder-Only Transformers

### Key Architectural Differences

**Encoder-Only Transformers (BERT-style)**:
- **Bidirectional Attention**: Each token can attend to all tokens in the sequence (past and future)
- **Use Case**: Understanding tasks (classification, NER, QA)
- **Training**: Masked Language Modeling (MLM) or direct classification
- **No Autoregressive Generation**: Processes entire sequence at once

**Decoder-Only Transformers (GPT-style)**:
- **Causal/Masked Attention**: Each token can only attend to previous tokens
- **Use Case**: Generation tasks (text generation, translation)
- **Training**: Next-token prediction (autoregressive)
- **Autoregressive Generation**: Generates tokens one at a time

### Why Bidirectional Attention for Classification?

Bidirectional attention allows the model to use **full context** from both directions when making classification decisions. For example, when classifying a document, the model can consider:
- The beginning of the document when processing the end
- The end of the document when processing the beginning
- All tokens simultaneously to make a holistic decision

This is different from decoder-only models, which must process tokens sequentially and can only use past context.


## Pooling Strategies

After processing the sequence through encoder blocks, we need to extract a **single vector representation** for classification. Two common strategies:

### 1. CLS Token Pooling
- Add a special `[CLS]` token at the beginning of the sequence
- After encoding, use the CLS token's representation as the sequence embedding
- **Advantage**: The CLS token can learn to aggregate sequence information
- **Used by**: BERT

### 2. Mean Pooling
- Average all token representations (excluding padding)
- **Advantage**: Uses all tokens equally
- **Used by**: Some models when CLS token is not available

In this implementation, both strategies are supported and can be configured via the `pooling_type` parameter.


## Configurable Hyperparameters

All model dimensions and training hyperparameters are configurable:

### Model Architecture
- **d_model**: Embedding dimension (e.g., 256, 512, 768)
- **num_layers**: Number of encoder blocks (e.g., 4, 6, 12)
- **num_heads**: Number of attention heads (must divide d_model)
- **d_ff**: Feed-forward dimension (typically 4× d_model)
- **num_classes**: Number of classification classes
- **pooling_type**: 'cls' or 'mean'

### Training
- **batch_size**: Training batch size
- **num_epochs**: Number of training epochs
- **learning_rate**: Learning rate for optimizer
- **weight_decay**: L2 regularization strength

### Dataset
- **num_samples**: Number of samples to load
- **max_seq_len**: Maximum sequence length
- **vocab_size**: Vocabulary size for tokenizer

You can modify these in the `get_default_config()` function or override them in `main()`.


## Attention Mechanism: Bidirectional vs Causal

### Bidirectional Attention (Encoder-Only)

```
Token positions:  [0]  [1]  [2]  [3]
Can attend to:    ✓✓✓  ✓✓✓  ✓✓✓  ✓✓✓
                 (all positions)
```

Each token can attend to **all tokens** in the sequence, including future tokens. This is achieved by:
1. Computing attention scores for all pairs of tokens
2. **NOT** applying a causal mask (upper triangular mask)
3. Only applying padding masks to ignore padding tokens

### Causal Attention (Decoder-Only)

```
Token positions:  [0]  [1]  [2]  [3]
Can attend to:    ✓    ✓✓   ✓✓✓  ✓✓✓✓
                 (only past)
```

Each token can only attend to **previous tokens**. This is achieved by:
1. Computing attention scores for all pairs
2. **Applying** a causal mask (upper triangular mask set to -inf)
3. This prevents tokens from seeing future information

### Why This Matters

- **Bidirectional**: Better for understanding tasks (classification, sentiment analysis)
- **Causal**: Required for generation tasks (text generation, where future tokens don't exist yet)


## Model Architecture Summary

```
Input Text
    ↓
Tokenization + CLS/SEP tokens
    ↓
Token Embeddings (vocab_size → d_model)
    ↓
+ Positional Encoding
    ↓
┌─────────────────┐
│ Encoder Block 1 │ ← Bidirectional Attention + FFN
└─────────────────┘
    ↓
┌─────────────────┐
│ Encoder Block 2 │ ← Bidirectional Attention + FFN
└─────────────────┘
    ↓
    ...
    ↓
┌─────────────────┐
│ Encoder Block N │ ← Bidirectional Attention + FFN
└─────────────────┘
    ↓
Pooling (CLS token or Mean)
    ↓
Classification Head (d_model → num_classes)
    ↓
Class Predictions
```

### Data Flow

1. **Input**: Raw text abstracts
2. **Tokenization**: Convert to token IDs, add CLS/SEP tokens
3. **Embedding**: Map token IDs to dense vectors (d_model dimensions)
4. **Positional Encoding**: Add position information
5. **Encoder Blocks**: Process through N layers of bidirectional attention
6. **Pooling**: Extract sequence-level representation
7. **Classification**: Map to class probabilities


## Paper Classification Inference

Now that we have a trained model, we can use it to classify new paper abstracts. The `classify_paper` function takes a paper abstract as input and returns the predicted research topic with confidence scores.


In [None]:
# INFERENCE FUNCTION

def classify_paper(abstract, model, tokenizer, config, device=None):
    """Classify a paper abstract using the trained encoder-only transformer."""
    if device is None:
        device = next(model.parameters()).device
    
    model.eval()
    
    cls_id = tokenizer.token_to_id("[CLS]") if tokenizer.token_to_id("[CLS]") is not None else tokenizer.token_to_id("[BOS]")
    pad_id = tokenizer.token_to_id("[PAD]")
    sep_id = tokenizer.token_to_id("[SEP]") if tokenizer.token_to_id("[SEP]") is not None else tokenizer.token_to_id("[EOS]")
    
    encoding = tokenizer.encode(abstract)
    tokens = encoding.ids
    
    max_len = config['max_seq_len']
    if len(tokens) > max_len - 2:
        tokens = tokens[:max_len - 2]
    
    tokens = [cls_id] + tokens + [sep_id]
    pad_len = max_len - len(tokens)
    tokens = tokens + [pad_id] * pad_len
    
    attention_mask = [1 if t != pad_id else 0 for t in tokens]
    
    input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
    attention_mask = torch.tensor([attention_mask], dtype=torch.long).to(device)
    
    with torch.no_grad():
        logits = model(input_ids, attention_mask)
        probabilities = F.softmax(logits, dim=-1)
        predicted_class = torch.argmax(probabilities, dim=-1).item()
        confidence = probabilities[0, predicted_class].item()
        all_probs = {TOPIC_NAMES[i]: probabilities[0, i].item() for i in range(len(TOPIC_NAMES))}
    
    return {
        'predicted_class': predicted_class,
        'predicted_topic': TOPIC_NAMES[predicted_class],
        'confidence': confidence,
        'all_probabilities': all_probs
    }


def print_classification_result(result):
    """Pretty print the classification result."""
    print("\n" + "="*70)
    print("PAPER CLASSIFICATION RESULT")
    print("="*70)
    print(f"\nPredicted Topic: {result['predicted_topic']}")
    print(f"Confidence: {result['confidence']*100:.2f}%")
    print("\nAll Topic Probabilities:")
    
    sorted_probs = sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True)
    for topic, prob in sorted_probs:
        bar = "#" * int(prob * 30)
        print(f"  {topic:35s} {prob*100:6.2f}% {bar}")
    print("="*70)
