In [1]:
from collections import Counter, deque

### Let's add **Ġ** in place of white space in between the words.

In [2]:
def preprocess(text):
    """
    Replace spaces with a special marker 'Ġ' (except at the start) and return the processed string.
    """
    processed = []
    for i, c in enumerate(text):
        if c == ' ' and i != 0:
            processed.append('Ġ')
        if c != ' ':
            processed.append(c)
    return ''.join(processed)

### Building the vocabulary

In [3]:
def build_initial_vocab(processed_text, allowed_special):
    """
    Build the initial character-level vocabulary plus any allowed special tokens.
    Returns vocab (id->token) and inverse_vocab (token->id).
    """
    # Start with all byte values
    unique_chars = [chr(i) for i in range(256)]
    # Add any extra chars from text
    unique_chars.extend(ch for ch in sorted(set(processed_text)) if ch not in unique_chars)
    if 'Ġ' not in unique_chars:
        unique_chars.append('Ġ')

    vocab = {i: ch for i, ch in enumerate(unique_chars)}
    inverse_vocab = {ch: i for i, ch in vocab.items()}

    # Add special tokens
    for tok in allowed_special:
        if tok not in inverse_vocab:
            nid = len(vocab)
            vocab[nid] = tok
            inverse_vocab[tok] = nid

    return vocab, inverse_vocab

In [4]:
def replace_pair(token_ids, pair_id, new_id):
    """
    Given a list of token IDs, merge occurrences of pair_id into new_id.
    """
    dq = deque(token_ids)
    out = []
    while dq:
        cur = dq.popleft()
        if dq and (cur, dq[0]) == pair_id:
            out.append(new_id)
            dq.popleft()
        else:
            out.append(cur)
    return out

# How `replace_pair()` Works: A Short Example

Below is a step‑by‑step illustration of how `replace_pair()` takes a sequence of token IDs, looks for a specific adjacent pair, and merges every occurrence into a single new ID.

---

## 1. The Inputs

- **Token ID sequence**:  
[98, 97, 110, 97, 110, 97] 'b' 'a' 'n' 'a' 'n' 'a'



- **Pair to merge**:  
`(97, 110)`  (that is, `'a'` + `'n'`)
- **New token ID**:  
`258`  (which we’ll think of as the subword `'an'`)

---

| Step | Remaining tokens (deque)     | Look at `cur` and next | Action                                | Output list      |
|:----:|:-----------------------------|:-----------------------|:--------------------------------------|:-----------------|
|  1   | `[98, 97, 110, 97, 110, 97]` | `cur = 98`             | `(98, 97) ≠ (97,110)` → emit `98`     | `[98]`           |
|  2   | `[97, 110, 97, 110, 97]`     | `cur = 97`, next = 110 | matches `(97,110)` → emit `258`, skip next | `[98, 258]`  |
|  3   | `[97, 110, 97]`              | `cur = 97`, next = 110 | matches `(97,110)` → emit `258`, skip next | `[98, 258, 258]` |
|  4   | `[97]`                       | `cur = 97`             | no next to pair → emit `97`           | `[98, 258, 258, 97]` |

---

## 3. The Result

```text
[98, 258, 258, 97]
'b', 'an', 'an', 'a'
```

**deque:** a double‑ended queue data structure that lets you append or pop items from either the front or back in O(1) time. <br> <br>
**popleft():** a method on a deque that removes and returns the leftmost element in O(1) time.


## Training the BPE

In [34]:
def train_bpe(text, vocab_size, allowed_special={'<|endoftext|>'}):
    """
    Train a BPE tokenizer on `text` until `vocab_size` tokens.
    Returns:
      - vocab: dict mapping id -> token
      - inverse_vocab: dict mapping token -> id
      - bpe_merges: dict mapping (p0_id, p1_id) -> new_id
    """
    processed_text = preprocess(text)
    vocab, inverse_vocab = build_initial_vocab(processed_text, allowed_special)
    token_ids = [inverse_vocab[ch] for ch in processed_text]
    bpe_merges = {}

    for new_id in range(len(vocab), vocab_size):
        pairs = Counter(zip(token_ids, token_ids[1:]))
        if not pairs:
            break
        # Most frequent pair
        pair_id = max(pairs.items(), key=lambda x: x[1])[0]
        p0, p1 = pair_id
        # Record merge
        bpe_merges[pair_id] = new_id
        # Add new token
        merged = vocab[p0] + vocab[p1]
        vocab[new_id] = merged
        inverse_vocab[merged] = new_id
        # Apply merge to token sequence
        token_ids = replace_pair(token_ids, pair_id, new_id)

    return vocab, inverse_vocab, bpe_merges


## How `train_bpe()` Works: A Short Example

### 1. Inputs
- **Text**: `banana banana`  
- **Preprocess**: replace spaces with `Ġ` → `bananaĠbanana`  
- **Target `vocab_size`**: initial vocab + 2 merges  

### 2. Initial Vocab (IDs 0–4)
| ID | Token           |
|----|-----------------|
| 0  | `b`             |
| 1  | `a`             |
| 2  | `n`             |
| 3  | `Ġ`             |
| 4  | `<\|endoftext\|>`   |

<br> <br>

Initial token IDs:  
[0, 1, 2, 1, 2, 1, 3, 0, 1, 2, 1, 2, 1]


### 3. Merge Iterations

| Step | Most freq pair | New ID | New token | Token IDs after merge      |
|:----:|:--------------:|:------:|:---------:|:---------------------------|
| 1    | (1, 2)         | 5      | `an`      | [0, 5, 5, 1, 3, 0, 5, 5, 1] |
| 2    | (0, 5)         | 6      | `ban`     | [6, 5, 1, 3, 6, 5, 1]       |

### 4. Result
- **Final vocab (size 7):**  
  `{0:'b', 1:'a', 2:'n', 3:'Ġ', 4:'<|endoftext|>', 5:'an', 6:'ban'}`
- **Recorded merges:**  
  `(1,2) → 5`  
  `(0,5) → 6`


### Function to show merges in an understandable format

In [6]:
def display_merge_trees(vocab, bpe_merges):
    """
    Print ASCII trees showing how each merged token is composed.
    """
    # Reverse map: new_id -> (p0, p1)
    rev = {new_id: pair for pair, new_id in bpe_merges.items()}

    def recurse(tid, prefix='', is_last=True):
        token = vocab[tid]
        connector = '└─ ' if is_last else '├─ '
        print(prefix + connector + repr(token))
        if tid in rev:
            p0, p1 = rev[tid]
            new_prefix = prefix + ('   ' if is_last else '│  ')
            recurse(p0, new_prefix, is_last=False)
            recurse(p1, new_prefix, is_last=True)

    print('\nBPE Merge Trees:')
    for (p0, p1), mid in sorted(bpe_merges.items(), key=lambda x: x[1]):
        print(f"\nMerge ID {mid}: ({vocab[p0]!r}, {vocab[p1]!r}) → {mid!r} '{vocab[mid]}'")
        recurse(mid, prefix='', is_last=True)

### Encoding and Decoding

In [7]:
def encode(text, inverse_vocab, bpe_merges):
    """
    Encode `text` to a list of token IDs using the trained BPE merges.
    """
    processed_text = preprocess(text)
    token_ids = [inverse_vocab[ch] for ch in processed_text]
    # Apply merges in order of creation
    for pair, new_id in sorted(bpe_merges.items(), key=lambda x: x[1]):
        token_ids = replace_pair(token_ids, pair, new_id)
    return token_ids

In [31]:
def decode(token_ids, vocab):
    """
    Decode a list of token IDs back to a string.
    """
    text = ''.join(vocab[t] for t in token_ids)
    # Restore spaces
    return text.replace('Ġ', ' ')

### Result

In [32]:
sample = 'banana banana'
vocab_size = 500  # for demonstration
vocab, inv_vocab, merges = train_bpe(sample, vocab_size)
print('Vocabulary size:', len(vocab))
display_merge_trees(vocab, merges)

Vocabulary size: 264

BPE Merge Trees:

Merge ID 258: ('a', 'n') → 258 'an'
└─ 'an'
   ├─ 'a'
   └─ 'n'

Merge ID 259: ('b', 'an') → 259 'ban'
└─ 'ban'
   ├─ 'b'
   └─ 'an'
      ├─ 'a'
      └─ 'n'

Merge ID 260: ('ban', 'an') → 260 'banan'
└─ 'banan'
   ├─ 'ban'
   │  ├─ 'b'
   │  └─ 'an'
   │     ├─ 'a'
   │     └─ 'n'
   └─ 'an'
      ├─ 'a'
      └─ 'n'

Merge ID 261: ('banan', 'a') → 261 'banana'
└─ 'banana'
   ├─ 'banan'
   │  ├─ 'ban'
   │  │  ├─ 'b'
   │  │  └─ 'an'
   │  │     ├─ 'a'
   │  │     └─ 'n'
   │  └─ 'an'
   │     ├─ 'a'
   │     └─ 'n'
   └─ 'a'

Merge ID 262: ('banana', 'Ġ') → 262 'bananaĠ'
└─ 'bananaĠ'
   ├─ 'banana'
   │  ├─ 'banan'
   │  │  ├─ 'ban'
   │  │  │  ├─ 'b'
   │  │  │  └─ 'an'
   │  │  │     ├─ 'a'
   │  │  │     └─ 'n'
   │  │  └─ 'an'
   │  │     ├─ 'a'
   │  │     └─ 'n'
   │  └─ 'a'
   └─ 'Ġ'

Merge ID 263: ('bananaĠ', 'banana') → 263 'bananaĠbanana'
└─ 'bananaĠbanana'
   ├─ 'bananaĠ'
   │  ├─ 'banana'
   │  │  ├─ 'banan'
   │  │  │  ├─ 'ban'
  

## The Merges

In [33]:
merges

{(97, 110): 258,
 (98, 258): 259,
 (259, 258): 260,
 (260, 97): 261,
 (261, 256): 262,
 (262, 261): 263}

In [37]:
for (p0, p1), mid in merges.items():
    left  = decode([p0], vocab)
    right = decode([p1], vocab)
    merged = decode([mid], vocab)
    print(f"({left!r}, {right!r}) -> {merged!r}")

('a', 'n') -> 'an'
('b', 'an') -> 'ban'
('ban', 'an') -> 'banan'
('banan', 'a') -> 'banana'
('banana', ' ') -> 'banana '
('banana ', 'banana') -> 'banana banana'


### Encoding and Decoding Results

In [38]:
encode("banana", inv_vocab, merges)

[261]

In [39]:
decode([261], vocab)

'banana'

A simple working of BPE is demonstrated.