<a href="https://colab.research.google.com/github/chen-star/llm_model_trainings/blob/main/impl_byte_pair_encoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Byte Pair Encoding (BPE)

Byte Pair Encoding (BPE) is a subword tokenization method widely used in Natural Language Processing (NLP), particularly in Large Language Models (LLMs) like GPT, RoBERTa, and others.

## How it Works

1.  **Initialization**: Start with a vocabulary of individual characters.
2.  **Counting**: Count the frequency of all adjacent pairs of symbols in the text.
3.  **Merging**: Identify the most frequent pair and merge them into a new symbol.
4.  **Iteration**: Repeat steps 2 and 3 for a fixed number of merges (hyperparameter) or until a desired vocabulary size is reached.

## Example Walkthrough

Let's assume we have a small corpus with the following word frequencies:

*   "low": 5
*   "lower": 2
*   "newest": 6
*   "widest": 3

### Step 1: Initialization

We split words into characters and append a special end-of-word symbol `</w>`.

*   `l o w </w>`: 5
*   `l o w e r </w>`: 2
*   `n e w e s t </w>`: 6
*   `w i d e s t </w>`: 3

**Vocabulary**: `l, o, w, e, r, n, s, t, i, d, </w>`

### Step 2: Counting Pairs

We count the frequency of all adjacent pairs.

*   `e s`: 6 (newest) + 3 (widest) = **9**
*   `s t`: 6 (newest) + 3 (widest) = **9**
*   `e s t`: (overlaps, but we count bigrams first)
*   `l o`: 5 (low) + 2 (lower) = 7
*   `o w`: 5 (low) + 2 (lower) = 7
*   ...

### Step 3: Merging

The most frequent pair is `e` and `s` (9 times) or `s` and `t` (9 times). Let's pick `e` and `s` to merge into `es`.

*   `l o w </w>`: 5
*   `l o w e r </w>`: 2
*   `n e w es t </w>`: 6
*   `w i d es t </w>`: 3

**New Token**: `es`

### Step 4: Iteration

Now we count pairs again. The pair `es` and `t` appears 9 times (6 in newest + 3 in widest).

Merge `es` and `t` -> `est`.

*   `l o w </w>`: 5
*   `l o w e r </w>`: 2
*   `n e w est </w>`: 6
*   `w i d est </w>`: 3

**New Token**: `est`

We continue this process until we reach a desired vocabulary size.

# [1] ðŸ—º Implement BPE

In [1]:
import numpy as np

In [2]:
def count_pair_freq(text: str) -> dict[str, int]:
  pair_counts = {}

  for i in range(len(text) - 1):
    pair = text[i] + text[i + 1]
    pair_counts[pair] = pair_counts.get(pair, 0) + 1

  return pair_counts

In [3]:
def select_most_frequent_pair(pair_counts: dict[str, int]) -> str:
  maxIdx = np.argmax(list(pair_counts.values()))
  return list(pair_counts.keys())[maxIdx]

In [4]:
def update_vocabulary(vocabulary: dict[str, int], pair: str) -> dict[str, int]:
    vocabulary[pair] = max(vocabulary.values()) + 1
    return vocabulary

In [5]:
def generate_merged_text(text: str, pair: str) -> str:
  merged_text = []

  i = 0
  while i < (len(text) - 1):
    curr_pair = text[i] + text[i + 1]

    if curr_pair == pair:
      merged_text.append(curr_pair)
      i += 2
    else:
      merged_text.append(text[i])
      i += 1

  if i == (len(text) - 1):
    merged_text.append(text[i])

  return merged_text

In [6]:
def bpe(text: str, target_vocab_size=30) -> dict[str, int]:
  # initialize each char in text to vacabulary
  chars = list(set(text))
  chars.sort()
  vocabulary = {c : i for i, c in enumerate(chars)}

  # run bpe
  updated_text = text
  while len(vocabulary) < target_vocab_size:
    pair_counts = count_pair_freq(updated_text)
    max_pair = select_most_frequent_pair(pair_counts)
    vocabulary = update_vocabulary(vocabulary, max_pair)
    updated_text = generate_merged_text(updated_text, max_pair)
    print(f"*** vocab_size={len(vocabulary)}:")
    print(f"\t {vocabulary}")

  return vocabulary

# [2] ðŸ§ª Test BPE

In [7]:
text = """
This is a much more complex example text for byte pair encoding demonstration.
We will observe how it tokenizes common words and subword units effectively.
"""
print(f"Original text: {text}")

Original text: 
This is a much more complex example text for byte pair encoding demonstration. 
We will observe how it tokenizes common words and subword units effectively.



In [8]:
final_vocabulary = bpe(text, target_vocab_size=40)

*** vocab_size=30:
	 {'\n': 0, ' ': 1, '.': 2, 'T': 3, 'W': 4, 'a': 5, 'b': 6, 'c': 7, 'd': 8, 'e': 9, 'f': 10, 'g': 11, 'h': 12, 'i': 13, 'k': 14, 'l': 15, 'm': 16, 'n': 17, 'o': 18, 'p': 19, 'r': 20, 's': 21, 't': 22, 'u': 23, 'v': 24, 'w': 25, 'x': 26, 'y': 27, 'z': 28, 's ': 29}
*** vocab_size=31:
	 {'\n': 0, ' ': 1, '.': 2, 'T': 3, 'W': 4, 'a': 5, 'b': 6, 'c': 7, 'd': 8, 'e': 9, 'f': 10, 'g': 11, 'h': 12, 'i': 13, 'k': 14, 'l': 15, 'm': 16, 'n': 17, 'o': 18, 'p': 19, 'r': 20, 's': 21, 't': 22, 'u': 23, 'v': 24, 'w': 25, 'x': 26, 'y': 27, 'z': 28, 's ': 29, 'e ': 30}
*** vocab_size=32:
	 {'\n': 0, ' ': 1, '.': 2, 'T': 3, 'W': 4, 'a': 5, 'b': 6, 'c': 7, 'd': 8, 'e': 9, 'f': 10, 'g': 11, 'h': 12, 'i': 13, 'k': 14, 'l': 15, 'm': 16, 'n': 17, 'o': 18, 'p': 19, 'r': 20, 's': 21, 't': 22, 'u': 23, 'v': 24, 'w': 25, 'x': 26, 'y': 27, 'z': 28, 's ': 29, 'e ': 30, 'or': 31}
*** vocab_size=33:
	 {'\n': 0, ' ': 1, '.': 2, 'T': 3, 'W': 4, 'a': 5, 'b': 6, 'c': 7, 'd': 8, 'e': 9, 'f': 10, 'g': 1