# Assignment 2: Implementing Decoding Strategies for Summarization (40 points)

In this assignment, you will **implement decoding algorithms** used for text summarization using a pretrained Transformer model.

---

### Your Task
1. **Load** the `sshleifer/distilbart-cnn-12-6` summarization model.
2. **Implement the following decoding strategies from scratch** (no `.generate()` allowed!). You need to provide your own explanation on your implementation for each function:
   - Greedy decoding (3 points)
   - Top-k sampling (3 points)
   - Top-p (nucleus) sampling (3 points)
   - Beam search (3 points)
   - Beam search with n-gram blocking (3 points)
3. Use your decoder to summarize 200 articles from the CNN/DailyMail dataset.
4. Implement the following ROUGE metrics and evaluate your summaries using your own ROUGE metric implementation.
  - Implementation
    - ROUGE-n (2 points): e.g., ROUGE-1 & ROUGE-2
    - ROUGE-L (4 points)
  - Explanation on your implementation (2 points)
  - Discuss how to improve these metrics to perform a better evaluation? (2 points)
5. Discussion (15 points)

---

**Note:**
  - Regarding the decoding strategies, you are expected to work directly with model logits and sampling logic. Do not use `model.generate()` or any pre-built function. **Hint**: use "outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)"
  - For each function, you need to write clear and concise comments about your implementation. This may not be line-by-line, but rather meaningful chunk of codes.
  - For each question, justify your answer with explanations.

### **We highly recommend using 'cpu' as the default for development, and switching to 'gpu' only for evaluating summarization performance. This is due to the limited GPU availability in the Colab environment.**


In [None]:
!pip install transformers datasets --quiet


In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import torch
import json
import torch.nn.functional as F
import random

# Load model and tokenizer
model_name = "sshleifer/distilbart-cnn-12-6"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
model.eval()


In this assignment, we use the CNN/DailyMail summarization dataset, a widely-used benchmark for training and evaluating text summarization models.

Each data sample consists of:

- article: A news story, usually between 300 and 800 words.

- highlights: A bullet-style abstractive summary of the article, written by human editors.

This dataset is ideal for testing decoding strategies because:

- The summaries are relatively short and factual

- The dataset is large enough to support diverse decoding behavior

- It has been used in many summarization papers, making ROUGE scores easy to compare

In this assignment, we will use a small subset of 200 samples from the validation set for quick experimentation and evaluation.

In [None]:
# Upload cnn_dm_200.json file

# Load 200-sample dataset
with open("cnn_dm_200.json", "r") as f:
    dataset = json.load(f)

print(f"Loaded {len(dataset)} samples.")
print("\nExample article:")
print(dataset[0]["article"])
print("\nReference summary:")
print(dataset[0]["highlights"])


Loaded 200 samples.

Example article:
(CNN)The Palestinian Authority officially became the 123rd member of the International Criminal Court on Wednesday, a step that gives the court jurisdiction over alleged crimes in Palestinian territories. The formal accession was marked with a ceremony at The Hague, in the Netherlands, where the court is based. The Palestinians signed the ICC's founding Rome Statute in January, when they also accepted its jurisdiction over alleged crimes committed "in the occupied Palestinian territory, including East Jerusalem, since June 13, 2014." Later that month, the ICC opened a preliminary examination into the situation in Palestinian territories, paving the way for possible war crimes investigations against Israelis. As members of the court, Palestinians may be subject to counter-charges as well. Israel and the United States, neither of which is an ICC member, opposed the Palestinians' efforts to join the body. But Palestinian Foreign Minister Riad al-Malki

## Decoding in Language Models
Pretrained language models like BART or GPT generate text by predicting the most likely next token one at a time.
However, simply choosing the most probable token at each step often leads to deterministic, repetitive, or generic outputs.

To overcome this, several decoding strategies have been proposed to balance:

- fluency vs. diversity

- coherence vs. novelty

In this assignment, you will implement five widely-used decoding strategies from scratch and compare their effects.

### Greedy Decoding

In [None]:
def greedy_decode(input_ids, max_length=128):
    """
    Perform greedy decoding from the model using logits.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        max_length (int): Maximum length of the generated sequence.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Greedy Decoding from scratch
    pass


In [None]:
def greedy_decode(input_ids, max_length=128):
    """
    Perform greedy decoding from the model using logits.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        max_length (int): Maximum length of the generated sequence.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Greedy Decoding from scratch
    input_ids = input_ids.to(device)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) # shape 1,1 # 
    eos_token_id = model.config.eos_token_id # take the EOS token id from the vocab
    with torch.no_grad():
      for _ in range(max_length):
          print(f"decode token {_}")
          output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
          proba_output = torch.softmax(output.logits[:,-1,:],axis=-1) # shape 1,vocab_size
          id = torch.tensor([[torch.argmax(proba_output, axis=-1)]]).to(device) # greedy approach to take the token_id with largest probability 

          decoder_input_ids = torch.cat([decoder_input_ids,id],axis=-1) # append the decoded token to the generated token sequence
      
          if id.item() == eos_token_id: # If the model generate an EOS token, the decoding process should be stop
              break
        
    summary = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=False) # convert the output token sequence to string 
    

    return summary


Explain your implementation

### Top-k Decoding

In [None]:
def top_k_decode(input_ids, k=50, max_length=128, temperature=1.0):
    """
    Perform Top-k sampling decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        k (int): Number of top tokens to sample from.
        max_length (int): Maximum length of the generated sequence.
        temperature (float): Softmax temperature for sampling.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Top-k Sampling Decoding
    pass

In [None]:
def top_k_decode(input_ids, k=2, max_length=128, temperature=1.0):
    """
    Perform Top-k sampling decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        k (int): Number of top tokens to sample from.
        max_length (int): Maximum length of the generated sequence.
        temperature (float): Softmax temperature for sampling.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Top-k Sampling Decoding
    input_ids = input_ids.to(device)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) # shape 1,1 # 
    eos_token_id = model.config.eos_token_id # take the EOS token id from the vocab
    with torch.no_grad():
      for _ in range(max_length):
        print(f"decode token {_}")

        output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        output_proba = torch.softmax(output.logits[:,-1,:]/temperature,axis=-1) 
        # print(f"output_proba shape: {output_proba.shape}")
        values, indices = torch.topk(output_proba, k=k, axis=-1)
        # print(indices.shape,values.shape)

        
        values = values/ torch.sum(values,axis=-1,keepdim=True).item() # redistribute proba
        
        local_index = torch.multinomial(values[0],num_samples=1) # sampling from redistributed distribution
        id = indices[:,local_index]
        
        decoder_input_ids = torch.cat([decoder_input_ids,id],axis=-1)
        
        # print(decoder_input_ids[0])
        if id.item() == eos_token_id:
          break

    summary = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=False)
    return summary

        
    

- Explain your implementation.
- What are the role of hyperparameters of this strategy? and what happen if you change the values?

### Top-p Decoding

In [None]:
def top_p_decode(input_ids, p=0.9, max_length=128, temperature=1.0):
    """
    Perform Top-p (nucleus) sampling decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        p (float): Cumulative probability threshold for sampling.
        max_length (int): Maximum length of the generated sequence.
        temperature (float): Softmax temperature for sampling.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Top-p Sampling Decoding
    pass

In [None]:
def top_p_decode(input_ids, p=0.9, max_length=128, temperature=1.0):
    """
    Perform Top-p (nucleus) sampling decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        p (float): Cumulative probability threshold for sampling.
        max_length (int): Maximum length of the generated sequence.
        temperature (float): Softmax temperature for sampling.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Top-p Sampling Decoding
    
    input_ids = input_ids.to(device)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) # shape 1,1 
    eos_token_id = model.config.eos_token_id # take the EOS token id from the vocab
    with torch.no_grad():
      for _ in range(max_length):
        print(f"decode token {_}")

        output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
        output_proba = torch.softmax(output.logits[:,-1,:]/temperature,axis=-1) 
        # print(output_proba.shape)
        probs = output_proba[0]
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        # print(sorted_probs,sorted_probs.shape)
        # print(sorted_indices)
        cumulative_probs = torch.cumsum(sorted_probs, dim=0)
        # print(cumulative_probs)
        nucleus_mask = cumulative_probs <= p
        nucleus_mask[max((cumulative_probs > p).nonzero(as_tuple=True)[0][0], 0)] = True # Ensure we include at least one token after passing threshold
        # print(nucleus_mask)
        nucleus_probs = sorted_probs[nucleus_mask]
        nucleus_indices = sorted_indices[nucleus_mask]
        # print(nucleus_indices,nucleus_probs)
        nucleus_probs = nucleus_probs/torch.sum(nucleus_probs)
        # print(nucleus_probs)
        
        local_index = torch.multinomial(nucleus_probs,num_samples=1) # sampling from redistributed distribution
        id = torch.tensor([[nucleus_indices[local_index]]]).to(device)
        # print(id.shape)
        decoder_input_ids = torch.cat([decoder_input_ids,id],axis=-1)

        if id.item() == eos_token_id:
          break
        
        
    summary = tokenizer.decode(decoder_input_ids[0], skip_special_tokens=False)
    return summary



- Explain your implementation.
- What are the role of hyperparameters of this strategy? and what happen if you change the values?

### Beam Search Decoding

In [None]:
def beam_search_decode(input_ids, beam_size=4, max_length=128):
    """
    Perform beam search decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        beam_size (int): Number of beams to explore.
        max_length (int): Maximum length of the generated sequence.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Beam Search Decoding
    pass


In [None]:
def beam_search_decode(input_ids, beam_size=4, max_length=128):
    """
    Perform beam search decoding from the model.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        beam_size (int): Number of beams to explore.
        max_length (int): Maximum length of the generated sequence.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Beam Search Decoding
    input_ids = input_ids.to(device)
    decoder_input_ids = torch.tensor([[model.config.decoder_start_token_id]]).to(device) # shape 1,1 
    eos_token_id = model.config.eos_token_id # take the EOS token id from the vocab
    vocab_size = model.config.vocab_size
    
    top_beam = []
    with torch.no_grad():
      output = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
    output = output.logits[:,-1,:]
    output_proba = torch.softmax(output,axis=-1)
    probas, indices = torch.topk(output_proba,k=3,axis=-1)
    for index,proba in zip(indices[0],probas[0]):
      print(index,proba)
      top_beam.append((torch.tensor([[decoder_input_ids,index]]),proba.item()))
    
    # print(top3)

    # print(top_beam)
    for _ in range(max_length):
      beam_list = []
      
      for beam in top_beam:
        # print(beam)
        with torch.no_grad():
          output = model(input_ids=input_ids, decoder_input_ids=beam[0])
        output = output.logits[:,-1,:]
        output_proba = beam[1] * torch.softmax(output,axis=-1)
        # print(beam[1],output_proba.shape)
        assert output_proba.shape == output.shape, f"shape output_proba must be [1,{vocab_size}]"
        for idx,local_proba in enumerate(list(output_proba[0])):
          
          beam_list.append((torch.cat([beam[0],torch.tensor([[idx]])],axis=-1), local_proba.item()))

      sorted_beam_list = sorted(beam_list, key=lambda item: item[1], reverse=True)
      
      top_beam = sorted_beam_list[:beam_size]
      
      if top_beam[0][0][0][-1].item() == eos_token_id:
        break
      # print(top_beam)


    summary = tokenizer.decode(top_beam[0][0][0], skip_special_tokens=False)
    return summary



- Explain your implementation.
- What are the role of hyperparameters of this strategy? and what happen if you change the values?

### Beam Search with N-gram Blocking

In [None]:
def beam_search_ngram_block(input_ids, beam_size=4, max_length=128, no_repeat_ngram_size=3):
    """
    Perform beam search decoding with n-gram repetition blocking.

    Args:
        input_ids (torch.Tensor): Tokenized input tensor of shape [1, seq_len].
        beam_size (int): Number of beams to explore.
        max_length (int): Maximum length of the generated sequence.
        no_repeat_ngram_size (int): Size of n-gram to prevent from repeating.

    Returns:
        str: Decoded summary text.
    """
    # TODO: Implement Beam Search with n-gram blocking
    pass

- Explain your implementation.
- What are the role of hyperparameters of this strategy? and what happen if you change the values?

In [None]:
# Example usage with the above decoding functions
input_text = dataset[0]['article']
input_ids = tokenizer(input_text, return_tensors="pt", truncation=True, max_length=1024).input_ids

print("Greedy Search:")
print(greedy_decode(input_ids))

print("\nBeam Search + N-gram Blocking:")
print(beam_search_ngram_block(input_ids))


## ROUGE Metric Evaluation

In [None]:
def tokenize(text):
    return re.findall(r'\w+', text.lower())

def compute_rouge_n(reference: str, generated: str, n: int = 1) -> dict:
    """
    Compute ROUGE-N score between reference and generated text.

    Args:
        reference (str): The reference summary.
        generated (str): The generated summary.
        n (int): The n-gram size (e.g., 1 for ROUGE-1).

    Returns:
        dict: Dictionary with 'precision', 'recall', and 'f1' scores.
    """
    # HINT: Implement ROUGE-N calculation using n-gram overlap
    pass

  def compute_rouge_l(reference: str, generated: str) -> dict:
    """
    Compute ROUGE-L (Longest Common Subsequence) score.

    Args:
        reference (str): The ground truth summary.
        generated (str): The generated summary by the model.

    Returns:
        dict: Dictionary with precision, recall, and f1 scores.
    """
    # HINT: Implement Longest Common Subsequence algorithm
    pass


- Explanation on your implementation
- Discuss how to improve these metrics to perform a better evaluation?

## Evaluation

In [None]:
def generate_summary_custom(article_text, strategy="greedy"):
    input_ids = tokenizer(article_text, return_tensors="pt", truncation=True, max_length=1024).input_ids
    if strategy == "greedy":
        return greedy_decode(input_ids)
    elif strategy == "top_k":
        return top_k_decode(input_ids)
    elif strategy == "top_p":
        return top_p_decode(input_ids)
    elif strategy == "beam":
        return beam_search_decode(input_ids)
    elif strategy == "beam_block":
        return beam_search_ngram_block(input_ids)
    else:
        raise ValueError("Unknown decoding strategy")


In [None]:
# Compare summaries across all decoding strategies for one article
sample_article = dataset[0]["article"]
for strategy in ["greedy", "top_k", "top_p", "beam", "beam_block"]:
    print(f"\n[{strategy.upper()}]")
    print(generate_summary_custom(sample_article, strategy=strategy))



In [None]:
import pandas as pd
import numpy as np

# Decoding strategies
strategies = ["greedy", "top_k", "top_p", "beam", "beam_block"]

# Save evalution results
results = {s: [] for s in strategies}

# Evaluation loop
for i, sample in enumerate(dataset):
    article = sample["article"]
    reference = sample["highlights"]

    input_ids = tokenizer(article, return_tensors="pt", truncation=True, max_length=1024).input_ids

    for strategy in strategies:
        try:
            generated = generate_summary_custom(article, strategy=strategy)

            rouge1 = compute_rouge_n(reference, generated, n=1)["f1"]
            rouge2 = compute_rouge_n(reference, generated, n=2)["f1"]
            rougel = compute_rouge_l(reference, generated)["f1"]

            results[strategy].append({
                "rouge1": rouge1,
                "rouge2": rouge2,
                "rougeL": rougel
            })
        except Exception as e:
            print(f"[{strategy}] Error on sample {i}: {e}")


In [None]:
summary = {}
for strategy in strategies:
    if results[strategy]:
        rouge1s = [x["rouge1"] for x in results[strategy]]
        rouge2s = [x["rouge2"] for x in results[strategy]]
        rougels = [x["rougeL"] for x in results[strategy]]

        summary[strategy] = {
            "ROUGE-1": np.mean(rouge1s),
            "ROUGE-2": np.mean(rouge2s),
            "ROUGE-L": np.mean(rougels)
        }

df = pd.DataFrame(summary).T.sort_values("ROUGE-L", ascending=False)
display(df)


## Discussion (Total: 15 points)
Answer the following questions based on your decoding outputs and analysis. Be clear and support your claims with examples. You may provide your evidence by implementing additional functions. For example, you may draw a plot or show a statistics.

1. Compare the different decoding strategies. Present and justify your findings using examples from your generated summaries. (9 points)
  - Which strategies produce more diverse outputs?

  - Which ones tend to repeat phrases or truncate early?

  - Which are more stable across different runs?


2. Analyze the impact of decoding parameters. Suggest complementary evaluation methods that might improve reliability. (3 points)
  - How does increasing beam size or sampling temperature affect the results?

  - Use at least one example to illustrate the effect.

3. Discuss the limitations of ROUGE as an evaluation metric. Suggest complementary evaluation methods that might improve reliability. (3 points)
  - What aspects of summarization quality does ROUGE fail to capture?


