# Chapter 2: Classic DP/Graphs for ML Engineers

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jmamath/interview_prep/blob/main/chapter_02_dp_graphs_ml.ipynb)

## Introduction

Have you ever wondered how Google Translate can take a sentence in one language and produce a coherent, grammatically correct translation in another? Or how a speech recognition system on your phone can accurately transcribe your spoken words into text, even in a noisy environment? These are not just feats of large-scale data processing; they are also triumphs of algorithmic ingenuity. At the heart of these technologies lie classic algorithms from computer science, adapted and scaled for the complexities of machine learning.

In this chapter, we'll pull back the curtain on some of these fundamental algorithms. We'll see how dynamic programming and graph search, concepts you might have first encountered in a standard algorithms course, are the workhorses behind many of the ML-powered features you use every day. We'll move beyond the textbook definitions and dive into practical, hands-on exercises that show you how these algorithms are applied in the real world.

By the end of this chapter, you'll not only have a deeper understanding of these classic algorithms, but you'll also have a practical toolkit for applying them to your own machine learning problems. So, let's get started!

## Learning Objectives
- Implement dynamic programming solutions for ML-related problems
- Design and implement beam search algorithms for sequence generation
- Apply graph algorithms to model training and inference problems
- Implement the Viterbi algorithm for sequence tagging
- Use diverse beam search for better generation diversity

---

## Problem 1: Simple Beam Search (Easy)

### Contextual Introduction
In machine translation or text summarization, we often need to generate a sequence of words. A simple approach, called greedy search, is to pick the most likely word at each step. However, this can lead to suboptimal results. For example, the best-scoring sentence might not start with the single best word. Beam search is a more effective alternative that keeps track of the `k` most promising sequences (the "beam") at each step, leading to better overall results.

### Key Concepts
- **Greedy Search**: Always choosing the locally optimal option at each step.
- **Beam Search**: A graph search algorithm that explores a graph by expanding the most promising nodes in a limited set.
- **Beam Width (k)**: The number of partial sequences (beams) to keep at each step.

### Problem Statement
Implement a basic beam search algorithm for sequence generation. You will be given a vocabulary and a simple scoring function. Your task is to generate the top `k` sequences of a given maximum length.

**Requirements**:
- Implement beam search with a configurable beam width.
- Support early stopping when an end-of-sequence token is reached.
- Return the top `k` sequences and their scores.

### Example: Understanding Beam Search with a Small Vocabulary

```python
# Let's see how beam search works step by step
# Vocabulary: ['a', 'b', 'c', '<END>']
# Beam width: 2 (keep top-2 sequences)
# Max length: 2

# STEP 0: Start with empty sequence
# Beam: [(score=0.0, seq=[])]

# STEP 1: Expand - try adding each word
# Candidates: (score=-1, ['a']), (score=-1, ['b']), (score=-1, ['c']), (score=-1, ['<END>'])
# After sorting by score and keeping top-2:
# Beam: [(score=-1, ['a']), (score=-1, ['b'])]

# STEP 2: Expand each sequence in beam
# From ['a']: (score=-2, ['a','a']), (score=-2, ['a','b']), (score=-2, ['a','c']), ...
# From ['b']: (score=-2, ['b','a']), (score=-2, ['b','b']), (score=-2, ['b','c']), ...
# Top-2: [(score=-2, ['a','a']), (score=-2, ['a','b'])]

# RESULT: [(['a', 'a'], -2), (['a', 'b'], -2)]

# Why this is better than greedy:
# - Greedy would pick ['a'] -> ['a','a'] only, missing other good options
# - Beam search explores multiple paths and finds better combinations
```

In [None]:
from typing import List, Tuple
import heapq

class BeamSearch:
    def __init__(self, vocabulary: List[str], end_token: str = '<END>'):
        self.vocabulary = vocabulary
        self.end_token = end_token

    def score_sequence(self, sequence: List[str]) -> float:
        # In a real scenario, this would be a language model score.
        # For this exercise, we use a simple length-based score.
        return -len(sequence)

    def search(self, beam_width: int, max_length: int) -> List[Tuple[float, List[str]]]:
        # TODO: Implement the beam search algorithm.
        # Remember to handle the beam, generate candidates, and manage completed sequences.
        pass

def test_beam_search():
    vocabulary = ['a', 'b', 'c', '<END>']
    # We are providing a correct implementation here for the sake of the test
    class CorrectBeamSearch(BeamSearch):
        def search(self, beam_width: int, max_length: int) -> List[Tuple[float, List[str]]]:
            if beam_width == 0: return []
            beam = [(0.0, [])]  # (score, sequence)
            completed_sequences = []
            for _ in range(max_length):
                new_beam = []
                for score, seq in beam:
                    if not seq or seq[-1] == self.end_token:
                        completed_sequences.append((score, seq))
                        continue
                    for token in self.vocabulary:
                        new_seq = seq + [token]
                        new_score = self.score_sequence(new_seq)
                        heapq.heappush(new_beam, (new_score, new_seq))
                beam.clear()
                while new_beam and len(beam) < beam_width:
                    score, seq = heapq.heappop(new_beam)
                    beam.append((score, seq))
            all_sequences = completed_sequences + beam
            return sorted(all_sequences, key=lambda x: x[0], reverse=True)[:beam_width]

    beam_search = CorrectBeamSearch(vocabulary)

    # Test 1: Basic search
    results = beam_search.search(beam_width=2, max_length=2)
    assert len(results) == 2
    assert results[0][1] == []

    # Test 2: Zero beam width
    results = beam_search.search(beam_width=0, max_length=2)
    assert len(results) == 0

    print("🎉 All beam search tests passed!")

test_beam_search()

<details>
<summary>Click to reveal hint for Problem 1</summary>

**Hint**: Use a priority queue (like Python's `heapq`) to maintain the beam of the top `k` sequences at each step. The priority queue should store tuples of `(score, sequence)`. At each step of the generation, expand each sequence in the beam with all possible next tokens, score the new sequences, and use the priority queue to keep only the top `k`.

</details>

---

## Problem 2: Top-k Beam Search with Scores (Medium)

### Contextual Introduction
Standard beam search can sometimes produce sequences that are very similar to each other. To encourage more diversity, we can introduce techniques like length normalization and a diversity penalty. Length normalization prevents the search from favoring shorter sequences, while a diversity penalty discourages sequences that are too similar to already selected ones. These techniques are crucial for applications like creative text generation or offering multiple diverse translation options.

### Key Concepts
- **Length Normalization**: A technique to reduce the bias of beam search towards shorter sequences by dividing the score by the sequence length raised to some power.
- **Diversity Penalty**: A penalty applied to the score of a sequence based on its similarity to other sequences in the beam, encouraging more diverse outputs.

### Problem Statement
Extend the simple beam search to include length normalization and a diversity penalty. You will implement a `TopKBeamSearch` class that generates more diverse and higher-quality sequences.

**Requirements**:
- Implement length normalization in the scoring function.
- Implement a diversity penalty based on n-gram overlap.
- Combine these techniques in the search algorithm to produce diverse sequences.

### Example: Length Normalization and Diversity Penalty

```python
import math

# WITHOUT length normalization: Shorter sequences get higher scores
sequences = [
    ['hello'],           # score = -1
    ['hello', 'world']   # score = -2
]
# Result: First sequence wins! (score -1 > -2)

# WITH length normalization (dividing by length^alpha):
# alpha = 0.6 is commonly used
sequences_with_norm = [
    ['hello'],           # normalized = -1 / (1^0.6) = -1.0
    ['hello', 'world']   # normalized = -2 / (2^0.6) = -1.32
]
# Result: First still wins, but the gap is smaller

# DIVERSITY PENALTY - preventing similar sequences:
# Sequence 1: ['a', 'b', 'c']  -> bigrams: {('a','b'), ('b','c')}
# Sequence 2: ['a', 'b', 'd']  -> bigrams: {('a','b'), ('b','d')}
# Overlap: 1 bigram ('a','b')
# Diversity penalty = 1 * penalty_weight (e.g., 0.5)

# In beam search with diversity:
# - First sequence ['a', 'b', 'c'] is selected with score 1.0
# - Second sequence ['a', 'b', 'd'] gets penalty for overlapping bigram
# - New score = 0.8 - (1 * 0.5) = 0.3
# - Different sequence like ['x', 'y', 'z'] with no overlap keeps full score 0.9
# Result: More diverse outputs!
```

In [None]:
from typing import List, Tuple
import math

class TopKBeamSearch(BeamSearch):
    def score_sequence(self, sequence: List[str], length_penalty: float = 0.6) -> float:
        # TODO: Implement length-normalized scoring
        pass

    def calculate_diversity_penalty(self, sequence: List[str], existing_sequences: List[List[str]], penalty_weight: float = 0.5) -> float:
        # TODO: Implement diversity penalty calculation based on n-gram overlap
        pass

    def search(self, beam_width: int, max_length: int, diversity_penalty: float = 0.5) -> List[Tuple[float, List[str]]]:
        # TODO: Implement the search method incorporating the new scoring and penalty functions
        pass

def test_top_k_beam_search():
    vocabulary = ['a', 'b', 'c', 'd', '<END>']
    # Correct implementation for testing purposes
    class CorrectTopK(TopKBeamSearch):
        def score_sequence(self, sequence: List[str], length_penalty: float = 0.6) -> float:
            raw_score = -len(sequence) * 0.5
            return raw_score / (len(sequence)**length_penalty) if sequence else 0
        def calculate_diversity_penalty(self, sequence: List[str], existing_sequences: List[List[str]], penalty_weight: float = 0.5) -> float:
            penalty = 0.0
            for existing_seq in existing_sequences:
                seq_bigrams = set(zip(sequence, sequence[1:]))
                existing_bigrams = set(zip(existing_seq, existing_seq[1:]))
                overlap = len(seq_bigrams.intersection(existing_bigrams))
                penalty += overlap * penalty_weight
            return penalty
        def search(self, beam_width: int, max_length: int, diversity_penalty: float = 0.5) -> List[Tuple[float, List[str]]]:
            beam = [(0.0, [])]
            completed = []
            for _ in range(max_length):
                new_beam = []
                for score, seq in beam:
                    if not seq or seq[-1] == self.end_token: continue
                    for token in self.vocabulary:
                        new_seq = seq + [token]
                        new_score = self.score_sequence(new_seq) - self.calculate_diversity_penalty(new_seq, [s for _, s in beam], diversity_penalty)
                        heapq.heappush(new_beam, (new_score, new_seq))
                beam.clear()
                while new_beam and len(beam) < beam_width:
                    score, seq = heapq.heappop(new_beam)
                    if seq[-1] == self.end_token: completed.append((score, seq))
                    else: beam.append((score, seq))
            return sorted(completed + beam, key=lambda x: x[0], reverse=True)[:beam_width]

    top_k_search = CorrectTopK(vocabulary)

    # Test 1: With diversity penalty, we should get more diverse results
    results_diverse = top_k_search.search(beam_width=3, max_length=3, diversity_penalty=1.0)
    results_no_diversity = top_k_search.search(beam_width=3, max_length=3, diversity_penalty=0.0)

    assert results_diverse[0][1] != results_no_diversity[0][1]

    print("🎉 All top-k beam search tests passed!")

test_top_k_beam_search()

<details>
<summary>Click to reveal hint for Problem 2</summary>

**Hint**: For length normalization, divide the sequence score by `len(sequence)**alpha`, where `alpha` is the length penalty factor. For the diversity penalty, you can calculate the n-gram overlap between a candidate sequence and the sequences already in the beam. Subtract a penalty proportional to this overlap from the candidate's score.

</details>

---

## Problem 3: Viterbi Algorithm for Sequence Tagging (Medium)

### Contextual Introduction
Part-of-speech (POS) tagging is a classic NLP task where we assign a grammatical tag (like noun, verb, or adjective) to each word in a sentence. A simple approach of picking the most likely tag for each word independently can fail because it ignores the context (e.g., "the fish" is more likely than "the and"). The Viterbi algorithm, a dynamic programming method, solves this by finding the most likely sequence of tags given the sequence of words. It's used in Hidden Markov Models (HMMs) and is fundamental to many sequence labeling tasks.

### Key Concepts
- **Hidden Markov Model (HMM)**: A statistical model with hidden states (the tags) and observable outputs (the words).
- **Transition Probability**: The probability of moving from one state to another, P(tag_i | tag_{i-1}).
- **Emission Probability**: The probability of observing a word given a state, P(word_i | tag_i).
- **Dynamic Programming**: The Viterbi algorithm uses a table to store the maximum probability of being in a certain state at a certain time, avoiding re-computation.

### Problem Statement
Implement the Viterbi algorithm for a simple POS tagger. You will be given the transition, emission, and initial probabilities of an HMM. Your task is to find the most likely sequence of POS tags for a given sentence.

**Requirements**:
- Implement the Viterbi algorithm using dynamic programming.
- Use logarithms to prevent underflow with small probabilities.
- Handle unknown words with a simple smoothing technique.
- Reconstruct the most likely path of tags.

### Example: Manual Viterbi Computation

```python
# Let's trace through a simple example: sentence = ['the', 'cat', 'sat']
# Tags: DET (determiner), NOUN (noun), VERB (verb)

# Transition probabilities P(tag_i | tag_{i-1}):
#           DET   NOUN  VERB
#  DET:    [0.1   0.8   0.1]
#  NOUN:   [0.3   0.1   0.6]
#  VERB:   [0.2   0.7   0.1]

# Emission probabilities P(word | tag):
#         'the'  'cat'  'sat'
#  DET:  [0.9    0.05   0.05]
#  NOUN: [0.05   0.9    0.05]
#  VERB: [0.05   0.05   0.9]

# Initial probabilities: [0.8 (DET), 0.1 (NOUN), 0.1 (VERB)]

# STEP 1: Process word 'the'
# For each tag, compute: initial_prob * emission_prob['the']
# DET:  0.8 * 0.9 = 0.72   <- best
# NOUN: 0.1 * 0.05 = 0.005
# VERB: 0.1 * 0.05 = 0.005
# Viterbi table: [0.72, 0.005, 0.005]

# STEP 2: Process word 'cat'
# For NOUN tag: max(
#     0.72 * 0.8 * 0.9,     <- from DET
#     0.005 * 0.1 * 0.9,    <- from NOUN
#     0.005 * 0.7 * 0.9     <- from VERB
# ) = max(0.518, 0.0005, 0.0032) = 0.518  <- from DET
# Similarly for other tags...

# STEP 3: Process word 'sat'
# Similar computation, tracking best path

# FINAL RESULT: ['DET', 'NOUN', 'VERB']
# Because this path has the highest joint probability
```

In [None]:
import numpy as np
from typing import List, Tuple, Dict

class HMMTagger:
    def __init__(self, tags: List[str], words: List[str], initial_prob: np.ndarray, transition_prob: np.ndarray, emission_prob: np.ndarray):
        self.tags = tags
        self.words = words
        self.tag_to_idx = {tag: i for i, tag in enumerate(tags)}
        self.word_to_idx = {word: i for i, word in enumerate(words)}
        self.log_initial = np.log(initial_prob + 1e-10)
        self.log_transition = np.log(transition_prob + 1e-10)
        self.log_emission = np.log(emission_prob + 1e-10)

    def viterbi(self, sentence: List[str]) -> Tuple[List[str], float]:
        # TODO: Implement the Viterbi algorithm, including initialization, recursion, and backtracking.
        pass

def test_viterbi_algorithm():
    tags = ['DET', 'NOUN', 'VERB']
    words = ['the', 'cat', 'sat']
    initial_prob = np.array([0.8, 0.1, 0.1])
    transition_prob = np.array([[0.1, 0.8, 0.1], [0.3, 0.1, 0.6], [0.2, 0.7, 0.1]])
    emission_prob = np.array([[0.9, 0.05, 0.05], [0.05, 0.9, 0.05], [0.05, 0.05, 0.9]])

    class CorrectHMM(HMMTagger):
        def viterbi(self, sentence: List[str]) -> Tuple[List[str], float]:
            T = len(sentence); N = len(self.tags)
            if T == 0: return [], 0.0
            viterbi_table = np.zeros((T, N)); backpointer = np.zeros((T, N), dtype=int)
            word_idx = self.word_to_idx.get(sentence[0], -1)
            emission = self.log_emission[:, word_idx] if word_idx != -1 else np.log(np.full(N, 1e-10))
            viterbi_table[0, :] = self.log_initial + emission
            for t in range(1, T):
                word_idx = self.word_to_idx.get(sentence[t], -1)
                emission = self.log_emission[:, word_idx] if word_idx != -1 else np.log(np.full(N, 1e-10))
                for s in range(N):
                    trans_probs = viterbi_table[t-1, :] + self.log_transition[:, s]
                    viterbi_table[t, s] = np.max(trans_probs) + emission[s]
                    backpointer[t, s] = np.argmax(trans_probs)
            best_prob = np.max(viterbi_table[T-1, :]); last_state = np.argmax(viterbi_table[T-1, :])
            path = [self.tags[last_state]]
            for t in range(T - 1, 0, -1):
                last_state = backpointer[t, last_state]
                path.insert(0, self.tags[last_state])
            return path, np.exp(best_prob)

    tagger = CorrectHMM(tags, words, initial_prob, transition_prob, emission_prob)

    # Test 1: A likely sequence
    sentence1 = ['the', 'cat', 'sat']
    path, prob = tagger.viterbi(sentence1)
    assert path == ['DET', 'NOUN', 'VERB']

    # Test 2: A sequence with an unknown word
    sentence2 = ['the', 'dog', 'sat']
    path, prob = tagger.viterbi(sentence2)
    assert path == ['DET', 'NOUN', 'VERB']

    print("🎉 All Viterbi algorithm tests passed!")

test_viterbi_algorithm()

<details>
<summary>Click to reveal hint for Problem 3</summary>

**Hint**: Create a dynamic programming table of size `(num_words, num_tags)`. Each cell `(i, j)` will store the maximum probability of a tag sequence of length `i` ending with tag `j`. Also, create a `backpointer` table to store the path. Iterate through the words and for each word, calculate the probabilities for each tag based on the previous word's tag probabilities and the transition probabilities. After filling the table, backtrack from the end to find the most likely path.

</details>

---

## Problem 4: Constrained Beam Search (Medium-Hard)

### Contextual Introduction
In many real-world applications, we need to generate text that adheres to certain rules. For example, a chatbot must avoid generating toxic language, or a text summarization model might be required to include certain keywords. Constrained beam search extends the standard algorithm by pruning partial sequences that violate predefined constraints, ensuring that the final output meets the requirements.

### Key Concepts
- **Constraint Satisfaction**: The process of finding a solution that satisfies a set of constraints.
- **Hard Constraints vs. Soft Constraints**: Hard constraints must be satisfied, while soft constraints are desirable but not mandatory.
- **Pruning**: The process of eliminating partial solutions that cannot lead to a valid final solution.

### Problem Statement
Implement a constrained beam search algorithm that can handle two types of constraints: `MustContainConstraint` and `MustNotContainConstraint`. You will integrate these constraints into the beam search process to generate sequences that satisfy all given rules.

**Requirements**:
- Design an abstract `Constraint` class and specific implementations for `MustContainConstraint` and `MustNotContainConstraint`.
- Integrate constraint checking into the beam search algorithm.
- Prune the search space by discarding partial sequences that violate constraints.

### Example: How Constraints Filter Beam Search

```python
# Vocabulary: ['a', 'b', 'c', 'toxic', '<END>']
# Beam width: 2
# Constraint: MustNotContainConstraint('toxic')

# STEP 1: Generate candidates
# Candidates: (score=-1, ['a']), (score=-1, ['b']), (score=-1, ['toxic']), (score=-1, ['c'])

# STEP 2: Apply constraints
# Check each candidate:
# - ['a']: Contains 'toxic'? No -> KEEP
# - ['b']: Contains 'toxic'? No -> KEEP
# - ['toxic']: Contains 'toxic'? Yes -> PRUNE (remove)
# - ['c']: Contains 'toxic'? No -> KEEP

# STEP 3: Select top-2 after filtering
# Beam: [(score=-1, ['a']), (score=-1, ['b'])]

# STEP 4: Continue expanding only valid sequences
# From ['a']: ['a','a'], ['a','b'], ['a','c'] (not ['a','toxic']!)
# From ['b']: ['b','a'], ['b','b'], ['b','c'] (not ['b','toxic']!)

# FINAL RESULT: Only sequences without 'toxic' are generated
# This is much more efficient than generating toxic sequences and filtering later!
```

In [None]:
from typing import List, Tuple, Set
from abc import ABC, abstractmethod

class Constraint(ABC):
    @abstractmethod
    def check(self, sequence: List[str]) -> bool:
        pass

class MustContainConstraint(Constraint):
    def __init__(self, required_token: str):
        self.required_token = required_token

    def check(self, sequence: List[str]) -> bool:
        # TODO: Implement the check for must-contain constraint
        pass

class MustNotContainConstraint(Constraint):
    def __init__(self, forbidden_token: str):
        self.forbidden_token = forbidden_token

    def check(self, sequence: List[str]) -> bool:
        # TODO: Implement the check for must-not-contain constraint
        pass

class ConstrainedBeamSearch(BeamSearch):
    def __init__(self, vocabulary: List[str], end_token: str = '<END>'):
        super().__init__(vocabulary, end_token)
        self.constraints = []

    def add_constraint(self, constraint: Constraint):
        self.constraints.append(constraint)

    def search(self, beam_width: int, max_length: int) -> List[Tuple[float, List[str]]]:
        # TODO: Implement the constrained beam search
        pass

def test_constrained_beam_search():
    vocabulary = ['a', 'b', 'c', 'd', '<END>']
    class CorrectConstrained(ConstrainedBeamSearch):
        def search(self, beam_width: int, max_length: int) -> List[Tuple[float, List[str]]]:
            beam = [(0.0, [])]
            completed = []
            for _ in range(max_length):
                new_beam = []
                for score, seq in beam:
                    if not seq or seq[-1] == self.end_token: continue
                    for token in self.vocabulary:
                        new_seq = seq + [token]
                        if all(c.check(new_seq) for c in self.constraints):
                            new_score = self.score_sequence(new_seq)
                            heapq.heappush(new_beam, (new_score, new_seq))
                beam.clear()
                while new_beam and len(beam) < beam_width:
                    score, seq = heapq.heappop(new_beam)
                    if seq[-1] == self.end_token: completed.append((score, seq))
                    else: beam.append((score, seq))
            return sorted(completed + beam, key=lambda x: x[0], reverse=True)[:beam_width]
    
    # Test 1: Must contain 'c'
    searcher_must_contain = CorrectConstrained(vocabulary)
    searcher_must_contain.add_constraint(MustContainConstraint('c'))
    results = searcher_must_contain.search(beam_width=2, max_length=3)
    for _, seq in results:
        assert 'c' in seq

    # Test 2: Must not contain 'b'
    searcher_must_not_contain = CorrectConstrained(vocabulary)
    searcher_must_not_contain.add_constraint(MustNotContainConstraint('b'))
    results = searcher_must_not_contain.search(beam_width=2, max_length=3)
    for _, seq in results:
        assert 'b' not in seq

    print("🎉 All constrained beam search tests passed!")

test_constrained_beam_search()

<details>
<summary>Click to reveal hint for Problem 4</summary>

**Hint**: Create an abstract base class `Constraint` with a `check` method. Then, implement concrete constraint classes like `MustContainConstraint` and `MustNotContainConstraint`. In your beam search loop, after generating a new candidate sequence, iterate through your list of constraints and only add the sequence to the new beam if it satisfies all of them.

</details>

---

## Problem 5: Diverse Beam Search with Groups (Hard)

### Contextual Introduction
While standard beam search is good at finding high-quality sequences, it often produces a set of very similar results. For creative applications like story generation or offering multiple translation choices, we need diversity. Diverse beam search addresses this by partitioning the beam into groups and encouraging each group to explore a different part of the search space. This ensures that the final set of sequences is both high-quality and diverse.

### Key Concepts
- **Sequence Similarity**: A metric to quantify how similar two sequences are (e.g., n-gram overlap, Jaccard similarity).
- **Clustering/Grouping**: The process of partitioning a set of items into groups based on similarity.
- **Quality-Diversity Trade-off**: The balance between generating high-scoring (quality) sequences and generating a wide variety of (diverse) sequences.

### Problem Statement
Implement a diverse beam search algorithm that groups similar sequences and selects the best from each group. This will involve calculating sequence similarity, grouping sequences, and modifying the beam search to maintain diversity across groups.

**Requirements**:
- Implement a function to calculate sequence similarity (e.g., Jaccard similarity of bigrams).
- In each step of the beam search, group the candidate sequences.
- Select the best sequence from each group to form the new beam, ensuring diversity.

### Example: Sequence Similarity and Grouping

```python
# Calculating Jaccard similarity between sequences using bigrams

# Sequence 1: ['hello', 'world', 'end']
# Bigrams: {('hello', 'world'), ('world', 'end')}

# Sequence 2: ['hello', 'world', 'now']
# Bigrams: {('hello', 'world'), ('world', 'now')}

# Intersection: {('hello', 'world')}  (1 common bigram)
# Union: {('hello', 'world'), ('world', 'end'), ('world', 'now')}  (3 total unique)
# Jaccard similarity = 1 / 3 ≈ 0.33  (quite similar)

# Sequence 3: ['goodbye', 'friend', 'soon']
# Bigrams: {('goodbye', 'friend'), ('friend', 'soon')}

# Intersection with Seq1: {} (0 common bigrams)
# Union with Seq1: {('hello', 'world'), ('world', 'end'), ('goodbye', 'friend'), ('friend', 'soon')}
# Jaccard similarity = 0 / 4 = 0.0  (very different!)

# GROUPING IN DIVERSE BEAM SEARCH:
# Beam width: 4, Groups: 2
# Candidates (sorted by score):
# 1. (score=0.9, ['hello', 'world', 'end'])
# 2. (score=0.88, ['hello', 'world', 'now'])      <- Similar to #1
# 3. (score=0.85, ['goodbye', 'friend', 'soon'])  <- Different from #1
# 4. (score=0.82, ['hi', 'there', 'friend'])      <- Different from #1 and #3

# GROUP 1 (from sequence 1): [#1, #2]  <- Similar sequences
# GROUP 2 (from sequence 3): [#3, #4]  <- Different sequences

# SELECT BEST FROM EACH GROUP:
# From GROUP 1: #1 (best score)
# From GROUP 2: #3 (best score in this group)

# RESULT: [['hello', 'world', 'end'], ['goodbye', 'friend', 'soon']]
# More diverse than ['hello', 'world', 'end'], ['hello', 'world', 'now']
```

In [None]:
from typing import List, Tuple
from collections import defaultdict

class DiverseBeamSearch(BeamSearch):
    def calculate_similarity(self, seq1: List[str], seq2: List[str]) -> float:
        # TODO: Implement Jaccard similarity for bigrams
        pass

    def search(self, beam_width: int, max_length: int, num_groups: int, diversity_strength: float = 0.5) -> List[Tuple[float, List[str]]]:
        # TODO: Implement diverse beam search using groups
        pass

def test_diverse_beam_search():
    vocabulary = ['a', 'b', 'c', 'd', 'e', '<END>']
    class CorrectDiverse(DiverseBeamSearch):
        def calculate_similarity(self, seq1: List[str], seq2: List[str]) -> float:
            set1 = set(zip(seq1, seq1[1:])); set2 = set(zip(seq2, seq2[1:]))
            intersection = len(set1.intersection(set2)); union = len(set1.union(set2))
            return intersection / union if union > 0 else 0
        def search(self, beam_width: int, max_length: int, num_groups: int, diversity_strength: float = 0.5) -> List[Tuple[float, List[str]]]:
            if num_groups > beam_width: num_groups = beam_width
            beams = [[(0.0, [])] for _ in range(num_groups)]
            completed = []
            for _ in range(max_length):
                all_candidates = []
                for i in range(num_groups):
                    for score, seq in beams[i]:
                        if not seq or seq[-1] == self.end_token: continue
                        for token in self.vocabulary:
                            new_seq = seq + [token]
                            new_score = self.score_sequence(new_seq) - (i * diversity_strength)
                            all_candidates.append((new_score, new_seq))
                beams = [[] for _ in range(num_groups)]
                sorted_candidates = sorted(all_candidates, key=lambda x: x[0], reverse=True)
                for score, seq in sorted_candidates:
                    group_idx = hash(' '.join(seq[:1])) % num_groups
                    if len(beams[group_idx]) < beam_width / num_groups:
                        if seq[-1] == self.end_token: completed.append((score, seq))
                        else: beams[group_idx].append((score, seq))
            flat_beam = [item for sublist in beams for item in sublist]
            return sorted(completed + flat_beam, key=lambda x: x[0], reverse=True)[:beam_width]

    diverse_searcher = CorrectDiverse(vocabulary)

    # Test that with diversity, we get different results
    results_diverse = diverse_searcher.search(beam_width=4, max_length=3, num_groups=2, diversity_strength=0.8)
    results_regular = diverse_searcher.search(beam_width=4, max_length=3, num_groups=1, diversity_strength=0.0)

    assert results_diverse[0][1] != results_regular[0][1]

    print("🎉 All diverse beam search tests passed!")

test_diverse_beam_search()

<details>
<summary>Click to reveal hint for Problem 5</summary>

**Hint**: A simple approach to grouping is to use a penalty. In each step of the beam search, when you are scoring new candidate sequences, add a penalty to the score that is proportional to the sequence's similarity to the other sequences in its group. This will encourage the groups to diverge. For example, you can use `group_index * diversity_penalty` as a simple penalty.

</details>