# Text Summarization with CNN/DailyMail Dataset: Data & Task Overview

## The Original Task: Tweet Summarization
The original goal of this project is to develop a transformer model from scratch for abstractive summarization of tweets. However, a significant challenge in this domain is the lack of large-scale, high-quality datasets pairing tweets with their summaries.

## Using CNN/DailyMail as the Base Dataset
Due to the unavailability of a large twitter summarization dataset, we've chosen the CNN/DailyMail dataset to train our base model. This dataset consists of news articles from CNN and Daily Mail websites paired with human-written summaries (called "highlights"). These summaries are bullet-point style abstracts created by professional editors.

Key statistics:
- Training set: 287,113 article-summary pairs
- Validation set: 13,368 article-summary pairs
- Test set: 11,490 article-summary pairs

## The Task - Abstractive Text Summarization
This project implements an **abstractive summarization** model using a Transformer architecture. Unlike extractive summarization (which selects sentences from the source text), abstractive summarization generates new text that captures the essential information from the source document.

## Why CNN/DailyMail is Suitable
1. **Real-world application**: News summarization is a practical task with commercial applications.
2. **Well-defined task**: The highlights are professionally written and follow consistent patterns.
3. **Appropriate length ratio**: Articles are typically several paragraphs (400-800 words), while summaries are 3-4 bullet points (around 40-70 words), providing a reasonable compression ratio.
4. **Diverse topics**: The dataset covers a wide range of news topics, making models trained on it more generalizable.
5. **Multiple references**: Each article has multiple highlight points, allowing evaluation of different aspects of summarization.

Our plan is to first develop and train a strong base model on this dataset, which can later be adapted for the tweet summarization task. efficient and effective.

In [5]:
# Install Required Packages
!pip install datasets transformers tokenizers torch nltk rouge_score pandas numpy tqdm wandb tensorboardX sentencepiece einops matplotlib
!pip install accelerate rouge



In [1]:
# Environment Setup
import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import nltk
from nltk.tokenize import sent_tokenize
from datasets import load_dataset, Dataset
from transformers import PreTrainedTokenizerFast
from tokenizers import ByteLevelBPETokenizer, Tokenizer
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import wandb  # Optional for experiment tracking
from tensorboardX import SummaryWriter
import logging
import regex as re
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import autocast, GradScaler  # For mixed precision training
from torch.nn.utils import clip_grad_norm_  # For gradient clipping
import torch.cuda.amp as amp  # For mixed precision
from torch.utils.checkpoint import checkpoint  # For gradient checkpointing
import time  # For timing training
import json  # For saving configs
import matplotlib.pyplot as plt  # For plotting
import shutil  # For disk usage information
import random  # For seed setting

## Random Seed Initialization

The code below sets fixed random seeds across all libraries used in this project. This is critical for:

- **Reproducibility**: Ensures the same results can be obtained across different runs
- **Consistent evaluation**: Guarantees that the generated summaries remain consistent for proper analysis and comparison
- **Reliable generation**: With our temperature-based sampling approach, fixed seeds ensure consistent token selection during text generation
- **Deterministic behavior**: Makes debugging and validation possible by eliminating randomness as a variable

For academic and research contexts, reproducibility is a fundamental requirement. Without these seeds, the model would produce different summaries each time, making proper analysis and comparison impossible.

The `deterministic` and `benchmark` settings specifically configure CUDA operations to prioritize consistent results over performance optimizations.

In [2]:
# Set random seeds 
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Download necessary NLTK data
nltk.download('punkt')

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Check GPU specs if available
if torch.cuda.is_available():
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Check disk space
total, used, free = shutil.disk_usage("/")
print(f"Disk space - Total: {total/1e9:.1f} GB, Used: {used/1e9:.1f} GB, Free: {free/1e9:.1f} GB")

Using device: cuda
GPU available: NVIDIA GeForce RTX 3060 Laptop GPU
Total GPU memory: 6.44 GB
Disk space - Total: 372.8 GB, Used: 338.8 GB, Free: 34.0 GB


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\nisha\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
# Load Dataset
# Load CNN/DailyMail dataset
cnn_dailymail = load_dataset("cnn_dailymail", "3.0.0")

# Print info
print("Dataset loaded successfully!")
print(f"Dataset structure: {cnn_dailymail}")
print(f"Number of training examples: {len(cnn_dailymail['train'])}")
print(f"Number of validation examples: {len(cnn_dailymail['validation'])}")
print(f"Number of test examples: {len(cnn_dailymail['test'])}")

# Sample 
sample = cnn_dailymail['train'][0]
print("\nSample article (first 300 chars):")
print(sample['article'][:300] + "...")
print("\nSample highlights (summary):")
print(sample['highlights'])

Dataset loaded successfully!
Dataset structure: DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})
Number of training examples: 287113
Number of validation examples: 13368
Number of test examples: 11490

Sample article (first 300 chars):
LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won't cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappoi...

Sample highlights (summary):
Harry Potter star Daniel Radcliffe gets £20M fortune as he turns 18 Monday .
Young actor says he has no plans to fritter his cash away .
Radcliffe's

# Custom Tokenization Strategy for Abstractive Summarization

## The Tokenization Architecture

The project implements a custom `OptimizedTokenizer` class that wraps around Hugging Face's Rust-based tokenizers library instead of the more commonly used SentencePiece. This choice was deliberate and offers several advantages for our summarization task.

### Key Components of the Implementation:

1. **ByteLevelBPE Tokenizer**: The implementation uses Byte-Level Byte-Pair Encoding (BPE) algorithm for tokenization, which works by iteratively merging the most frequent pairs of bytes in the text.

2. **Special Tokens Management**: Four special tokens are explicitly defined with assigned IDs:
   - `<pad>` (ID 0): For padding sequences to uniform length
   - `<sos>` (ID 1): Start of sequence marker
   - `<eos>` (ID 2): End of sequence marker
   - `<unk>` (ID 3): For handling unknown tokens

3. **Fast Training Process**: The tokenizer is trained on a sampled subset of the data (around 200,000 texts including both articles and summaries).

4. **Vocabulary Size**: A vocabulary size of 32,000 tokens is used, which is large enough to capture the diversity of news language.

## Rationale for Avoiding SentencePiece

The implementation deliberately avoids SentencePiece for several reasons:

1. **Performance**: ByteLevelBPE with the Rust implementation provides significant performance gains during both training and inference compared to SentencePiece.

2. **Efficiency**: Training the tokenizer takes minutes rather than hours, making development iterations faster.

3. **Memory Usage**: The Rust-based implementation is more memory-efficient, allowing for processing larger batches during training.

4. **Fine-grained Control**: The implementation provides explicit control over special token IDs, which is important for the transformer architecture.

5. **HuggingFace Integration**: Using `PreTrainedTokenizerFast` wrapper ensures compatibility with the broader ecosystem.

## Implementation Details

The implementation includes methods for:
- Training the tokenizer on a corpus (`train`)
- Encoding single texts (`encode`)
- Batch encoding multiple texts (`batch_encode`)
- Decoding token IDs back to text (`decode`)
- Saving and loading the tokenizer (`save`, `load`)

When used in the pipeline, the tokenizer truncates inputs to 512 tokens and targets to 128 tokens, which aligns with the characteristics of CNN/DailyMail articles and summaries.

For a text summarization task like ours, having an efficient tokenization process is crucial due to the large amount of text data being processed. The chosen approach optimizes for both training speed and runtime performance. both training speed and runtime performance.

In [3]:
# Optimized Tokenizer Class
class OptimizedTokenizer:
    def __init__(self, vocab_size=32000):
        self.vocab_size = vocab_size
        self.tokenizer = None
        self.special_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        self.special_token_ids = {
            "<pad>": 0,
            "<sos>": 1,
            "<eos>": 2,
            "<unk>": 3
        }
        
    def train(self, texts, model_prefix="tokenizer", num_samples=None):
        """Train a ByteLevelBPE tokenizer (much faster than SentencePiece)"""
        os.makedirs(model_prefix, exist_ok=True)
        
        # Sample texts to speed up training if needed
        if num_samples and len(texts) > num_samples:
            import random
            random.seed(42)
            texts = random.sample(texts, num_samples)
        
        # Write sample texts to file
        corpus_path = "corpus.txt"
        print(f"Writing {len(texts)} texts to file...")
        with open(corpus_path, 'w', encoding='utf-8') as f:
            for text in texts:
                f.write(text + '\n')
        
        # Initialize and train the tokenizer (Rust implementation - very fast)
        print("Training tokenizer (this should take minutes, not hours)...")
        from tokenizers import Tokenizer
        from tokenizers.models import BPE
        from tokenizers.trainers import BpeTrainer
        from tokenizers.pre_tokenizers import Whitespace
        
        # Create a new BPE tokenizer
        tokenizer = Tokenizer(BPE(unk_token="<unk>"))
        tokenizer.pre_tokenizer = Whitespace()
        
        # Prepare the trainer
        trainer = BpeTrainer(
            vocab_size=self.vocab_size,
            special_tokens=self.special_tokens,
            min_frequency=2
        )
        
        # Train the tokenizer
        tokenizer.train(files=[corpus_path], trainer=trainer)
        
        # Save the tokenizer
        tokenizer_path = os.path.join(model_prefix, "tokenizer.json")
        tokenizer.save(tokenizer_path)
        print(f"Tokenizer saved to {tokenizer_path}")
        
        # Load the tokenizer
        from transformers import PreTrainedTokenizerFast
        self.tokenizer = PreTrainedTokenizerFast(
            tokenizer_file=tokenizer_path,
            bos_token="<sos>",
            eos_token="<eos>",
            pad_token="<pad>",
            unk_token="<unk>"
        )
        
        # Set special token IDs explicitly
        self.tokenizer.pad_token_id = 0
        self.tokenizer.bos_token_id = 1
        self.tokenizer.eos_token_id = 2
        self.tokenizer.unk_token_id = 3
        
        # Clean up
        if os.path.exists(corpus_path):
            os.remove(corpus_path)
            
        print(f"Tokenizer training complete!")
        print(f"Vocabulary size: {self.tokenizer.vocab_size}")
        
    def encode(self, text, max_length=None, padding="max_length", truncation=True):
        if self.tokenizer is None:
            raise ValueError("Tokenizer not trained. Call train() first.")
        
        # Use the HuggingFace tokenizer
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            max_length=max_length,
            padding=padding,
            truncation=truncation,
            return_attention_mask=True,
            return_tensors=None
        )
        
        return {
            "input_ids": encoding["input_ids"],
            "attention_mask": encoding["attention_mask"]
        }
    
    def batch_encode(self, texts, max_length=None, padding="max_length", truncation=True):
        if self.tokenizer is None:
            raise ValueError("Tokenizer not trained. Call train() first.")
        
        # Batch encode
        encodings = self.tokenizer(
            texts,
            add_special_tokens=True,
            max_length=max_length,
            padding=padding,
            truncation=truncation,
            return_attention_mask=True,
            return_tensors=None
        )
        
        return {
            "input_ids": encodings["input_ids"],
            "attention_mask": encodings["attention_mask"]
        }
    
    def decode(self, token_ids, skip_special_tokens=True):
        if self.tokenizer is None:
            raise ValueError("Tokenizer not trained. Call train() first.")
            
        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.cpu().tolist()
            
        return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
    
    def save(self, path):
        if self.tokenizer is None:
            raise ValueError("Tokenizer not trained. Call train() first.")
        
        os.makedirs(path, exist_ok=True)
        self.tokenizer.save_pretrained(path)
        
        # Save the special token mapping separately
        with open(os.path.join(path, "special_tokens.json"), "w") as f:
            json.dump(self.special_token_ids, f)
    
    def load(self, path):
        from transformers import PreTrainedTokenizerFast
        
        tokenizer_path = os.path.join(path, "tokenizer.json")
        if os.path.exists(tokenizer_path):
            self.tokenizer = PreTrainedTokenizerFast(
                tokenizer_file=tokenizer_path,
                bos_token="<sos>",
                eos_token="<eos>",
                pad_token="<pad>",
                unk_token="<unk>"
            )
        else:
            # Try loading as a pretrained tokenizer
            self.tokenizer = PreTrainedTokenizerFast.from_pretrained(path)
        
        # Load special token mapping if it exists
        special_tokens_path = os.path.join(path, "special_tokens.json")
        if os.path.exists(special_tokens_path):
            with open(special_tokens_path, "r") as f:
                self.special_token_ids = json.load(f)
        
        # Ensure the tokenizer has the correct special tokens
        self.tokenizer.pad_token = "<pad>"
        self.tokenizer.bos_token = "<sos>"
        self.tokenizer.eos_token = "<eos>"
        self.tokenizer.unk_token = "<unk>"
        
        # Set special token IDs explicitly
        self.tokenizer.pad_token_id = self.special_token_ids["<pad>"]
        self.tokenizer.bos_token_id = self.special_token_ids["<sos>"]
        self.tokenizer.eos_token_id = self.special_token_ids["<eos>"]
        self.tokenizer.unk_token_id = self.special_token_ids["<unk>"]

# Tokenizer Training: Balancing Coverage and Efficiency

## Training Process Overview

The tokenizer training section implements a carefully balanced approach to building a vocabulary that can handle the complexities of news articles while maintaining computational efficiency. The code first checks if a previously trained tokenizer exists, and if not, proceeds to train a new one using a substantial sample of the dataset.

## Why a 32,000 Token Vocabulary?

The choice of a 32,000 token vocabulary represents a deliberate balance between several competing factors:

1. **Linguistic Coverage**: News articles contain diverse vocabulary including named entities, domain-specific terminology, and rare words. A larger vocabulary ensures better coverage of these elements without excessive use of unknown tokens.

2. **Model Size Considerations**: Each additional token in the vocabulary increases the size of the embedding matrices in the transformer model. At 32,000 tokens, we achieve good coverage while keeping the model parameters manageable.

3. **Training Efficiency**: Larger vocabularies require more computation during both training and inference. The 32,000 size allows for efficient training while preserving linguistic nuance.

4. **Industry Standard Alignment**: This size is in line with other successful language models - slightly larger than BERT (30,000) but smaller than GPT-2 (50,000).

5. **Subword Efficiency**: With BPE tokenization, 32,000 tokens provides sufficient granularity to reconstruct most words while capturing common subword patterns.

## Corpus Selection Strategy

The implementation samples 100,000 articles but also includes their summaries, effectively creating a 200,000 text corpus. This ensures the tokenizer learns patterns from both input articles and target summaries, which is crucial for the seq2seq nature of the summarization task.

The sample size was increased from an earlier version (from 50,000 to 100,000), indicating an empirical decision to improve vocabulary coverage based on observed performance during development.

In [4]:
# Train Tokenizer
# Initialize tokenizer
MAX_VOCAB_SIZE = 32000
tokenizer = OptimizedTokenizer(vocab_size=MAX_VOCAB_SIZE)

# Check if tokenizer already exists
if os.path.exists(r"C:\Users\nisha\Downloads\tokenizer\tokenizer.json"):
    print("Loading existing tokenizer...")
    tokenizer.load(r"C:\Users\nisha\Downloads\tokenizer")
    print(f"Tokenizer loaded with vocabulary size: {tokenizer.tokenizer.vocab_size}")
else:
    print("Preparing texts for tokenizer training...")
    # Use a larger subset of the dataset for better vocabulary coverage
    sample_size = 100000  
    train_samples = cnn_dailymail['train'].select(range(sample_size))
    
    texts = []
    for example in tqdm(train_samples):
        texts.append(example['article'])
        texts.append(example['highlights'])
    
    print(f"Training tokenizer on {len(texts)} texts...")
    start_time = time.time()
    tokenizer.train(texts, model_prefix="tokenizer", num_samples=200000)  # Increased sample size
    end_time = time.time()
    print(f"Tokenizer training completed in {end_time - start_time:.2f} seconds")
    
    # Save the tokenizer
    tokenizer.save("tokenizer")
    print("Tokenizer saved to 'tokenizer' directory")

Loading existing tokenizer...
Tokenizer loaded with vocabulary size: 32000


# Transformer Architecture: Advanced Components for Summarization

## Positional Encoding

The positional encoding component addresses a fundamental limitation of transformer models: they have no inherent understanding of sequence order. Unlike RNNs, which process tokens sequentially, transformers process all tokens in parallel.

This implementation uses sinusoidal positional encodings, which add position-dependent patterns to each embedding. The mathematical properties of these sine/cosine functions allow the model to attend to relative positions, making it possible to understand the sequential nature of text while retaining the benefits of parallel processing.

## Enhanced Multi-Head Attention

The `ImprovedMultiHeadAttention` class represents a refined implementation of the attention mechanism that forms the core of the transformer. Key enhancements include:

- **Numerical Stability**: Uses a smaller negative value (-1e4 instead of -1e9) for masked positions to prevent overflow in mixed precision training
- **Proper Initialization**: Weight matrices are initialized with Xavier uniform distribution to ensure stable gradient flow
- **Flexible Masking**: Supports multiple mask dimensions for different attention patterns

Multi-head attention allows the model to jointly attend to information from different representation subspaces at different positions. This enables capturing different types of dependencies in the text - some heads might focus on local syntactic patterns while others capture document-level semantic relationships.

## Position-wise Feed-Forward Networks

The feed-forward networks apply two linear transformations with a GELU activation in between. This creates a component that can model complex token-level transformations. Key features:

- **GELU Activation**: Uses Gaussian Error Linear Unit instead of ReLU, providing smoother gradients
- **Proper Dropout**: Applied after activation to improve regularization
- **Careful Initialization**: Parameter initialization designed to prevent vanishing/exploding gradients

## Encoder and Decoder Architecture

The encoder and decoder follow the classic transformer design but with several optimizations:

- **Pre-Layer Normalization**: Unlike the original transformer's post-layer norm, this implementation applies normalization before each sub-layer, significantly improving training stability
- **Residual Connections**: Carefully implemented skip connections help maintain gradient flow through deep networks
- **Shared Embeddings**: Input and output embeddings are shared to reduce parameters and improve regularization

## Complete Transformer Model

The `ImprovedTransformer` class brings everything together with several enhancements:

- **Efficient Masking**: Optimized logic for creating source and target masks
- **Three-Way Weight Tying**: Shares weights between encoder embeddings, decoder embeddings, and the output projection layer
- **Beam Search Generation**: Implements a sophisticated beam search algorithm with top-k sampling for better summary quality
- **Temperature Control**: Allows controlling the randomness in the generation process

These architectural choices reflect both the original transformer design principles and more recent improvements developed by the NLP community, creating a model particularly well-suited for abstractive summarization.

In [5]:
# Transformer Components
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_length=5000, dropout=0.1):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        # Create positional encoding matrix
        pe = torch.zeros(max_seq_length, d_model)
        position = torch.arange(0, max_seq_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        
        # Register buffer (not a parameter, but should be part of the module's state)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x: [batch_size, seq_len, d_model]
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return self.dropout(x)

class ImprovedMultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super(ImprovedMultiHeadAttention, self).__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        
        # Linear projections with weight initialization
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
        # Initialize weights properly
        for module in [self.W_q, self.W_k, self.W_v, self.W_o]:
            nn.init.xavier_uniform_(module.weight)
            nn.init.zeros_(module.bias)
            
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask=None, return_attention=False):
        batch_size = query.size(0)
        
        # Linear projections
        Q = self.W_q(query)
        K = self.W_k(key)
        V = self.W_v(value)
        
        # Reshape for multi-head attention
        Q = Q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [B, h, L_q, d_k]
        K = K.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [B, h, L_k, d_k]
        V = V.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [B, h, L_v, d_k]
        
        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # [B, h, L_q, L_k]
        
        if mask is not None:
            # Fix mask dimensions to match scores
            if mask.dim() == 2:  # [B, L]
                mask = mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, L]
            elif mask.dim() == 3:  # [B, L_q, L_k]
                mask = mask.unsqueeze(1)  # [B, 1, L_q, L_k]
            
            # Use a smaller negative value for numerical stability in mixed precision
            # -1e4 is safe for float16, while -1e9 causes overflow
            scores = scores.masked_fill(mask == 0, -1e4)
        
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Apply attention weights to values
        context = torch.matmul(attn_weights, V)  # [B, h, L_q, d_k]
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)  # [B, L_q, d_model]
        
        # Final projection
        output = self.W_o(context)
        
        if return_attention:
            return output, attn_weights
        else:
            return output

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1, activation='gelu'):
        super(FeedForward, self).__init__()
        
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        
        # Choose activation function
        if activation == 'relu':
            self.activation = F.relu
        elif activation == 'gelu':
            self.activation = F.gelu
        else:
            raise ValueError(f"Unsupported activation: {activation}")
        
        # Initialize weights
        nn.init.xavier_uniform_(self.linear1.weight)
        nn.init.zeros_(self.linear1.bias)
        nn.init.xavier_uniform_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)
        
    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, activation='gelu'):
        super(EncoderLayer, self).__init__()
        
        # Self-attention mechanism
        self.self_attn = ImprovedMultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout, activation)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
        
        # Dropout for regularization
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Pre-LayerNorm architecture (more stable for training)
        norm_x = self.norm1(x)
        attn_output = self.self_attn(norm_x, norm_x, norm_x, mask)
        x = x + self.dropout1(attn_output)
        
        norm_x = self.norm2(x)
        ff_output = self.feed_forward(norm_x)
        x = x + self.dropout2(ff_output)
        
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1, activation='gelu'):
        super(DecoderLayer, self).__init__()
        
        # Self-attention mechanism
        self.self_attn = ImprovedMultiHeadAttention(d_model, num_heads, dropout)
        
        # Cross-attention mechanism (for encoder-decoder attention)
        self.cross_attn = ImprovedMultiHeadAttention(d_model, num_heads, dropout)
        
        # Feed-forward network
        self.feed_forward = FeedForward(d_model, d_ff, dropout, activation)
        
        # Layer normalization
        self.norm1 = nn.LayerNorm(d_model, eps=1e-6)
        self.norm2 = nn.LayerNorm(d_model, eps=1e-6)
        self.norm3 = nn.LayerNorm(d_model, eps=1e-6)
        
        # Dropout for regularization
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # Pre-LayerNorm architecture
        # Self-attention block
        norm_x = self.norm1(x)
        self_attn_output = self.self_attn(norm_x, norm_x, norm_x, tgt_mask)
        x = x + self.dropout1(self_attn_output)
        
        # Cross-attention block
        norm_x = self.norm2(x)
        cross_attn_output = self.cross_attn(norm_x, enc_output, enc_output, src_mask)
        x = x + self.dropout2(cross_attn_output)
        
        # Feed-forward block
        norm_x = self.norm3(x)
        ff_output = self.feed_forward(norm_x)
        x = x + self.dropout3(ff_output)
        
        return x

# Encoder-Decoder Architecture: The Core of Sequence Transformation

## Encoder: Processing Source Documents

The Encoder component is responsible for transforming the input article into a rich contextual representation. Its key characteristics include:

1. **Token and Positional Embedding Combination**: 
   - Input tokens are first embedded into a continuous vector space
   - These embeddings are scaled by √d_model to balance their magnitude with positional encodings
   - Positional encodings are added to provide sequence order information

2. **Stacked Processing Layers**:
   - Multiple identical encoder layers process the input sequentially
   - Each layer has self-attention and feed-forward sub-layers
   - Layer normalization and residual connections maintain gradient flow
   
3. **Final Normalization**: 
   - A layer normalization is applied to the output, producing the encoded representation
   - This stabilizes the encoder output before it's passed to the decoder

4. **Attention Masking**:
   - The encoder uses masking to ignore padding tokens, ensuring attention is only computed over actual content
   - This is crucial for processing variable-length articles efficiently

## Decoder: Generating Summaries

The Decoder takes the encoder output and autoregressively generates the summary. Its architecture includes:

1. **Embeddings with Positional Information**:
   - Like the encoder, it embeds tokens and adds positional encodings
   - During training, it processes the entire target sequence in parallel with causal masking
   - During inference, it generates tokens one by one

2. **Additional Complexity**:
   - Each decoder layer has three sub-components rather than two
   - Self-attention over previously generated tokens
   - Cross-attention to access the encoder's representation
   - Feed-forward network for transformation

3. **Causal Masking**:
   - A triangular attention mask ensures the decoder can only attend to positions that come before the current one
   - This prevents information leakage from future tokens during training

4. **Cross-Attention Mechanism**:
   - The critical component where the decoder attends to the encoded source
   - Allows the model to focus on relevant parts of the article when generating each summary token

## Encoder-Decoder Interaction

The interaction between these components creates a powerful mechanism for transforming articles into summaries:

1. **Information Flow**: 
   - The encoder processes the entire article at once
   - The decoder accesses this encoded information through cross-attention
   - This allows selective focusing on relevant article sections

2. **Parameter Sharing**:
   - The implementation shares embedding weights between encoder and decoder
   - This weight tying reduces model size and acts as a regularization technique

3. **Model Scaling**:
   - The architecture can be scaled by adjusting the number of layers, attention heads, and model dimension
   - The implementation uses 6 layers, giving sufficient depth for learning complex patterns without excessive computational requirements

This encoder-decoder architecture is particularly well-suited for summarization because it can process long documents, identify salient information, and generate fluent summaries that capture the essence of the source text.

In [6]:
# Encoder and Decoder Implementation
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout=0.1, activation='gelu'):
        super(Encoder, self).__init__()
        
        # Token embedding + positional encoding
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)
        
        # Stack of encoder layers
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout, activation)
            for _ in range(num_layers)
        ])
        
        # Final layer normalization
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
        
        # Initialize embeddings
        nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5)
        
    def forward(self, x, mask=None):
        # x: [batch_size, seq_len]
        
        # Embed tokens and add positional encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)
        
        # Apply encoder layers
        for layer in self.layers:
            x = layer(x, mask)
            
        # Apply final normalization
        x = self.norm(x)
            
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, dropout=0.1, activation='gelu'):
        super(Decoder, self).__init__()
        
        # Token embedding + positional encoding
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = PositionalEncoding(d_model, dropout=dropout)
        
        # Stack of decoder layers
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, num_heads, d_ff, dropout, activation)
            for _ in range(num_layers)
        ])
        
        # Final layer normalization
        self.norm = nn.LayerNorm(d_model, eps=1e-6)
        
        # Initialize embeddings
        nn.init.normal_(self.embedding.weight, mean=0, std=d_model**-0.5)
        
    def forward(self, x, enc_output, src_mask=None, tgt_mask=None):
        # x: [batch_size, seq_len]
        
        # Embed tokens and add positional encoding
        x = self.embedding(x) * math.sqrt(self.embedding.embedding_dim)
        x = self.pos_encoding(x)
        
        # Apply decoder layers
        for layer in self.layers:
            x = layer(x, enc_output, src_mask, tgt_mask)
            
        # Apply final normalization
        x = self.norm(x)
            
        return x

# Full Transformer Model Implementation

## ImprovedTransformer Class

The `ImprovedTransformer` class in Cell 8 implements a complete sequence-to-sequence transformer designed specifically for abstractive summarization. This class integrates the previously defined encoder and decoder components into a cohesive model.

## Core Components and Design Decisions

1. **Initialization Parameters**:
   - The model accepts configuration for vocabulary sizes, dimensions, heads, feed-forward size, and number of layers
   - Default values (d_model=512, num_heads=8, d_ff=2048, num_layers=6) follow the original transformer paper with modest adjustments

2. **Special Token Management**:
   - The class explicitly stores token IDs (pad=0, sos=1, eos=2) needed for masking and generation
   - These IDs align with the tokenizer's special tokens

3. **Encoder-Decoder Integration**:
   - The encoder processes the entire source text in parallel
   - The decoder works with the encoder outputs through cross-attention mechanisms

4. **Parameter Sharing Strategy**:
   - The `share_embeddings` parameter enables weight sharing between encoder and decoder embeddings
   - Three-way weight tying links the output projection layer with the decoder embeddings
   - This significantly reduces model parameters and improves regularization

5. **Masking Utilities**:
   - `create_src_mask`: Generates masks for handling padding in source sequences
   - `create_tgt_mask`: Creates combined causal and padding masks for target sequences
   - These ensure proper attention behavior during both training and inference

6. **Forward Method**:
   - Handles the complete sequence-to-sequence transformation
   - Connects encoder and decoder with appropriate masking
   - Returns logits over the vocabulary for each target position

7. **Generation Method**:
   - Implements beam search with additional controls (beam size, top-k, temperature)
   - Uses helper `Beam` class to track multiple generation hypotheses
   - Incorporates early stopping for efficiency

This implementation combines architectural best practices with practical optimizations for the summarization task, striking a balance between model expressiveness and computational efficiency.

In [7]:
# Full Transformer Model
class ImprovedTransformer(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, num_heads=8,
                 d_ff=2048, num_layers=6, dropout=0.1, activation='gelu',
                 share_embeddings=True):
        super(ImprovedTransformer, self).__init__()
        
        # Store special token IDs for mask creation and generation
        self.pad_token_id = 0
        self.sos_token_id = 1
        self.eos_token_id = 2
        
        # Encoder and decoder
        self.encoder = Encoder(src_vocab_size, d_model, num_heads, d_ff, num_layers, dropout, activation)
        self.decoder = Decoder(tgt_vocab_size, d_model, num_heads, d_ff, num_layers, dropout, activation)
        
        # Final projection to vocabulary
        self.final_layer = nn.Linear(d_model, tgt_vocab_size, bias=False)
        
        # Share embeddings between encoder and decoder (parameter sharing)
        if share_embeddings:
            self.encoder.embedding.weight = self.decoder.embedding.weight
            
        # Share embeddings with final output layer (three-way weight tying)
        self.final_layer.weight = self.decoder.embedding.weight
        
    def create_src_mask(self, src):
        # src: [batch_size, src_len]
        # Create source padding mask: 1 for tokens, 0 for padding
        src_mask = (src != self.pad_token_id).unsqueeze(1).unsqueeze(2)
        # src_mask shape: [batch_size, 1, 1, src_len]
        return src_mask
    
    def create_tgt_mask(self, tgt):
        # tgt: [batch_size, tgt_len]
        # Create target padding mask
        tgt_pad_mask = (tgt != self.pad_token_id).unsqueeze(1).unsqueeze(3)
        # tgt_pad_mask shape: [batch_size, 1, tgt_len, 1]
        
        # Create causal mask to prevent attending to future positions
        tgt_len = tgt.size(1)
        tgt_sub_mask = torch.tril(torch.ones((tgt_len, tgt_len), device=tgt.device)).bool()
        tgt_sub_mask = tgt_sub_mask.unsqueeze(0).unsqueeze(1)
        # tgt_sub_mask shape: [1, 1, tgt_len, tgt_len]
        
        # Combine padding mask and causal mask
        tgt_mask = tgt_pad_mask & tgt_sub_mask
        # tgt_mask shape: [batch_size, 1, tgt_len, tgt_len]
        return tgt_mask
    
    def forward(self, src, tgt):
        # src: [batch_size, src_len]
        # tgt: [batch_size, tgt_len]
        
        # Create masks
        src_mask = self.create_src_mask(src)
        tgt_mask = self.create_tgt_mask(tgt)
        
        # Encoder forward pass
        enc_output = self.encoder(src, src_mask)
        
        # Decoder forward pass
        dec_output = self.decoder(tgt, enc_output, src_mask, tgt_mask)
        
        # Final projection to vocabulary
        output = self.final_layer(dec_output)
        
        return output
    
    def generate(self, src, max_length=128, beam_size=5, top_k=50, temperature=1.0, early_stopping=True):
        """
        Generate sequence using beam search.
        
        Args:
            src: Tensor of shape [batch_size, src_len]
            max_length: Maximum length of generated sequence
            beam_size: Number of beams for beam search
            top_k: Sample from top k most likely tokens for diversity
            temperature: Temperature for sampling
            early_stopping: Whether to stop when all beams have generated EOS
            
        Returns:
            Tensor of shape [batch_size, max_length]
        """
        batch_size = src.size(0)
        
        # Process source sequence
        src_mask = self.create_src_mask(src)
        encoder_output = self.encoder(src, src_mask)
        
        # Initialize beam for each batch item
        beams = [Beam(beam_size, self.pad_token_id, self.sos_token_id, self.eos_token_id, device=src.device)
                for _ in range(batch_size)]
        
        # Run beam search for each batch item
        for i in range(batch_size):
            # Extract encoder output for this batch item
            encoder_output_i = encoder_output[i:i+1].expand(beam_size, -1, -1)
            src_mask_i = src_mask[i:i+1].expand(beam_size, -1, -1, -1)
            
            # This will store the decoder input tokens
            input_ids = torch.full((beam_size, 1), self.sos_token_id, dtype=torch.long, device=src.device)
            
            # Generate tokens step by step
            for step in range(max_length):
                # Create tgt_mask for the current step
                tgt_mask = self.create_tgt_mask(input_ids)
                
                # Compute decoder output
                decoder_output = self.decoder(input_ids, encoder_output_i, src_mask_i, tgt_mask)
                
                # Get output for the last position
                last_token_output = decoder_output[:, -1]
                
                # Project to vocabulary
                logits = self.final_layer(last_token_output)
                
                # Apply temperature
                logits = logits / temperature
                
                # Apply top-k sampling
                if top_k > 0:
                    top_k_logits, top_k_indices = torch.topk(logits, top_k)
                    logits = torch.full_like(logits, float('-inf'))
                    logits.scatter_(1, top_k_indices, top_k_logits)
                
                # Convert to probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Update beam state
                beams[i].advance(probs)
                
                # Prepare for next step
                # Get the indices of the top beam_size hypotheses
                beam_indices = beams[i].get_current_state()
                
                # Select the next tokens and update the input for next step
                next_tokens = beams[i].get_current_tokens().unsqueeze(1)
                
                # Create new input by concatenating current input_ids with next tokens
                input_ids = torch.cat([input_ids[beam_indices], next_tokens], dim=1)
                
                # Check if all beams have generated EOS
                if early_stopping and beams[i].is_done():
                    break
            
        # Return the best hypothesis for each batch item
        output_ids = []
        for i in range(batch_size):
            output_ids.append(beams[i].get_best_hypothesis())
            
        # Pad sequences to max_length if needed
        max_len = max(len(ids) for ids in output_ids)
        padded_output = torch.full((batch_size, max_len), self.pad_token_id, dtype=torch.long, device=src.device)
        
        for i, ids in enumerate(output_ids):
            padded_output[i, :len(ids)] = torch.tensor(ids, dtype=torch.long, device=src.device)
            
        return padded_output

# Helper classes for beam search generation
class Beam:
    def __init__(self, beam_size, pad_token_id, sos_token_id, eos_token_id, device):
        self.beam_size = beam_size
        self.pad_token_id = pad_token_id
        self.sos_token_id = sos_token_id
        self.eos_token_id = eos_token_id
        self.device = device
        
        # Scores for each sequence
        self.scores = torch.zeros(beam_size, device=device)
        self.scores[1:] = -1e9  # Start with only one hypothesis
        
        # Backpointers at each step
        self.backpointers = []
        
        # Sequence tokens at each step
        self.tokens = torch.full((beam_size, 1), self.sos_token_id, dtype=torch.long, device=device)
        
        # Flags to indicate if a beam has reached EOS
        self.eos_top = False
        self.finished = []
        
    def get_current_state(self):
        """Get current output token indices"""
        return torch.arange(self.beam_size, device=self.device)
    
    def get_current_tokens(self):
        """Get the tokens for the current step"""
        return self.tokens[:, -1]
    
    def advance(self, word_probs):
        """Update beam state given log probabilities for the next token"""
        vocab_size = word_probs.size(-1)
        
        # Add current scores to the log probabilities
        scores = word_probs + self.scores.unsqueeze(1)
        
        # Flatten to find top beam_size candidates
        flat_scores = scores.view(-1)
        best_scores, best_scores_idx = flat_scores.topk(self.beam_size, largest=True, sorted=True)
        
        # Get beam indices and token indices
        beam_idx = best_scores_idx // vocab_size
        token_idx = best_scores_idx % vocab_size
        
        # Update backpointers
        self.backpointers.append(beam_idx)
        
        # Update tokens
        self.tokens = torch.cat([self.tokens[beam_idx], token_idx.unsqueeze(1)], dim=1)
        
        # Update scores
        self.scores = best_scores
        
        # Check if any hypothesis reached EOS
        eos_indices = torch.where(token_idx == self.eos_token_id)[0]
        if len(eos_indices) > 0:
            for idx in eos_indices:
                if not self.eos_top:
                    # Record the first EOS in the beam
                    self.eos_top = True
                    
                # Add finished hypothesis
                self.finished.append((self.scores[idx].item(), len(self.tokens[idx]) - 1, idx))
        
    def is_done(self):
        """Check if at least one beam has reached EOS"""
        return self.eos_top and len(self.finished) >= self.beam_size
    
    def get_best_hypothesis(self):
        """Get the best hypothesis"""
        if self.finished:
            # Sort finished hypotheses by score
            self.finished.sort(key=lambda x: -x[0])
            score, length, idx = self.finished[0]
            # Return best finished hypothesis
            return self.tokens[idx, 1:length+1].tolist()
        else:
            # Return best unfinished hypothesis
            return self.tokens[0, 1:].tolist()

# Data Preprocessing: Efficient Processing for Large-Scale Training

## Preprocessing Strategy

The data preprocessing implementation addresses the challenges of efficiently handling the large CNN/DailyMail dataset while preparing it for transformer training. The preprocessing pipeline includes several key optimizations:

1. **Length Constraints**:
   - Articles are capped at 512 tokens, capturing the majority of article content while keeping sequences manageable
   - Summaries are limited to 128 tokens, providing ample space for comprehensive summaries

2. **Label Processing for Training**:
   - Padding tokens in target sequences are replaced with -100, which PyTorch's loss functions automatically ignore
   - This prevents the model from learning to predict padding tokens
   - Token IDs are also verified to be within vocabulary size to prevent index errors

3. **Chunked Processing**:
   - The 287,113 training examples are processed in 50,000 example chunks
   - This chunking prevents memory overflow issues that would occur when processing the entire dataset at once
   - Each chunk is processed and then concatenated into the final dataset

4. **Parallelization**:
   - Multiple processes (num_proc=4) are used to accelerate tokenization
   - This leverages multicore processing capabilities, significantly reducing preprocessing time

5. **Persistence Strategy**:
   - Processed datasets are saved to disk after creation
   - The code checks for existing processed data before beginning the expensive preprocessing
   - This "process once, use many times" approach saves substantial time during development iterations

6. **Progress Monitoring**:
   - Detailed progress bars and time tracking provide visibility into the preprocessing status
   - This is particularly important given the time-intensive nature of processing large text datasets

## Implementation Details

The preprocessing handles several technical challenges:

1. **Memory Management**: By processing in chunks and using efficient data structures, the code avoids out-of-memory errors that would occur when naively processing the full dataset.

2. **Attention Mask Generation**: The tokenization process generates attention masks that identify which tokens are real content versus padding, essential for the transformer's attention mechanisms.

3. **Dataset Transformation**: The original dataset format is converted to a training-ready format with input_ids, attention_masks, and properly formatted labels.

This preprocessing approach balances thoroughness with efficiency, ensuring that the entire dataset is properly prepared for training while minimizing computational overhead.

In [8]:
# Data Preprocessing
# Define preprocessing parameters
MAX_INPUT_LENGTH = 512
MAX_TARGET_LENGTH = 128

def preprocess_function(examples):
    """
    Preprocess data examples for training:
    - Tokenize inputs (articles)
    - Tokenize targets (summaries)
    - Truncate to max length
    """
    # Tokenize the articles
    model_inputs = tokenizer.batch_encode(
        examples["article"], 
        max_length=MAX_INPUT_LENGTH, 
        padding="max_length", 
        truncation=True
    )
    
    # Tokenize the summaries
    labels = tokenizer.batch_encode(
        examples["highlights"], 
        max_length=MAX_TARGET_LENGTH, 
        padding="max_length", 
        truncation=True
    )
    
    # Set the labels
    model_inputs["labels"] = labels["input_ids"]
    
    # Replace padding token id with -100 for labels (PyTorch will ignore these in loss calculation)
    new_labels = []
    for label in labels["input_ids"]:
        # Important: Ensure label tokens don't exceed vocab size
        new_label = [l if l != tokenizer.tokenizer.pad_token_id and l < MAX_VOCAB_SIZE else -100 for l in label]
        new_labels.append(new_label)
    
    model_inputs["labels"] = new_labels
    
    return model_inputs

# Process datasets with progress tracking and multiple processes
print("Processing training dataset...")
start_time = time.time()

# Process all training data
train_dataset = cnn_dailymail["train"]
processed_data_dir = r"D:\NLP-Project\processed_text_summarization_data\processed_data"
train_data_path = os.path.join(processed_data_dir, "train")
val_data_path = os.path.join(processed_data_dir, "validation")

# Check if processed data already exists
if os.path.exists(train_data_path) and os.path.exists(val_data_path):
    print("Loading pre-processed datasets...")
    try:
        from datasets import load_from_disk
        tokenized_train_dataset = load_from_disk(train_data_path)
        tokenized_val_dataset = load_from_disk(val_data_path)
        print(f"Loaded {len(tokenized_train_dataset)} training examples and {len(tokenized_val_dataset)} validation examples.")
    except Exception as e:
        print(f"Error loading pre-processed data: {e}")
        raise e
else:
    print("Processing datasets from scratch...")
    os.makedirs(processed_data_dir, exist_ok=True)
    
    # Process training data in manageable chunks to avoid memory issues
    chunk_size = 50000
    num_chunks = math.ceil(len(train_dataset) / chunk_size)
    
    all_train_datasets = []
    for i in range(num_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, len(train_dataset))
        
        print(f"Processing chunk {i+1}/{num_chunks} (examples {start_idx} to {end_idx})...")
        chunk = train_dataset.select(range(start_idx, end_idx))
        
        tokenized_chunk = chunk.map(
            preprocess_function,
            batched=True,
            batch_size=1000,
            remove_columns=train_dataset.column_names,
            desc=f"Processing chunk {i+1}/{num_chunks}",
            num_proc=4  # Use multiple processes
        )
        
        all_train_datasets.append(tokenized_chunk)
    
    # Combine all chunks
    from datasets import concatenate_datasets
    tokenized_train_dataset = concatenate_datasets(all_train_datasets)
    
    print("Processing validation dataset...")
    tokenized_val_dataset = cnn_dailymail["validation"].map(
        preprocess_function,
        batched=True,
        batch_size=1000,
        remove_columns=cnn_dailymail["validation"].column_names,
        desc="Processing validation data",
        num_proc=4
    )
    
    end_time = time.time()
    print(f"Data preprocessing completed in {end_time - start_time:.2f} seconds")
    
    # Save processed datasets
    print("Saving processed datasets to disk...")
    tokenized_train_dataset.save_to_disk(train_data_path)
    tokenized_val_dataset.save_to_disk(val_data_path)

# Calculate dataset sizes
print(f"Processed train dataset size: {len(tokenized_train_dataset)}")
print(f"Processed validation dataset size: {len(tokenized_val_dataset)}")
print("Preprocessing complete!")

Processing training dataset...
Loading pre-processed datasets...
Loaded 287113 training examples and 13368 validation examples.
Processed train dataset size: 287113
Processed validation dataset size: 13368
Preprocessing complete!


# Hyperparameter Selection: Balancing Performance and Efficiency

## Model Hyperparameters

The implementation carefully selects hyperparameters that balance modeling capacity with computational efficiency:

1. **Model Dimensions**:
   - `D_MODEL = 768`: Larger than the original transformer paper (512) to increase model capacity for handling complex news articles
   - `NUM_HEADS = 12`: Increased from the standard 8 heads to allow finer-grained attention patterns
   - `D_FF = 3072`: Feed-forward dimension set to 4x the model dimension, providing sufficient transformation capacity
   - `NUM_LAYERS = 6`: Using 6 encoder and decoder layers offers good depth without excessive computation

2. **Regularization Controls**:
   - `DROPOUT = 0.1`: Standard dropout rate that prevents overfitting while maintaining training signal
   - `ACTIVATION = 'gelu'`: Gaussian Error Linear Unit provides smoother gradients than ReLU

## Training Hyperparameters

The training configuration incorporates several advanced techniques:

1. **Batch Processing Strategy**:
   - `BATCH_SIZE = 16`: Direct batch size constrained by GPU memory
   - `GRADIENT_ACCUMULATION_STEPS = 4`: Accumulates gradients across 4 batches
   - `EFFECTIVE_BATCH_SIZE = 64`: The resulting effective batch size provides more stable gradients

2. **Optimization Settings**:
   - `LEARNING_RATE = 3e-4`: Slightly lower than typical 5e-4 for improved stability
   - `WEIGHT_DECAY = 0.01`: L2 regularization to prevent overfitting
   - `NUM_EPOCHS = 10`: Sufficient for convergence on this dataset
   - `WARMUP_RATIO = 0.1`: Gradual warmup for 10% of total steps prevents early instability
   - `MAX_GRAD_NORM = 1.0`: Gradient clipping to prevent exploding gradients
   - `MIXED_PRECISION = True`: Enables FP16 computation for speed while maintaining numerical stability

## Data Processing Optimizations

The implementation includes a specialized collate function that:

1. Prepares target inputs for teacher forcing during training
2. Handles padding and masking efficiently
3. Properly marks invalid positions using the -100 label for PyTorch's loss functions

This careful balance of hyperparameters enables effective training of a powerful abstractive summarization model within reasonable computational constraints.

In [9]:
# Training Setup and Hyperparameters
# Model hyperparameters - optimized values
VOCAB_SIZE = 32000  # Ensure this matches the tokenizer vocabulary size
D_MODEL = 768       # Larger model dimension
NUM_HEADS = 12      # More attention heads
D_FF = 3072         # Larger feed-forward dimension
NUM_LAYERS = 6      # Enough layers for good performance without overfitting
DROPOUT = 0.1       # Standard dropout rate
ACTIVATION = 'gelu' # Better activation function

# Training hyperparameters
BATCH_SIZE = 16     # Increased batch size
GRADIENT_ACCUMULATION_STEPS = 4  # Accumulate gradients for stable training
EFFECTIVE_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS
LEARNING_RATE = 3e-4  # Slightly lower learning rate for stability
WEIGHT_DECAY = 0.01   # L2 regularization
NUM_EPOCHS = 10
WARMUP_RATIO = 0.1   # Warmup for 10% of total steps
MAX_GRAD_NORM = 1.0
MIXED_PRECISION = True  # Enable mixed precision training for speed

# Create dataloaders with optimized collate function
def collate_fn(batch):
    input_ids = torch.tensor([example['input_ids'] for example in batch])
    attention_mask = torch.tensor([example['attention_mask'] for example in batch])
    labels = torch.tensor([example['labels'] for example in batch])
    
    # Calculate target input (for teacher forcing) and target output
    # Target input is the labels shifted right with SOS token at the beginning
    target_input = torch.zeros_like(labels)
    target_input[:, 0] = tokenizer.tokenizer.bos_token_id  # Start with SOS token
    
    # Fill in the rest of the target input
    valid_positions = (labels != -100)[:, :-1]
    target_input[:, 1:][valid_positions] = labels[:, :-1][valid_positions]
    
    # Set padding for invalid positions
    invalid_positions = ~valid_positions
    target_input[:, 1:][invalid_positions] = tokenizer.tokenizer.pad_token_id
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'target_input': target_input,
        'labels': labels
    }

from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

# Create training dataloader
train_sampler = RandomSampler(tokenized_train_dataset)
train_dataloader = DataLoader(
    tokenized_train_dataset,
    sampler=train_sampler,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# Create validation dataloader
val_sampler = SequentialSampler(tokenized_val_dataset)
val_dataloader = DataLoader(
    tokenized_val_dataset,
    sampler=val_sampler,
    batch_size=BATCH_SIZE,
    collate_fn=collate_fn,
    num_workers=4,
    pin_memory=True
)

# Calculate total training steps and warmup steps
total_steps = len(train_dataloader) * NUM_EPOCHS // GRADIENT_ACCUMULATION_STEPS
warmup_steps = int(total_steps * WARMUP_RATIO)

print(f"Total training steps: {total_steps}")
print(f"Warmup steps: {warmup_steps}")

Total training steps: 44862
Warmup steps: 4486


# Loss Function and Learning Rate Scheduler: Optimizing Training Dynamics

## Label Smoothing Loss

The implementation uses a custom `LabelSmoothingLoss` class that incorporates label smoothing, a technique that improves model generalization and confidence calibration:

1. **Regularization Mechanism**:
   - Rather than training the model to predict exactly 1.0 for the correct class and 0.0 for all others, label smoothing redistributes some probability to other tokens
   - The implementation uses a smoothing factor of 0.1, reserving 90% probability for the correct token and distributing 10% across all other tokens

2. **Overconfidence Prevention**:
   - Without smoothing, models tend to become overconfident in their predictions
   - By introducing uncertainty into the training targets, the model learns to be appropriately cautious
   - This is particularly important for summarization, where multiple valid phrasings may exist

3. **Handling of Padding Tokens**:
   - The loss function carefully handles padding with an `ignore_index` parameter (set to -100)
   - This prevents the model from wasting capacity learning to predict padding tokens
   - The implementation includes safeguards to prevent index errors by clamping target values

4. **Numerical Stability**:
   - The implementation uses stable computation patterns to avoid underflow/overflow
   - This is especially important when training with mixed precision

## Warmup Cosine Scheduler

The `WarmupCosineScheduler` class implements an advanced learning rate schedule combining initial warmup with cosine decay:

1. **Warmup Phase**:
   - Learning rate increases linearly from 0 to the base learning rate during the first 10% of training steps
   - This prevents unstable gradient updates at the beginning of training when weights are randomly initialized
   - Warmup is especially important for transformer models due to their complex gradient flow through attention mechanisms

2. **Cosine Annealing Phase**:
   - After warmup, learning rate follows a cosine curve that gradually decreases to a minimum
   - This smooth decay allows the model to settle into better minima than step-based schedules
   - The cosine shape provides initially slow decay that accelerates in the middle of training and slows again near the end

3. **Implementation Approach**:
   - The scheduler updates the optimizer's learning rate after each step
   - It tracks its state through the `current_step` variable
   - The schedule calculation accounts for both warmup and annealing phases with appropriate transitions

This combination of label smoothing and advanced learning rate scheduling contributes significantly to training stability and final model quality, addressing common challenges in transformer training.

In [10]:
# Loss Function and Scheduler
# Label smoothing loss function for better regularization
class LabelSmoothingLoss(nn.Module):
    def __init__(self, smoothing=0.1, ignore_index=-100, vocab_size=None):
        super(LabelSmoothingLoss, self).__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.vocab_size = vocab_size
        self.confidence = 1.0 - smoothing
        
    def forward(self, output, target):
        # output: [batch_size, seq_len, vocab_size]
        # target: [batch_size, seq_len]
        
        batch_size, seq_len, vocab_size = output.size()
        output = output.reshape(-1, vocab_size)
        target = target.reshape(-1)
        
        # Create mask for ignored indices (padding)
        non_pad_mask = (target != self.ignore_index)
        
        # Get only valid targets
        target = target[non_pad_mask]
        output = output[non_pad_mask]
        
        # Count valid targets
        n_valid = target.size(0)
        if n_valid == 0:
            return torch.tensor(0.0, device=output.device, requires_grad=True)
        
        # Clamp target values to stay within vocabulary range
        target = torch.clamp(target, 0, vocab_size-1)
        
        # Create smoothed target distribution
        smooth_target = torch.zeros_like(output)
        smooth_target.fill_(self.smoothing / (vocab_size - 1))
        smooth_target.scatter_(1, target.unsqueeze(1), self.confidence)
        
        # Calculate loss using cross-entropy
        log_probs = F.log_softmax(output, dim=1)
        loss = -(smooth_target * log_probs).sum(dim=1).mean()
        
        return loss

# Improved learning rate scheduler with warmup
class WarmupCosineScheduler:
    def __init__(self, optimizer, warmup_steps, total_steps):
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.current_step = 0
        
    def step(self):
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
            
    def get_lr(self):
        # Linear warmup
        if self.current_step < self.warmup_steps:
            return LEARNING_RATE * (self.current_step / max(1, self.warmup_steps))
        
        # Cosine annealing
        progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
        return LEARNING_RATE * 0.5 * (1.0 + math.cos(math.pi * progress))

# Model Initialization and Training Functions: Implementation Details

## Model Initialization Strategy

The model initialization code demonstrates a robust approach to creating and configuring the transformer:

1. **Model Creation with Fallback Mechanism**:
   - The implementation first attempts to create the model with the optimal hyperparameters
   - Error handling detects potential memory issues and provides a fallback to a smaller model configuration
   - This graceful degradation ensures training can proceed even with memory constraints

2. **Weight Initialization and Parameter Sharing**:
   - The model uses weight sharing between encoder and decoder embeddings
   - The output projection layer reuses decoder embedding weights (three-way weight tying)
   - Model size statistics are calculated to verify the expected parameter count (123.8M parameters)

3. **Optimizer Configuration**:
   - AdamW optimizer is used with parameter-specific weight decay
   - Parameters are grouped to apply weight decay only to appropriate tensors:
     - Weight decay applied to most parameters
     - No weight decay for bias terms and LayerNorm weights
   - This selective weight decay follows best practices for transformer models

4. **Mixed Precision Setup**:
   - The code tests mixed precision capability with a small forward pass
   - GradScaler is initialized when mixed precision is enabled
   - This proactive testing prevents potential runtime errors during training

5. **Monitoring and Logging Infrastructure**:
   - TensorBoard writer is initialized for tracking metrics
   - Checkpoint directory is cng instability, and the need for robust error handling.

In [11]:
# Model Initialization
# Create model with optimized architecture
print("Creating improved transformer model...")
try:
    # Initialize the model
    model = ImprovedTransformer(
        src_vocab_size=VOCAB_SIZE,
        tgt_vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        d_ff=D_FF,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT,
        activation=ACTIVATION,
        share_embeddings=True  # Enable weight sharing for efficiency
    )
    
    # Move model to device
    model = model.to(device)
    print(f"Model successfully moved to {device}")
    
except RuntimeError as e:
    print(f"Error during model initialization: {e}")
    print("Reducing model size and trying again...")
    
    # Try with smaller model if we hit memory issues
    D_MODEL = 512
    NUM_HEADS = 8
    D_FF = 2048
    NUM_LAYERS = 4
    
    model = ImprovedTransformer(
        src_vocab_size=VOCAB_SIZE,
        tgt_vocab_size=VOCAB_SIZE,
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        d_ff=D_FF,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT,
        activation=ACTIVATION,
        share_embeddings=True
    ).to(device)

# Calculate model size
model_size = sum(p.numel() for p in model.parameters())
print(f"Model created with {model_size:,} parameters")
model_size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
print(f"Model size in memory: {model_size_bytes / 1e9:.2f} GB")

# Initialize optimizer with weight decay
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay': WEIGHT_DECAY},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]

optimizer = AdamW(
    optimizer_grouped_parameters,
    lr=LEARNING_RATE,
    betas=(0.9, 0.999),
    eps=1e-8
)
print("Using AdamW optimizer with weight decay")

# Initialize scheduler
scheduler = WarmupCosineScheduler(optimizer, warmup_steps, total_steps)

# Initialize loss function
criterion = LabelSmoothingLoss(smoothing=0.1, ignore_index=-100, vocab_size=VOCAB_SIZE)

# Check if mixed precision is causing issues, and disable if needed
try:
    # Test if half precision works with a small forward pass
    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        test_input = torch.ones(1, 10, dtype=torch.long).to(device)
        test_output = model(test_input, test_input)
    
    # If it works, use mixed precision
    scaler = GradScaler() if MIXED_PRECISION else None
    print(f"Mixed precision training enabled: {MIXED_PRECISION}")
except RuntimeError as e:
    print(f"Mixed precision test failed: {e}")
    print("Disabling mixed precision training")
    MIXED_PRECISION = False
    scaler = None

# Initialize tensorboard for logging
writer = SummaryWriter(log_dir="runs/improved_transformer")

# Create directory for checkpoints
os.makedirs("checkpoints", exist_ok=True)

Creating improved transformer model...
Model successfully moved to cuda
Model created with 123,816,960 parameters
Model size in memory: 0.50 GB
Using AdamW optimizer with weight decay
Mixed precision training enabled: True


  scaler = GradScaler() if MIXED_PRECISION else None


## Training Functions Implementation

The training implementation includes several sophisticated functions:

1. **Checkpoint Management**:
   - `save_checkpoint`: Captures complete training state (model, optimizer, scheduler) along with metadata
   - `load_checkpoint`: Restores training from a saved state, enabling seamless resumption
   - Includes model configuration in checkpoints for accurate reconstruction

2. **Training Epoch Function**:
   - Implements a complete training loop with progress tracking
   - Handles mixed precision training with appropriate context managers
   - Implements gradient accumulation for effective larger batch sizes
   - Applies gradient clipping to prevent gradient explosions
   - Includes sophisticated error handling for numerical issues
   - Implements regular checkpointing during training
   - Logs detailed metrics for monitoring

3. **Evaluation Function**:
   - Clean implementation for validation that disables gradients
   - Uses the same batching and processing logic as training for consistency
   - Returns validation loss for model selection and early stopping decisions

These implementations incorporate numerous best practices for training large-scale transformer models, addressing common challenges like memory constraints, training instability, and the need for robust error handling.

In [12]:
# Training Functions
def save_checkpoint(model, optimizer, scheduler, epoch, step, loss, filename):
    torch.save({
        'epoch': epoch,
        'global_step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state': {
            'current_step': scheduler.current_step,
            'warmup_steps': scheduler.warmup_steps,
            'total_steps': scheduler.total_steps
        },
        'loss': loss,
        'model_config': {
            'vocab_size': VOCAB_SIZE,
            'd_model': D_MODEL,
            'num_heads': NUM_HEADS,
            'd_ff': D_FF,
            'num_layers': NUM_LAYERS,
            'dropout': DROPOUT,
            'activation': ACTIVATION
        }
    }, filename)
    print(f"Checkpoint saved: {filename}")

def load_checkpoint(filename, model, optimizer, scheduler):
    if os.path.exists(filename):
        print(f"Loading checkpoint: {filename}")
        checkpoint = torch.load(filename, map_location=device)
        
        # Load model state
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Load optimizer state
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        # Restore scheduler state if available
        if 'scheduler_state' in checkpoint:
            scheduler.current_step = checkpoint['scheduler_state']['current_step']
            scheduler.warmup_steps = checkpoint['scheduler_state']['warmup_steps']
            scheduler.total_steps = checkpoint['scheduler_state']['total_steps']
        
        start_epoch = checkpoint['epoch'] + 1
        global_step = checkpoint['global_step']
        print(f"Resuming from epoch {start_epoch}, step {global_step}")
        return start_epoch, global_step
    else:
        print("No checkpoint found, starting from scratch")
        return 0, 0

def train_epoch(model, dataloader, optimizer, criterion, scheduler, scaler=None, 
               gradient_accumulation_steps=1, max_grad_norm=1.0, epoch=0, global_step=0):
    model.train()
    epoch_loss = 0.0
    steps_per_epoch = len(dataloader)
    progress_bar = tqdm(enumerate(dataloader), total=steps_per_epoch, 
                       desc=f"Training Epoch {epoch+1}")
    
    # Reset gradients at the beginning of epoch
    optimizer.zero_grad()
    
    for step, batch in progress_bar:
        try:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            target_input = batch['target_input'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass with mixed precision if enabled
            if scaler is not None:
                # Use torch.amp.autocast('cuda') instead of torch.cuda.amp.autocast()
                with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                    outputs = model(input_ids, target_input)
                    loss = criterion(outputs, labels)
                    loss = loss / gradient_accumulation_steps
                    
                # Backward pass with scaling
                scaler.scale(loss).backward()
                
                # Step optimizer and scaler after accumulation
                if (step + 1) % gradient_accumulation_steps == 0 or step == steps_per_epoch - 1:
                    # Unscale before gradient clipping
                    scaler.unscale_(optimizer)
                    clip_grad_norm_(model.parameters(), max_grad_norm)
                    
                    # Update parameters with scaler
                    scaler.step(optimizer)
                    scaler.update()
                    
                    # Step scheduler
                    scheduler.step()
                    
                    # Reset gradients
                    optimizer.zero_grad()
                    
                    # Update global step
                    global_step += 1
            else:
                # Standard forward pass
                outputs = model(input_ids, target_input)
                loss = criterion(outputs, labels)
                loss = loss / gradient_accumulation_steps
                
                # Backward pass
                loss.backward()
                
                # Step optimizer after accumulation
                if (step + 1) % gradient_accumulation_steps == 0 or step == steps_per_epoch - 1:
                    clip_grad_norm_(model.parameters(), max_grad_norm)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
                    global_step += 1
            
            # Track loss
            epoch_loss += loss.item() * gradient_accumulation_steps
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': epoch_loss / (step + 1),
                'lr': optimizer.param_groups[0]['lr']
            })
            
            # Log to tensorboard every 100 steps
            if global_step % 100 == 0:
                writer.add_scalar('train/loss', loss.item() * gradient_accumulation_steps, global_step)
                writer.add_scalar('train/learning_rate', optimizer.param_groups[0]['lr'], global_step)
            
            # Save checkpoint periodically
            if global_step % 1000 == 0:
                save_checkpoint(
                    model, optimizer, scheduler, epoch, global_step,
                    loss.item() * gradient_accumulation_steps,
                    f"checkpoints/checkpoint_step_{global_step}.pt"
                )
                
        except RuntimeError as e:
            if "overflow" in str(e) or "underflow" in str(e) or "out of range" in str(e):
                print(f"Numerical error in batch (skipping): {e}")
                # Skip this batch and continue with next one
                optimizer.zero_grad()
                if scaler is not None:
                    # If we have overflow in mixed precision, reduce the scale
                    scaler.update()
                continue
            else:
                # For other runtime errors, re-raise
                raise e
    
    # Return the average loss for the epoch and the updated global step
    return epoch_loss / steps_per_epoch, global_step

def evaluate(model, dataloader, criterion):
    model.eval()
    eval_loss = 0.0
    steps = len(dataloader)
    progress_bar = tqdm(enumerate(dataloader), total=steps, desc="Evaluating")
    
    with torch.no_grad():
        for step, batch in progress_bar:
            # Move batch to device
            input_ids = batch['input_ids'].to(device)
            target_input = batch['target_input'].to(device)
            labels = batch['labels'].to(device)
            
            # Forward pass
            outputs = model(input_ids, target_input)
            loss = criterion(outputs, labels)
            
            # Track loss
            eval_loss += loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({'loss': eval_loss / (step + 1)})
    
    # Return the average loss
    return eval_loss / steps

# The Training Loop: Comprehensive Implementation with Monitoring

## Training Loop Architecture

The training loop implementation in Cell 14 represents a comprehensive approach to transformer training, combining several advanced techniques:

1. **Checkpoint Recovery Mechanism**:
   - Training begins by attempting to load the latest checkpoint
   - If found, training resumes from the saved epoch and step
   - If not, training starts from scratch
   - This enables resilience against interruptions and allows for training continuation

2. **Progress Tracking and Metrics**:
   - Training statistics are stored in a JSON file that persists across sessions
   - The loop maintains and updates best validation loss for model selection
   - Detailed logging includes per-epoch metrics like training time and learning rates

3. **Main Training Procedure**:
   - Structured as a double loop: outer loop over epochs, inner loop via the `train_epoch` function
   - Each epoch follows a train-then-evaluate pattern
   - The results are logged to TensorBoard for visualization
   - Progress is displayed to the user with detailed metrics

4. **Model Persistence Strategy**:
   - Regular checkpoints save model state after each epoch
   - Special "best model" checkpoint preserves the model with lowest validation loss
   - The final model is saved regardless of performance
   - Emergency checkpoints are created if training is interrupted

5. **Comprehensive Error Handling**:
   - Try/except blocks capture both user interruptions and unexpected errors
   - Upon interruption, an emergency checkpoint is saved
   - Errors trigger diagnostics and attempt to preserve training progress
   - TensorBoard writer is properly closed in the finally block

## Implementation Details

The implementation addresses several practical challenges:

1. **Training Duration Tracking**:
   - Timestamps record the start of training and each epoch
   - Elapsed time is calculated and reported
   - This helps with planning and resource allocation

2. **Model Selection**:
   - Validation loss determines the "best" model
   - The loop tracks and updates this metric after each epoch
   - Clear logging signals when a new best model is found

3. **Resource Management**:
   - The loop implements a structured approach to manage memory and computation
   - Regular state saving prevents excessive resource waste on interruption
   - Properly closes resources with the finally block

This carefully structured training loop creates a robust training procedure capable of handling the lengthy process of training a large transformer model on the substantial CNN/DailyMail dataset, with appropriate safeguards against common failure modes.

In [19]:
# Training Loop
# Auto-resume from checkpoint
checkpoint_path = "checkpoints/latest_checkpoint.pt"
start_epoch, global_step = load_checkpoint(checkpoint_path, model, optimizer, scheduler) if os.path.exists(checkpoint_path) else (0, 0)

# Track best model
best_val_loss = float('inf')
training_stats = []

# Load existing stats if available
if os.path.exists('training_stats.json'):
    with open('training_stats.json', 'r') as f:
        training_stats = json.load(f)

# Training start time
train_start_time = time.time()

print(f"Starting training from epoch {start_epoch+1} with {len(train_dataloader)} batches per epoch")
print(f"Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS} (effective batch size: {EFFECTIVE_BATCH_SIZE})")
print(f"Using mixed precision: {MIXED_PRECISION}")

# Training loop
try:
    for epoch in range(start_epoch, NUM_EPOCHS):
        epoch_start_time = time.time()
        
        # Train one epoch
        train_loss, global_step = train_epoch(
            model=model,
            dataloader=train_dataloader,
            optimizer=optimizer,
            criterion=criterion,
            scheduler=scheduler,
            scaler=scaler,
            gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
            max_grad_norm=MAX_GRAD_NORM,
            epoch=epoch,
            global_step=global_step
        )
        
        # Calculate epoch time
        epoch_time = time.time() - epoch_start_time
        
        print(f"Epoch {epoch+1}/{NUM_EPOCHS} completed in {epoch_time:.2f}s - Loss: {train_loss:.4f}")
        
        # Evaluate
        val_loss = evaluate(model, val_dataloader, criterion)
        print(f"Validation Loss: {val_loss:.4f}")
        
        # Log to tensorboard
        writer.add_scalar('epoch/train_loss', train_loss, epoch)
        writer.add_scalar('epoch/val_loss', val_loss, epoch)
        writer.add_scalar('epoch/time', epoch_time, epoch)
        
        # Save checkpoint after each epoch
        save_checkpoint(
            model, optimizer, scheduler, epoch, global_step, val_loss,
            checkpoint_path
        )
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"New best validation loss: {best_val_loss:.4f} - Saving model...")
            save_checkpoint(
                model, optimizer, scheduler, epoch, global_step, val_loss,
                "checkpoints/best_model.pt"
            )
        
        # Track training stats
        training_stats.append({
            'epoch': epoch + 1,
            'train_loss': train_loss,
            'val_loss': val_loss,
            'epoch_time': epoch_time,
            'learning_rate': optimizer.param_groups[0]['lr']
        })
        
        # Save training stats
        with open('training_stats.json', 'w') as f:
            json.dump(training_stats, f)
    
    # Training complete
    total_training_time = time.time() - train_start_time
    print(f"Training completed in {total_training_time/60:.2f} minutes")
    print(f"Best validation loss: {best_val_loss:.4f}")
    
    # Save final model
    save_checkpoint(
        model, optimizer, scheduler, NUM_EPOCHS-1, global_step, val_loss,
        "checkpoints/final_model.pt"
    )

except KeyboardInterrupt:
    print("Training interrupted by user!")
    # Save emergency checkpoint
    try:
        save_checkpoint(
            model, optimizer, scheduler, epoch, global_step, train_loss,
            "checkpoints/interrupted_checkpoint.pt"
        )
        print("Emergency checkpoint saved")
    except Exception as e:
        print(f"Could not save emergency checkpoint: {e}")

except Exception as e:
    print(f"Error during training: {e}")
    # Try to save emergency checkpoint
    try:
        save_checkpoint(
            model, optimizer, scheduler, epoch, global_step, train_loss,
            "checkpoints/error_checkpoint.pt"
        )
        print("Emergency checkpoint saved")
    except Exception as e2:
        print(f"Could not save emergency checkpoint: {e2}")

finally:
    # Close tensorboard writer
    writer.close()

Starting training from epoch 1 with 17945 batches per epoch
Gradient accumulation steps: 4 (effective batch size: 64)
Using mixed precision: True


Training Epoch 1:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_0.pt
Checkpoint saved: checkpoints/checkpoint_step_0.pt
Checkpoint saved: checkpoints/checkpoint_step_0.pt
Checkpoint saved: checkpoints/checkpoint_step_1000.pt
Checkpoint saved: checkpoints/checkpoint_step_1000.pt
Checkpoint saved: checkpoints/checkpoint_step_1000.pt
Checkpoint saved: checkpoints/checkpoint_step_1000.pt
Checkpoint saved: checkpoints/checkpoint_step_2000.pt
Checkpoint saved: checkpoints/checkpoint_step_2000.pt
Checkpoint saved: checkpoints/checkpoint_step_2000.pt
Checkpoint saved: checkpoints/checkpoint_step_2000.pt
Checkpoint saved: checkpoints/checkpoint_step_3000.pt
Checkpoint saved: checkpoints/checkpoint_step_3000.pt
Checkpoint saved: checkpoints/checkpoint_step_3000.pt
Checkpoint saved: checkpoints/checkpoint_step_3000.pt
Checkpoint saved: checkpoints/checkpoint_step_4000.pt
Checkpoint saved: checkpoints/checkpoint_step_4000.pt
Checkpoint saved: checkpoints/checkpoint_step_4000.pt
Checkpoint saved: checkpoints/checkpo

Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 5.1591
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 5.1591 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 2:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_5000.pt
Checkpoint saved: checkpoints/checkpoint_step_5000.pt
Checkpoint saved: checkpoints/checkpoint_step_5000.pt
Checkpoint saved: checkpoints/checkpoint_step_5000.pt
Checkpoint saved: checkpoints/checkpoint_step_6000.pt
Checkpoint saved: checkpoints/checkpoint_step_6000.pt
Checkpoint saved: checkpoints/checkpoint_step_6000.pt
Checkpoint saved: checkpoints/checkpoint_step_6000.pt
Checkpoint saved: checkpoints/checkpoint_step_7000.pt
Checkpoint saved: checkpoints/checkpoint_step_7000.pt
Checkpoint saved: checkpoints/checkpoint_step_7000.pt
Checkpoint saved: checkpoints/checkpoint_step_7000.pt
Checkpoint saved: checkpoints/checkpoint_step_8000.pt
Checkpoint saved: checkpoints/checkpoint_step_8000.pt
Checkpoint saved: checkpoints/checkpoint_step_8000.pt
Checkpoint saved: checkpoints/checkpoint_step_8000.pt
Epoch 2/10 completed in 1632.03s - Loss: 4.8443


Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 4.1304
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 4.1304 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 3:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_9000.pt
Checkpoint saved: checkpoints/checkpoint_step_9000.pt
Checkpoint saved: checkpoints/checkpoint_step_9000.pt
Checkpoint saved: checkpoints/checkpoint_step_9000.pt
Checkpoint saved: checkpoints/checkpoint_step_10000.pt
Checkpoint saved: checkpoints/checkpoint_step_10000.pt
Checkpoint saved: checkpoints/checkpoint_step_10000.pt
Checkpoint saved: checkpoints/checkpoint_step_10000.pt
Checkpoint saved: checkpoints/checkpoint_step_11000.pt
Checkpoint saved: checkpoints/checkpoint_step_11000.pt
Checkpoint saved: checkpoints/checkpoint_step_11000.pt
Checkpoint saved: checkpoints/checkpoint_step_11000.pt
Checkpoint saved: checkpoints/checkpoint_step_12000.pt
Checkpoint saved: checkpoints/checkpoint_step_12000.pt
Checkpoint saved: checkpoints/checkpoint_step_12000.pt
Checkpoint saved: checkpoints/checkpoint_step_12000.pt
Checkpoint saved: checkpoints/checkpoint_step_13000.pt
Checkpoint saved: checkpoints/checkpoint_step_13000.pt
Checkpoint sav

Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.7069
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.7069 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 4:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_14000.pt
Checkpoint saved: checkpoints/checkpoint_step_14000.pt
Checkpoint saved: checkpoints/checkpoint_step_14000.pt
Checkpoint saved: checkpoints/checkpoint_step_14000.pt
Checkpoint saved: checkpoints/checkpoint_step_15000.pt
Checkpoint saved: checkpoints/checkpoint_step_15000.pt
Checkpoint saved: checkpoints/checkpoint_step_15000.pt
Checkpoint saved: checkpoints/checkpoint_step_15000.pt
Checkpoint saved: checkpoints/checkpoint_step_16000.pt
Checkpoint saved: checkpoints/checkpoint_step_16000.pt
Checkpoint saved: checkpoints/checkpoint_step_16000.pt
Checkpoint saved: checkpoints/checkpoint_step_16000.pt
Checkpoint saved: checkpoints/checkpoint_step_17000.pt
Checkpoint saved: checkpoints/checkpoint_step_17000.pt
Checkpoint saved: checkpoints/checkpoint_step_17000.pt
Checkpoint saved: checkpoints/checkpoint_step_17000.pt
Epoch 4/10 completed in 1631.56s - Loss: 3.7231


Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.5773
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.5773 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 5:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_18000.pt
Checkpoint saved: checkpoints/checkpoint_step_18000.pt
Checkpoint saved: checkpoints/checkpoint_step_18000.pt
Checkpoint saved: checkpoints/checkpoint_step_18000.pt
Checkpoint saved: checkpoints/checkpoint_step_19000.pt
Checkpoint saved: checkpoints/checkpoint_step_19000.pt
Checkpoint saved: checkpoints/checkpoint_step_19000.pt
Checkpoint saved: checkpoints/checkpoint_step_19000.pt
Checkpoint saved: checkpoints/checkpoint_step_20000.pt
Checkpoint saved: checkpoints/checkpoint_step_20000.pt
Checkpoint saved: checkpoints/checkpoint_step_20000.pt
Checkpoint saved: checkpoints/checkpoint_step_20000.pt
Checkpoint saved: checkpoints/checkpoint_step_21000.pt
Checkpoint saved: checkpoints/checkpoint_step_21000.pt
Checkpoint saved: checkpoints/checkpoint_step_21000.pt
Checkpoint saved: checkpoints/checkpoint_step_21000.pt
Checkpoint saved: checkpoints/checkpoint_step_22000.pt
Checkpoint saved: checkpoints/checkpoint_step_22000.pt
Checkpoint

Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4989
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.4989 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 6:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_23000.pt
Checkpoint saved: checkpoints/checkpoint_step_23000.pt
Checkpoint saved: checkpoints/checkpoint_step_23000.pt
Checkpoint saved: checkpoints/checkpoint_step_23000.pt
Checkpoint saved: checkpoints/checkpoint_step_24000.pt
Checkpoint saved: checkpoints/checkpoint_step_24000.pt
Checkpoint saved: checkpoints/checkpoint_step_24000.pt
Checkpoint saved: checkpoints/checkpoint_step_24000.pt
Checkpoint saved: checkpoints/checkpoint_step_25000.pt
Checkpoint saved: checkpoints/checkpoint_step_25000.pt
Checkpoint saved: checkpoints/checkpoint_step_25000.pt
Checkpoint saved: checkpoints/checkpoint_step_25000.pt
Checkpoint saved: checkpoints/checkpoint_step_26000.pt
Checkpoint saved: checkpoints/checkpoint_step_26000.pt
Checkpoint saved: checkpoints/checkpoint_step_26000.pt
Checkpoint saved: checkpoints/checkpoint_step_26000.pt
Epoch 6/10 completed in 1652.61s - Loss: 3.3816


Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4541
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.4541 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 7:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_27000.pt
Checkpoint saved: checkpoints/checkpoint_step_27000.pt
Checkpoint saved: checkpoints/checkpoint_step_27000.pt
Checkpoint saved: checkpoints/checkpoint_step_27000.pt
Checkpoint saved: checkpoints/checkpoint_step_28000.pt
Checkpoint saved: checkpoints/checkpoint_step_28000.pt
Checkpoint saved: checkpoints/checkpoint_step_28000.pt
Checkpoint saved: checkpoints/checkpoint_step_28000.pt
Checkpoint saved: checkpoints/checkpoint_step_29000.pt
Checkpoint saved: checkpoints/checkpoint_step_29000.pt
Checkpoint saved: checkpoints/checkpoint_step_29000.pt
Checkpoint saved: checkpoints/checkpoint_step_29000.pt
Checkpoint saved: checkpoints/checkpoint_step_30000.pt
Checkpoint saved: checkpoints/checkpoint_step_30000.pt
Checkpoint saved: checkpoints/checkpoint_step_30000.pt
Checkpoint saved: checkpoints/checkpoint_step_30000.pt
Checkpoint saved: checkpoints/checkpoint_step_31000.pt
Checkpoint saved: checkpoints/checkpoint_step_31000.pt
Checkpoint

Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4361
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.4361 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 8:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_32000.pt
Checkpoint saved: checkpoints/checkpoint_step_32000.pt
Checkpoint saved: checkpoints/checkpoint_step_32000.pt
Checkpoint saved: checkpoints/checkpoint_step_32000.pt
Checkpoint saved: checkpoints/checkpoint_step_33000.pt
Checkpoint saved: checkpoints/checkpoint_step_33000.pt
Checkpoint saved: checkpoints/checkpoint_step_33000.pt
Checkpoint saved: checkpoints/checkpoint_step_33000.pt
Checkpoint saved: checkpoints/checkpoint_step_34000.pt
Checkpoint saved: checkpoints/checkpoint_step_34000.pt
Checkpoint saved: checkpoints/checkpoint_step_34000.pt
Checkpoint saved: checkpoints/checkpoint_step_34000.pt
Checkpoint saved: checkpoints/checkpoint_step_35000.pt
Checkpoint saved: checkpoints/checkpoint_step_35000.pt
Checkpoint saved: checkpoints/checkpoint_step_35000.pt
Checkpoint saved: checkpoints/checkpoint_step_35000.pt
Epoch 8/10 completed in 1638.47s - Loss: 3.1483


Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4274
Checkpoint saved: checkpoints/latest_checkpoint.pt
New best validation loss: 3.4274 - Saving model...
Checkpoint saved: checkpoints/best_model.pt


Training Epoch 9:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_36000.pt
Checkpoint saved: checkpoints/checkpoint_step_36000.pt
Checkpoint saved: checkpoints/checkpoint_step_36000.pt
Checkpoint saved: checkpoints/checkpoint_step_36000.pt
Checkpoint saved: checkpoints/checkpoint_step_37000.pt
Checkpoint saved: checkpoints/checkpoint_step_37000.pt
Checkpoint saved: checkpoints/checkpoint_step_37000.pt
Checkpoint saved: checkpoints/checkpoint_step_37000.pt
Checkpoint saved: checkpoints/checkpoint_step_38000.pt
Checkpoint saved: checkpoints/checkpoint_step_38000.pt
Checkpoint saved: checkpoints/checkpoint_step_38000.pt
Checkpoint saved: checkpoints/checkpoint_step_38000.pt
Checkpoint saved: checkpoints/checkpoint_step_39000.pt
Checkpoint saved: checkpoints/checkpoint_step_39000.pt
Checkpoint saved: checkpoints/checkpoint_step_39000.pt
Checkpoint saved: checkpoints/checkpoint_step_39000.pt
Checkpoint saved: checkpoints/checkpoint_step_40000.pt
Checkpoint saved: checkpoints/checkpoint_step_40000.pt
Checkpoint

Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4357
Checkpoint saved: checkpoints/latest_checkpoint.pt


Training Epoch 10:   0%|          | 0/17945 [00:00<?, ?it/s]

Checkpoint saved: checkpoints/checkpoint_step_41000.pt
Checkpoint saved: checkpoints/checkpoint_step_41000.pt
Checkpoint saved: checkpoints/checkpoint_step_41000.pt
Checkpoint saved: checkpoints/checkpoint_step_41000.pt
Checkpoint saved: checkpoints/checkpoint_step_42000.pt
Checkpoint saved: checkpoints/checkpoint_step_42000.pt
Checkpoint saved: checkpoints/checkpoint_step_42000.pt
Checkpoint saved: checkpoints/checkpoint_step_42000.pt
Checkpoint saved: checkpoints/checkpoint_step_43000.pt
Checkpoint saved: checkpoints/checkpoint_step_43000.pt
Checkpoint saved: checkpoints/checkpoint_step_43000.pt
Checkpoint saved: checkpoints/checkpoint_step_43000.pt
Checkpoint saved: checkpoints/checkpoint_step_44000.pt
Checkpoint saved: checkpoints/checkpoint_step_44000.pt
Checkpoint saved: checkpoints/checkpoint_step_44000.pt
Checkpoint saved: checkpoints/checkpoint_step_44000.pt
Epoch 10/10 completed in 1644.76s - Loss: 3.0394


Evaluating:   0%|          | 0/836 [00:00<?, ?it/s]

Validation Loss: 3.4382
Checkpoint saved: checkpoints/latest_checkpoint.pt
Training completed in 281.77 minutes
Best validation loss: 3.4274
Checkpoint saved: checkpoints/final_model.pt


## Approach 1: Built-in Generate Method

### Design Philosophy
This approach leverages the model's native generation capabilities, prioritizing implementation cleanliness and code maintainability over specialized controls for summarization.

### Key Implementation Features

- **Architectural Integration**: Directly uses the model's generation pipeline without additional custom logic

- **Simplified Parameter Control**: Offers configuration of standard generation parameters (beam size, temperature, top-k)

- **Minimal Post-Processing**: Applies only basic cleanup to the generated text

### Performance Analysis
The built-in approach achieves lower scores across all metrics:
- ROUGE-1: 0.2315
- ROUGE-2: 0.0716
- ROUGE-L: 0.2192

The generated summaries tend to be significantly longer (62.76 words) than references (34.98 words), which negatively impacts precision and overall quality.

## Technical Implications

The performance gap between these approaches demonstrates that text generation for specialized tasks like summarization benefits significantly from task-specific constraints and controls that go beyond general-purpose generation mechanisms.

The custom approach shows that:
1. **Length control is critical** for summarization quality
2. **Repetition prevention** substantially improves readability
3. **Task-specific beam search scoring** leads to more appropriate summary selection

In [17]:
# Generating with inbuilt generate() in the model
def generate_with_model(model, tokenizer, article, max_length=64, beam_size=5, 
                       top_k=50, temperature=0.7, early_stopping=True):
    """Generate a summary using the model's built-in generate() method"""
    # Tokenize input
    encoding = tokenizer.encode(
        article, 
        max_length=MAX_INPUT_LENGTH, 
        padding='max_length', 
        truncation=True
    )
    
    # Convert to tensor and move to device
    input_ids = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
    
    # Generate summary using the model's built-in method
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids,
            max_length=max_length,
            beam_size=beam_size,
            top_k=top_k,
            temperature=temperature,
            early_stopping=early_stopping
        )
    
    # Decode the generated tokens
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    
    # Basic post-processing
    generated_text = re.sub(r'\s+', ' ', generated_text).strip()
    if generated_text and not generated_text[-1] in '.!?':
        generated_text += '.'
    
    return generated_text

def evaluate_rouge_scores(model, tokenizer, test_dataset, num_examples=100):
    """Evaluate model performance using ROUGE metrics"""
    # Import required libraries
    from rouge import Rouge
    rouge = Rouge()
    
    # Prepare lists to store results
    references = []
    summaries = []
    
    # Process examples
    print(f"Generating summaries for {num_examples} examples...")
    test_subset = test_dataset.select(range(min(num_examples, len(test_dataset))))
    
    for i, example in enumerate(tqdm(test_subset)):
        article = example['article']
        reference = example['highlights']
        
        # Generate summary
        generated_summary = generate_with_model(
            model, 
            tokenizer, 
            article,
            max_length=64,  # Adjust as needed
            beam_size=4,
            top_k=50,
            temperature=0.7
        )
        
        references.append(reference)
        summaries.append(generated_summary)
        
        # Display progress examples
        if i % 10 == 0 or i == 0:
            print(f"\nExample {i+1}:")
            print(f"Article (truncated): {article[:200]}...")
            print(f"Reference: {reference}")
            print(f"Generated: {generated_summary}")
    
    # Calculate ROUGE scores
    print("\nCalculating ROUGE scores...")
    scores = rouge.get_scores(summaries, references, avg=True)
    
    # Print scores
    print("\nROUGE Scores:")
    print(f"ROUGE-1: {scores['rouge-1']['f']:.4f}")
    print(f"ROUGE-2: {scores['rouge-2']['f']:.4f}")
    print(f"ROUGE-L: {scores['rouge-l']['f']:.4f}")
    
    # Calculate additional metrics
    ref_lengths = [len(r.split()) for r in references]
    summary_lengths = [len(s.split()) for s in summaries]
    
    print("\nLength Statistics:")
    print(f"Average Reference Length: {sum(ref_lengths)/len(ref_lengths):.2f} words")
    print(f"Average Summary Length: {sum(summary_lengths)/len(summary_lengths):.2f} words")
    
    # Save results to CSV
    results_df = pd.DataFrame({
        'article': [example['article'][:500] + '...' for example in test_subset],
        'reference': references,
        'generated': summaries,
        'rouge1': [scores['rouge-1']['f']] * len(summaries),
        'rouge2': [scores['rouge-2']['f']] * len(summaries),
        'rougeL': [scores['rouge-l']['f']] * len(summaries)
    })
    
    results_df.to_csv('model_evaluation_results.csv', index=False)
    print("Results saved to model_evaluation_results.csv")
    
    return scores

# Load model and run evaluation
print("Loading best model...")
best_model_path = r"C:\Users\nisha\Downloads\best_model (1).pt"
checkpoint = torch.load(best_model_path, map_location=device)

if 'model_config' in checkpoint:
    config = checkpoint['model_config']
    model = ImprovedTransformer(
        src_vocab_size=config['vocab_size'],
        tgt_vocab_size=config['vocab_size'],
        d_model=config['d_model'],
        num_heads=config['num_heads'],
        d_ff=config['d_ff'],
        num_layers=config['num_layers'],
        dropout=config['dropout'],
        activation=config['activation']
    ).to(device)

model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

# Single example test
sample_article = cnn_dailymail['test'][0]['article']
reference = cnn_dailymail['test'][0]['highlights']

print("\nSample Article (first 200 chars):")
print(sample_article[:200] + "...")

print("\nReference Summary:")
print(reference)

print("\nGenerated Summary:")
summary = generate_with_model(model, tokenizer, sample_article)
print(summary)

# Calculate ROUGE for single example
rouge = Rouge()
scores = rouge.get_scores(summary, reference)[0]
print("\nROUGE Scores for sample:")
print(f"ROUGE-1: {scores['rouge-1']['f']:.4f}")
print(f"ROUGE-2: {scores['rouge-2']['f']:.4f}")
print(f"ROUGE-L: {scores['rouge-l']['f']:.4f}")

# Full evaluation
print("\n=== Running full evaluation ===")
evaluate_rouge_scores(model, tokenizer, cnn_dailymail['test'], num_examples=50)

Loading best model...

Sample Article (first 200 chars):
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territor...

Reference Summary:
Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .

Generated Summary:
Palestinian Authority officially becomes a 12 3rd member of the International Criminal Court . Palestinians signed Rome Stat ute in January , when they also accepted its jurisdiction over alleged crimes . Palestinian Foreign Minister Netanyahu says . ad al - M alk i ' s efforts to join the body . The Netherlands ' efforts to join the body . Palestinians may be.

ROUGE Scores for sample:
ROUGE-1: 0.2000
ROUGE-2: 0.0690
ROUGE-L: 0.1750

=== Running full evaluatio

  2%|▏         | 1/50 [00:01<00:51,  1.04s/it]


Example 1:
Article (truncated): (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territor...
Reference: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .
Generated: NEW : Palestinian Authority officially a State Party member of Palestine member of the International Criminal Court . Palestinians may be subject to counter - charges as well . Palestinians may be subject to counter - charges as well . Palestinian Foreign Minister Ri ad al - M alk i.


 22%|██▏       | 11/50 [00:11<00:40,  1.03s/it]


Example 11:
Article (truncated): London (CNN)A 19-year-old man was charged Wednesday with terror offenses after he was arrested as he returned to Britain from Turkey, London's Metropolitan Police said. Yahya Rashid, a UK national fro...
Reference: London's Metropolitan Police say the man was arrested at Luton airport after landing on a flight from Istanbul .
He's been charged with terror offenses allegedly committed since the start of November .
Generated: Yah ya Rashid , 19 , is due to appear in Westminster Magistrates ' Court on Wednesday . He was arrested at Westminster Magistrates ' Court on Wednesday . Rashid says . is due to appear in court on Wednesday . Rashid is due to appear in Westminster Magistrates ' Court on Wednesday . He is due to appear in Westminster Magistrates ' Court.


 42%|████▏     | 21/50 [00:21<00:30,  1.05s/it]


Example 21:
Article (truncated): Norfolk, Virginia (CNN)The second mate of the Houston Express probably couldn't believe what he was seeing. Hundreds of miles from land there was a small boat nearby. At first it looked abandoned. It ...
Reference: Father: "I know he went through what he went through"
Louis Jordan was found on his sailboat, which was listing and in bad shape, rescuer says .
He appears to be in good shape, physically and mentally .
Generated: The 1 , " He " looked good condition " It took so long " so he couldn ' d been drifting on the 35 - foot - foot capsized ' drifting on the 35 - foot Pearson sail boat for more than two months ' The weather wouldn ' father : " He looked good . Had n ' t lost too much.


 62%|██████▏   | 31/50 [00:32<00:21,  1.12s/it]


Example 31:
Article (truncated): (CNN)Police in the Indian city of Malegaon, in the western state of Maharashtra, are requiring identity cards for an unusual group of residents: Cattle. Following a recent state-wide ban on the sale a...
Reference: Authorities in the Indian city of Malegaon have asked residents to take a 'mugshot' of their cattle .
Cows are revered by the majority Hindu population, and many parts of the country have laws banning the slaughter of cattle .
Officials in Malegaon believe this is the best way to solve cow slaughter cases and enforce the law .
Generated: Mah ar ash tra are requiring identity cards for an unusual group of residents . Mah ar ash tra is the only way to solve cow slaughter cases and enforce the law . C ows are considered holy and revered by that state ' s majority Hindu population . Ban on the government doesn ' t have a right to interfere in an individual.


 82%|████████▏ | 41/50 [00:46<00:12,  1.43s/it]


Example 41:
Article (truncated): (CNN)A high temperature of 63.5 degrees Fahrenheit might sound like a pleasant day in early spring -- unless you're in Antarctica. The chilly continent recorded the temperature (15.5 degrees Celsius) ...
Reference: High temperatures are recorded on the northern tip of the Antarctica Peninsula .
The World Meteorological Organization will make the final determination .
Generated: The temperature was recorded at Argentina ' s Es per anza Base on the northern tip of the Antarctica Peninsula . The World Meteorological Organization , a specialized United Nations agency , is in the process of setting up an ad - hoc committee of about 10 blue - ribbon clim at ologists . The committee will examine the equipment used to measure the.


100%|██████████| 50/50 [00:59<00:00,  1.19s/it]


Calculating ROUGE scores...

ROUGE Scores:
ROUGE-1: 0.2315
ROUGE-2: 0.0716
ROUGE-L: 0.2192

Length Statistics:
Average Reference Length: 34.98 words
Average Summary Length: 62.76 words
Results saved to model_evaluation_results.csv





{'rouge-1': {'r': 0.2973948197047714,
  'p': 0.19780151505118926,
  'f': 0.23153300918980477},
 'rouge-2': {'r': 0.10313627817083003,
  'p': 0.057368809503389996,
  'f': 0.0716117097415556},
 'rouge-l': {'r': 0.28218548934238874,
  'p': 0.18721430976169418,
  'f': 0.21923602877855036}}

## Approach 2: Custom Beam Search Implementation

### Design Philosophy
The custom beam search algorithm provides fine-grained control over the generation process through several specialized mechanisms designed specifically for summarization. This approach prioritizes quality and readability over implementation simplicity.

### Key Implementation Features

- **Dynamic Length Control**: Adjusts generation length based on reference summary length, with configurable penalties to favor concise outputs
  
- **N-gram Repetition Prevention**: Actively blocks 3-gram repetitions that commonly occur in transformer outputs, dramatically improving readability

- **Sophisticated Beam Scoring**: Implements length normalization and adaptive scoring to balance between fluency and conciseness

- **Comprehensive Post-Processing Pipeline**: Includes sentence deduplication, proper ending punctuation, and removal of common generation artifacts

### Performance Analysis
The custom approach achieves superior results across all ROUGE metrics:
- ROUGE-1: 0.2483
- ROUGE-2: 0.0820
- ROUGE-L: 0.2310

Most notably, the generated summaries maintain lengths (average 42.62 words) much closer to reference summaries (34.46 words), contributing to higher precision scores.

In [14]:
# Improved Generation and Evaluation with Length Control
import nltk
from rouge_score import rouge_scorer
from rouge import Rouge
import pandas as pd
import re
from collections import defaultdict, Counter

def generate_summary(model, tokenizer, article, max_length=64, beam_size=5, 
                     min_length=20, length_penalty=0.6, 
                     top_k=0, temperature=0.5, early_stopping=True):

    # Prepare input
    encoding = tokenizer.encode(article, max_length=MAX_INPUT_LENGTH, padding='max_length', truncation=True)
    input_ids = torch.tensor(encoding['input_ids']).unsqueeze(0).to(device)
    
    # Generate with beam search
    with torch.no_grad():
        # Initialize generation parameters
        batch_size = input_ids.size(0)
        src_mask = model.create_src_mask(input_ids)
        encoder_output = model.encoder(input_ids, src_mask)
        
        # Initialize beams with SOS token
        beams = [{'tokens': [model.sos_token_id], 
                 'score': 0.0, 
                 'finished': False} for _ in range(beam_size)]
        
        # Generate up to max_length tokens
        for step in range(max_length):
            new_beams = []
            
            # Process each active beam
            for beam in beams:
                if beam['finished']:
                    new_beams.append(beam)
                    continue
                
                # Convert beam tokens to tensor
                curr_ids = torch.tensor([beam['tokens']], dtype=torch.long, device=device)
                
                # Forward pass through decoder
                tgt_mask = model.create_tgt_mask(curr_ids)
                decoder_output = model.decoder(curr_ids, encoder_output, src_mask, tgt_mask)
                logits = model.final_layer(decoder_output[:, -1])
                
                # Apply temperature
                logits = logits / temperature
                
                # Get top k tokens
                if top_k > 0:
                    topk_logits, topk_indices = torch.topk(logits, k=top_k)
                    logits = torch.full_like(logits, float('-inf'))
                    logits.scatter_(1, topk_indices, topk_logits)
                
                # Convert to probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Get top beam_size candidates
                topk_probs, topk_ids = torch.topk(probs, k=beam_size * 2)
                
                # Create new candidates
                for i in range(topk_ids.size(1)):
                    token_id = topk_ids[0, i].item()
                    token_prob = topk_probs[0, i].item()
                    
                    # Skip tokens that would create 3-gram repetitions
                    if len(beam['tokens']) >= 3:
                        last_3gram = beam['tokens'][-2:]
                        skip = False
                        for j in range(len(beam['tokens']) - 2):
                            if beam['tokens'][j:j+2] == last_3gram and beam['tokens'][j+2] == token_id:
                                skip = True
                                break
                        if skip:
                            continue
                    
                    # Calculate new score
                    # Length normalization: (5 + len)^length_penalty / (5^length_penalty)
                    new_len = len(beam['tokens']) + 1
                    length_norm = ((5 + new_len) ** length_penalty) / (5 ** length_penalty)
                    new_score = beam['score'] - math.log(token_prob) / length_norm
                    
                    new_tokens = beam['tokens'] + [token_id]
                    is_finished = token_id == model.eos_token_id or new_len >= max_length
                    
                    # Only add EOS if we're past min_length
                    if token_id == model.eos_token_id and new_len < min_length:
                        continue
                        
                    new_beams.append({
                        'tokens': new_tokens,
                        'score': new_score,
                        'finished': is_finished
                    })
            
            # Keep top beam_size beams
            new_beams = sorted(new_beams, key=lambda x: x['score'])[:beam_size]
            beams = new_beams
            
            # Early stopping: if all beams are finished
            if early_stopping and all(beam['finished'] for beam in beams):
                break
        
        # Select best beam
        best_beam = min(beams, key=lambda x: x['score'])
        generated_ids = best_beam['tokens']
        
        # Remove SOS and EOS tokens for decoding
        if generated_ids[0] == model.sos_token_id:
            generated_ids = generated_ids[1:]
        if generated_ids and generated_ids[-1] == model.eos_token_id:
            generated_ids = generated_ids[:-1]
    
    # Decode and post-process
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
    
    # Post-processing
    generated_text = post_process_summary(generated_text)
    
    return generated_text

def post_process_summary(text):
    """Improve summary quality through post-processing"""
    # Fix spacing issues
    text = re.sub(r'\s+', ' ', text).strip()
    
    # Remove duplicate sentences
    sentences = re.split(r'(?<=[.!?])\s+', text)
    unique_sentences = []
    seen = set()
    for sent in sentences:
        sent_lower = sent.lower()
        # Skip empty, very short, or duplicate sentences
        if sent and len(sent) > 10 and sent_lower not in seen:
            unique_sentences.append(sent)
            seen.add(sent_lower)
    
    # Join sentences
    text = ' '.join(unique_sentences)
    
    # Make sure text ends with proper punctuation
    if text and not text[-1] in '.!?':
        text += '.'
    
    # Remove redundant references to "NEW:"
    text = text.replace('NEW : ', '').replace('NEW: ', '')
    
    return text

def evaluate_model(model, test_dataset, tokenizer, num_examples=100):

    # Load the ROUGE scorer
    rouge = Rouge()
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    
    # Prepare lists to store results
    references = []
    summaries = []
    
    # Track additional metrics
    avg_reference_length = 0
    avg_summary_length = 0
    
    # Make sure the model is in evaluation mode
    model.eval()
    
    # Evaluate on a subset of test data
    test_subset = test_dataset.select(range(min(num_examples, len(test_dataset))))
    
    # Generate summaries and calculate ROUGE scores
    print(f"Generating summaries for {len(test_subset)} examples...")
    for i, example in enumerate(tqdm(test_subset)):
        article = example['article']
        reference = example['highlights']
        
        # Get reference length to guide generation length
        ref_length = len(reference.split())
        target_length = min(ref_length + 10, 60)  # Target slightly longer than reference but capped
        
        generated_summary = generate_summary(
            model, tokenizer, article, 
            max_length=target_length,
            min_length=min(20, target_length-5),
            beam_size=5,
            length_penalty=0.7,  # Favor shorter summaries
            top_k=50,
            temperature=0.5,     # Lower temperature for more focused generation
            early_stopping=True
        )
        
        references.append(reference)
        summaries.append(generated_summary)
        
        # Calculate additional metrics
        avg_reference_length += len(reference.split())
        avg_summary_length += len(generated_summary.split())
        
        # Print example every 10 items
        if i % 10 == 0:
            print(f"\nExample {i+1}:")
            print(f"Article (truncated): {article[:200]}...")
            print(f"Reference: {reference}")
            print(f"Generated: {generated_summary}")
    
    # Calculate average metrics
    avg_reference_length /= len(test_subset)
    avg_summary_length /= len(test_subset)
    
    # Calculate ROUGE scores
    try:
        # Try using the Rouge library first (faster for multiple examples)
        rouge_scores = rouge.get_scores(summaries, references, avg=True)
        
        print("\nROUGE Scores:")
        print(f"ROUGE-1: {rouge_scores['rouge-1']['f']:.4f}")
        print(f"ROUGE-2: {rouge_scores['rouge-2']['f']:.4f}")
        print(f"ROUGE-L: {rouge_scores['rouge-l']['f']:.4f}")
        
    except Exception as e:
        print(f"Error with Rouge library: {e}")
        print("Falling back to rouge_scorer...")
        
        # Use the rouge_scorer library as fallback
        rouge1_scores = []
        rouge2_scores = []
        rougeL_scores = []
        
        for ref, hyp in zip(references, summaries):
            try:
                scores = scorer.score(ref, hyp)
                rouge1_scores.append(scores['rouge1'].fmeasure)
                rouge2_scores.append(scores['rouge2'].fmeasure)
                rougeL_scores.append(scores['rougeL'].fmeasure)
            except Exception as e:
                print(f"Error scoring example: {e}")
        
        # Calculate average scores
        avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores) if rouge1_scores else 0
        avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores) if rouge2_scores else 0
        avg_rougeL = sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0
        
        print("\nROUGE Scores:")
        print(f"ROUGE-1: {avg_rouge1:.4f}")
        print(f"ROUGE-2: {avg_rouge2:.4f}")
        print(f"ROUGE-L: {avg_rougeL:.4f}")
    
    # Print additional metrics
    print("\nAdditional Metrics:")
    print(f"Average Reference Length: {avg_reference_length:.2f} words")
    print(f"Average Summary Length: {avg_summary_length:.2f} words")
    
    # Save detailed results to CSV
    results_df = pd.DataFrame({
        'article': [example['article'][:500] + '...' for example in test_subset],
        'reference': references,
        'generated': summaries,
        'summary_length': [len(s.split()) for s in summaries],
        'reference_length': [len(r.split()) for r in references]
    })
    
    # Save results
    results_df.to_csv('evaluation_results.csv', index=False)
    print("Results saved to evaluation_results.csv")

# Load best model
try:
    print("Loading best model for evaluation...")
    best_model_path = r"C:\Users\nisha\Downloads\best_model (1).pt"
    checkpoint = torch.load(best_model_path, map_location=device)
    
    # Check if we need to recreate the model from config
    if 'model_config' in checkpoint:
        config = checkpoint['model_config']
        print(f"Recreating model from saved config: {config}")
        
        # Create model with saved config
        model = ImprovedTransformer(
            src_vocab_size=config['vocab_size'],
            tgt_vocab_size=config['vocab_size'],
            d_model=config['d_model'],
            num_heads=config['num_heads'],
            d_ff=config['d_ff'],
            num_layers=config['num_layers'],
            dropout=config['dropout'],
            activation=config['activation']
        ).to(device)
    
    # Load state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    print("Model loaded successfully!")
    
    # Evaluate on test set
    evaluate_model(model, cnn_dailymail['test'], tokenizer, num_examples=100)
    
except Exception as e:
    print(f"Error during evaluation: {e}")
    import traceback
    traceback.print_exc()

Loading best model for evaluation...
Recreating model from saved config: {'vocab_size': 32000, 'd_model': 768, 'num_heads': 12, 'd_ff': 3072, 'num_layers': 6, 'dropout': 0.1, 'activation': 'gelu'}


INFO:absl:Using default tokenizer.


Model loaded successfully!
Generating summaries for 100 examples...


  1%|          | 1/100 [00:09<15:38,  9.48s/it]


Example 1:
Article (truncated): (CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territor...
Reference: Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June .
Israel and the United States opposed the move, which could open the door to war crimes investigations against Israelis .
Generated: The ICC officially becomes the 12 3rd member of the International Criminal Court . The formal access ion was marked with a ceremony at The Hague , in the Netherlands . Palestinians may be subject to counter - charges as well . Palestinian Foreign Minister.


 11%|█         | 11/100 [00:41<04:17,  2.90s/it]


Example 11:
Article (truncated): London (CNN)A 19-year-old man was charged Wednesday with terror offenses after he was arrested as he returned to Britain from Turkey, London's Metropolitan Police said. Yahya Rashid, a UK national fro...
Reference: London's Metropolitan Police say the man was arrested at Luton airport after landing on a flight from Istanbul .
He's been charged with terror offenses allegedly committed since the start of November .
Generated: Yah ya Rashid , 19 , was arrested at Luton airport on Tuesday . He is due to appear in Westminster Magistrates ' Court on Wednesday . Rashid is due in court on Wednesday , police say . He ' s been.


 21%|██        | 21/100 [01:12<04:08,  3.14s/it]


Example 21:
Article (truncated): Norfolk, Virginia (CNN)The second mate of the Houston Express probably couldn't believe what he was seeing. Hundreds of miles from land there was a small boat nearby. At first it looked abandoned. It ...
Reference: Father: "I know he went through what he went through"
Louis Jordan was found on his sailboat, which was listing and in bad shape, rescuer says .
He appears to be in good shape, physically and mentally .
Generated: Man in bad shape , listing to one side , saw there was a boat wrecked . He ' d been drifting on the 35 - foot Pearson sail boat for more than two months . His father says he was expecting his son to look different.


 31%|███       | 31/100 [01:42<03:27,  3.01s/it]


Example 31:
Article (truncated): (CNN)Police in the Indian city of Malegaon, in the western state of Maharashtra, are requiring identity cards for an unusual group of residents: Cattle. Following a recent state-wide ban on the sale a...
Reference: Authorities in the Indian city of Malegaon have asked residents to take a 'mugshot' of their cattle .
Cows are revered by the majority Hindu population, and many parts of the country have laws banning the slaughter of cattle .
Officials in Malegaon believe this is the best way to solve cow slaughter cases and enforce the law .
Generated: Indian police ask residents to take mug shot of their cattle and submit it to the police . C ows are holy and revered by that state ' s majority Hindu population . The ban on the sale and consumption of beef is still per missible . The slaughter of buff alo es is still a concern for the.


 41%|████      | 41/100 [02:12<02:46,  2.82s/it]


Example 41:
Article (truncated): (CNN)A high temperature of 63.5 degrees Fahrenheit might sound like a pleasant day in early spring -- unless you're in Antarctica. The chilly continent recorded the temperature (15.5 degrees Celsius) ...
Reference: High temperatures are recorded on the northern tip of the Antarctica Peninsula .
The World Meteorological Organization will make the final determination .
Generated: The temperature was recorded at Argentina ' s Es per anza Base on the northern tip of the Antarctica Peninsula . The agency is in the process of setting up an ad.


 51%|█████     | 51/100 [02:40<02:13,  2.72s/it]


Example 51:
Article (truncated): (CNN)According to an outside review by Columbia Journalism School professors, "(a)n institutional failure at Rolling Stone resulted in a deeply flawed article about a purported gang rape at the Univer...
Reference: An outside review found that a Rolling Stone article about campus rape was "deeply flawed"
Danny Cevallos says that there are obstacles to a successful libel case, should one be filed .
Generated: Columbia Journal ism School professors say the university ' s " failure " was a n institutional failure . The university says it ' s a " governmental entity " to eliminate U VA . The publication of the article is.


 61%|██████    | 61/100 [03:14<02:13,  3.42s/it]


Example 61:
Article (truncated): Hong Kong (CNN)Six people were hurt after an explosion at a controversial chemical plant in China's southeastern Fujian province sparked a huge fire, provincial authorities told state media. The plant...
Reference: A blast rocks a chemical plant in China's southeastern Fujian province for the second time in two years .
Six were injured after the explosion and are being hospitalized .
The explosion was triggered by an oil leak, though local media has not reported any toxic chemical spills .
Generated: The blast occurred at an oil storage facility Monday night . Residents living close to the plant had heard the explosion . The plant was hit by another explosion in July 2013 . The explosion sparked a huge fire in China ' s southeastern Fuj ian province . The Zhang zhou plant produces par ax yl ene (.


 71%|███████   | 71/100 [03:42<01:18,  2.71s/it]


Example 71:
Article (truncated): (CNN)A nuclear submarine being repaired at a Russian shipyard has caught on fire, according to a law enforcement source speaking to Russia's state-run news agency ITAR-Tass. "The submarine is in a dry...
Reference: Submarine is in Zvyozdochka shipyard, in northwestern Russia .
No "dangerous" substances on the submarine, shipyard spokesman told ITAR-Tass .
Generated: A nuclear submarine is being repaired at a Russian ship yard . The sub is being used as we lding work . The fire began on a sub in.


 81%|████████  | 81/100 [04:13<01:01,  3.25s/it]


Example 81:
Article (truncated): Cedar Falls, Iowa (CNN)As aides politely tried to rush Ted Cruz from an event in Cedar Falls to one in Cedar Rapids, Iowa, on Thursday, the presidential candidate continued shaking hands with anyone w...
Reference: Ted Cruz has built a brand as a stalwart conservative on fiscal issues .
But he's also eager to champion social issues at a time when many Republicans are eager to avoid them .
Cruz says the GOP needs to unite young libertarian-minded voters and evangelicals .
Generated: Ted Cruz drew crowds during his two - day swing across the state . The Iowa senator regularly avoids using a podium . He ' s the only official contender in the race , and is working a room . Cruz has built a brand as a stal wart conservative willing to buck GOP leadership .


 91%|█████████ | 91/100 [04:51<00:31,  3.53s/it]


Example 91:
Article (truncated): (CNN)They're not gonna take it anymore. Really. Twisted Sister says that its 2016 tour will be its last, according to a press release. Next year marks the band's 40th anniversary, and to celebrate, th...
Reference: Twisted Sister's 2016 tour will be its last .
Band will celebrate 40 years in 2016 .
Twisted Sister drummer A.J. Pero died in March .
Generated: Tw isted Sister says 2016 tour will be its last , according to a press release . The tour is titled " Forty and F * ck It " The band will play with a.


100%|██████████| 100/100 [05:22<00:00,  3.23s/it]


ROUGE Scores:
ROUGE-1: 0.2483
ROUGE-2: 0.0820
ROUGE-L: 0.2310

Additional Metrics:
Average Reference Length: 34.46 words
Average Summary Length: 42.62 words
Results saved to evaluation_results.csv



