# NLP Project: Extractive Summarization with BERT + BiLSTM + Attention

## Project Goal
- Build an extractive summarization model for messy study notes, lecture transcripts, and textbook excerpts.
- **Core Model**: BERT (Freeze) + BiLSTM + Attention.
- **Datasets**: Webis-TLDR-17 + WikiHow.
- **Improvements**: 
    1. Pre-trained BERT Embeddings (Better Semantics)
    2. Trigram Blocking (Redundancy Removal)
    3. Validation Split & Scheduler (Better Training)

## Methodology
1. **Data Loading**: Load and combine datasets.
2. **Preprocessing**: Convert abstractive summaries to extractive labels (Oracle extraction).
3. **Tokenization**: Use BERT tokenizer.
4. **Model**: BERT -> BiLSTM -> Attention -> Sentence Importance Score.
5. **Training**: Binary Cross Entropy Loss with Validation Checkpointing.

## 1. Install Dependencies
Note: For GPU support on Windows/Linux, we specify the CUDA version for PyTorch.

In [12]:
%pip install datasets transformers rouge-score nltk
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


Looking in indexes: https://download.pytorch.org/whl/cu121
Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


## 2. Imports & Setup

In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from datasets import load_dataset, concatenate_datasets
from transformers import BertTokenizer, BertModel
from rouge_score import rouge_scorer
import nltk
import numpy as np
import pandas as pd
from tqdm import tqdm
import re
import os

# Download NLTK data
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')
    nltk.download('punkt_tab')

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

Using device: cuda
GPU: NVIDIA GeForce RTX 3060 Laptop GPU


## 3. Load Datasets
We load Webis-TLDR-17 and WikiHow. We ensure both have uniform columns: `text` and `summary`.
Note: If WikiHow is unavailable, we fall back to CNN/DailyMail.

In [14]:
# Configuration
MAX_SAMPLES = 6000 # Increased slightly for better training

# 1. Load Webis-TLDR-17
print("Loading Webis-TLDR-17...")
try:
    dataset_webis = load_dataset("webis/tldr-17", split="train[:5%]")
except Exception as e:
    print(f"Webis dataset load failed: {e}. Proceeding without it.")
    dataset_webis = None

# 2. Load WikiHow (with Fallback)
print("Loading WikiHow...")
try:
    dataset_wikihow = load_dataset("wikihow", "all", split="train[:10%]")
except Exception as e:
    print(f"WikiHow load failed: {e}. Switching to CNN/DailyMail as fallback.")
    dataset_wikihow = load_dataset("cnn_dailymail", "3.0.0", split="train[:10%]")

# 3. Unify Columns
def unify_columns(dataset, text_col, summary_col):
    cols_to_keep = [text_col, summary_col]
    dataset = dataset.remove_columns([c for c in dataset.column_names if c not in cols_to_keep])
    dataset = dataset.rename_column(text_col, "text")
    dataset = dataset.rename_column(summary_col, "summary")
    return dataset

if dataset_webis:
    if 'content' in dataset_webis.column_names:
        dataset_webis = unify_columns(dataset_webis, 'content', 'summary')
    elif 'body' in dataset_webis.column_names:
        dataset_webis = unify_columns(dataset_webis, 'body', 'summary')

# Check for WikiHow or Fallback columns
if 'headline' in dataset_wikihow.column_names: 
    dataset_wikihow = unify_columns(dataset_wikihow, 'text', 'headline')
elif 'highlights' in dataset_wikihow.column_names:
    dataset_wikihow = unify_columns(dataset_wikihow, 'article', 'highlights')

# 4. Concatenate
if dataset_webis:
    combined_dataset = concatenate_datasets([dataset_webis, dataset_wikihow])
else:
    combined_dataset = dataset_wikihow

# Shuffle and limit
combined_dataset = combined_dataset.shuffle(seed=42).select(range(min(len(combined_dataset), MAX_SAMPLES)))

print(f"Combined Dataset Size: {len(combined_dataset)}")

Loading Webis-TLDR-17...
Webis dataset load failed: Dataset scripts are no longer supported, but found tldr-17.py. Proceeding without it.
Loading WikiHow...
WikiHow load failed: Dataset 'wikihow' doesn't exist on the Hub or cannot be accessed.. Switching to CNN/DailyMail as fallback.
Combined Dataset Size: 6000


## 4. Preprocessing: Oracle Label Generation
Extract sentence-level labels (0/1) based on ROUGE overlap with the abstractive summary.

In [15]:
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

def greedy_extraction(text, summary):
    # Split sentences
    sentences = nltk.sent_tokenize(text)
    
    # Clean/lowercase
    sentences = [s.strip() for s in sentences if len(s.split()) > 3] 
    
    if not sentences:
        return [], []

    labels = []
    for sent in sentences:
        scores = scorer.score(summary, sent)
        r1_rec = scores['rouge1'].recall
        rL_rec = scores['rougeL'].recall
        
        # Label 1 if sentence overlaps significantly with summary
        if r1_rec > 0.15 or rL_rec > 0.15: # Slightly stricter threshold for better quality
            labels.append(1)
        else:
            labels.append(0)
            
    return sentences, labels

## 5. Dataset Class & Tokenizer
We use `BertTokenizer`.

In [16]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

class ExtractiveDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_len=64):
        self.samples = []
        self.tokenizer = tokenizer
        self.max_len = max_len
        
        print("Preprocessing data...")
        for item in tqdm(hf_dataset):
            text = item['text']
            summary = item['summary']
            
            sents, labels = greedy_extraction(text, summary)
            for sent, label in zip(sents, labels):
                self.samples.append((sent, label))
            
        print(f"Processed {len(self.samples)} sentences.")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sentence, label = self.samples[idx]
        
        encoding = self.tokenizer(
            sentence,
            truncation=True,
            padding='max_length',
            max_length=self.max_len,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': torch.tensor(label, dtype=torch.float)
        }

# Create Dataset and Split
full_dataset = ExtractiveDataset(combined_dataset, tokenizer, max_len=64)

# 90/10 Train/Val Split
train_size = int(0.9 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}")

Preprocessing data...


100%|██████████| 6000/6000 [01:52<00:00, 53.53it/s]

Processed 178970 sentences.
Train samples: 161073, Val samples: 17897





## 6. Improved Model: BERT + BiLSTM + Attention
We replace the basic embedding layer with **BERT** to capture deep context. We freeze BERT to save compute/memory, treating it as a feature extractor.

In [17]:
class BiLSTMBertAttention(nn.Module):
    def __init__(self, hidden_dim):
        super(BiLSTMBertAttention, self).__init__()
        # 1. Load Pre-trained BERT
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        # Freeze BERT parameters
        for param in self.bert.parameters():
            param.requires_grad = False
            
        bert_dim = 768 # bert-base output dimension
        
        # 2. BiLSTM
        self.lstm = nn.LSTM(bert_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # 3. Attention Weights
        self.W_w = nn.Linear(hidden_dim * 2, hidden_dim * 2)
        self.u_w = nn.Linear(hidden_dim * 2, 1, bias=False)
        
        # 4. Classification Head
        self.fc = nn.Linear(hidden_dim * 2, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask):
        # Get BERT features | [Batch, SeqLen, 768]
        with torch.no_grad():
            bert_outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            embeddings = bert_outputs.last_hidden_state
        
        # LSTM | [Batch, SeqLen, Hidden*2]
        lstm_out, _ = self.lstm(embeddings)
        
        # Attention
        # u = tanh(W * h)
        u = torch.tanh(self.W_w(lstm_out))
        # score = u * u_w
        scores = self.u_w(u) # [Batch, SeqLen, 1]
        
        # Masking
        mask = attention_mask.unsqueeze(-1)
        scores = scores.masked_fill(mask == 0, -1e9)
            
        alpha = torch.softmax(scores, dim=1) # [Batch, SeqLen, 1]
        
        # Sentence Vector
        sentence_vector = torch.sum(lstm_out * alpha, dim=1) # [Batch, Hidden*2]
        
        # Prediction
        prediction = self.sigmoid(self.fc(sentence_vector))
        
        return prediction.squeeze()

## 7. Improved Training Loop
- Validation Loss tracking
- ReduceLROnPlateau Scheduler
- Save Best Model Checkpoint

In [18]:
model = BiLSTMBertAttention(hidden_dim=128).to(device)

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4) # Lower LR for fine-tuning stability
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=1, verbose=True)

EPOCHS = 4
best_val_loss = float('inf')

for epoch in range(EPOCHS):
    # --- Training ---
    model.train()
    train_loss = 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1} Train")
    
    for batch in progress:
        input_ids = batch['input_ids'].to(device)
        mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        optimizer.zero_grad()
        outputs = model(input_ids, mask)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        progress.set_postfix(loss=loss.item())
        
    avg_train_loss = train_loss / len(train_loader)
    
    # --- Validation ---
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            
            outputs = model(input_ids, mask)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            
    avg_val_loss = val_loss / len(val_loader)
    
    # Scheduler Step
    scheduler.step(avg_val_loss)
    
    print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
    
    # Checkpointing
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_model_bert.pth")
        print("  -> Saved Best Model!")

Epoch 1 Train: 100%|██████████| 5034/5034 [06:46<00:00, 12.39it/s, loss=0.464]


Epoch 1 | Train Loss: 0.4190 | Val Loss: 0.4125
  -> Saved Best Model!


Epoch 2 Train: 100%|██████████| 5034/5034 [06:56<00:00, 12.09it/s, loss=0.263]


Epoch 2 | Train Loss: 0.4028 | Val Loss: 0.3974
  -> Saved Best Model!


Epoch 3 Train: 100%|██████████| 5034/5034 [07:15<00:00, 11.56it/s, loss=0.556] 


Epoch 3 | Train Loss: 0.3964 | Val Loss: 0.3971
  -> Saved Best Model!


Epoch 4 Train: 100%|██████████| 5034/5034 [07:21<00:00, 11.39it/s, loss=0.342]


Epoch 4 | Train Loss: 0.3896 | Val Loss: 0.3971
  -> Saved Best Model!


## 8. Trigram Blocking Evaluation
We implement Trigram Blocking to reduce repetition in the final summary.

In [19]:
def get_trigrams(sentence):
    tokens = sentence.lower().split()
    if len(tokens) < 3:
        return set()
    return set(tuple(tokens[i:i+3]) for i in range(len(tokens)-2))

def has_trigram_overlap(sent, selected_sents):
    # Check if 'sent' shares any trigram with already selected sentences
    new_trigrams = get_trigrams(sent)
    if not new_trigrams: return False
    
    for existing in selected_sents:
        existing_trigrams = get_trigrams(existing)
        if not new_trigrams.isdisjoint(existing_trigrams):
            return True
    return False

def summarize_improved(text, model, tokenizer, device, top_k=3):
    model.eval()
    sentences = nltk.sent_tokenize(text)
    cleaned_sentences = [s.strip() for s in sentences if len(s.split()) > 3]
    
    if not cleaned_sentences: return "No content"
    
    inputs = tokenizer(
        cleaned_sentences,
        return_tensors='pt',
        truncation=True,
        padding='max_length',
        max_length=64
    )
    
    with torch.no_grad():
        input_ids = inputs['input_ids'].to(device)
        mask = inputs['attention_mask'].to(device)
        scores = model(input_ids, mask)
        
    if scores.ndim == 0: scores = scores.unsqueeze(0)
    scores = scores.cpu().numpy()
    
    sorted_indices = scores.argsort()[::-1]
    
    selected_sents = []
    for idx in sorted_indices:
        candidate = cleaned_sentences[idx]
        
        # Trigram Blocking
        if not has_trigram_overlap(candidate, selected_sents):
            selected_sents.append(candidate)
            
        if len(selected_sents) >= top_k:
            break
            
    return " ".join(selected_sents)

# Load Best Model for Inference
model.load_state_dict(torch.load("best_model_bert.pth"))

# Sample Test
test_idx = 0
txt = combined_dataset[test_idx]['text']
ref = combined_dataset[test_idx]['summary']
pred = summarize_improved(txt, model, tokenizer, device)

print("Reference:", ref)
print("Prediction:", pred)
print("-"*30)

# Quick ROUGE eval
eval_scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
scores = eval_scorer.score(ref, pred)
print("ROUGE-1:", scores['rouge1'].fmeasure)
print("ROUGE-L:", scores['rougeL'].fmeasure)

  model.load_state_dict(torch.load("best_model_bert.pth"))


Reference: A small amount of radioactive gas escaped from a steam generator, the NRC says .
The leak does not pose any threat to human health, an NRC spokesman says .
Operators shut down the No. 3 reactor at California's San Onofre plant as a result .
Prediction: (CNN) -- A small amount of radioactive gas escaped from a steam generator at Southern California's San Onofre nuclear power plant during a water leak, but there was no threat to public health, federal regulators said Wednesday. The water leak occurred in the thousands of tubes that carry heated water from the reactor core through the steam generator, a 65-foot-tall, 640-ton piece of equipment that boils water used to drive the unit's turbines. Though leaking tubes periodically occur in older units, Dricks said, Southern California Edison replaced the steam generators at San Onofre between 2009 and 2011.
------------------------------
ROUGE-1: 0.37241379310344824
ROUGE-L: 0.27586206896551724
