<a href="https://colab.research.google.com/github/el-eshaano/ml/blob/main/BPE_Tokeniser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2024-02-28 11:52:19--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2024-02-28 11:52:20 (15.8 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [None]:
from typing import Dict, Tuple
from tqdm.notebook import tqdm

In [None]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

unique_chars = sorted(list(set(text)))
vocab_size = len(unique_chars)

stoi = { ch:i for i,ch in enumerate(unique_chars) }
itos = { i:ch for i,ch in enumerate(unique_chars) }

tokens = [stoi[x] for x in list(text)]

In [None]:
def get_pair_counts(tokens: list[int]) -> Dict[Tuple[int, int], int]:

    pair_counts = {}
    for pair in zip(tokens, tokens[1:]):
        pair_counts[pair] = pair_counts.get(pair, 0) + 1

    return pair_counts

def replace_pair(arr, pair, replacement):

    new_arr = []
    i = 0
    while i < len(arr):
        if i < len(arr) - 1 and arr[i] == pair[0] and arr[i+1] == pair[1]:
            new_arr.append(replacement)
            i += 2
        else:
            new_arr.append(arr[i])
            i += 1

    return new_arr

In [None]:
desired_vocab_size = 276
current_vocab_size = vocab_size
# ---

num_merges = desired_vocab_size - current_vocab_size
ids = list(tokens)

merges = {}
for i in tqdm(range(num_merges), desc="Merges: ", leave=False):
    pair_counts = get_pair_counts(ids)
    max_occuring_pair = max(pair_counts, key=pair_counts.get)

    if pair_counts[max_occuring_pair] == 1:
        print("Max pair only occurs once, breaking from loop")
        break

    new_token_code = current_vocab_size + i
    ids = replace_pair(ids, max_occuring_pair, new_token_code)
    merges[max_occuring_pair] = new_token_code



Merges:   0%|          | 0/211 [00:00<?, ?it/s]

In [None]:
vocab = itos.copy()
for (p0, p1), idx in merges.items(): # ordered dictionaries baby (ly python >= 3.8)
    vocab[idx] = vocab[p0] + vocab[p1]

print(len(vocab))
print(vocab)

276
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z', 65: 'e ', 66: 'th', 67: 't ', 68: 's ', 69: 'd ', 70: ', ', 71: 'ou', 72: 'er', 73: 'in', 74: 'y ', 75: 'an', 76: ':\n', 77: 'or', 78: 'o ', 79: 'en', 80: '\n\n', 81: 'ar', 82: ' th', 83: 'on', 84: 'll', 85: 'ha', 86: ',\n', 87: '.\n\n', 88: 'is ', 89: 'es', 90: 'you', 91: ' s', 92: 'to ', 93: 'and ', 94: 'ow', 95: 'ea', 96: ' m', 97: ' w', 98: 'of', 99: ' h', 100: 'ing', 101: 'om', 102: ' a', 103: 'ch', 104: 'the '

In [None]:
def decode(ids):
  """ given ids (list of integers), return Python string """
  tokens = "".join(vocab[idx] for idx in ids)
  return tokens

In [None]:
def encode(string):
    toks = [stoi[x] for x in string]
    while len(toks) >= 2: # avoid error when only 1 token is passed
        pair_counts = get_pair_counts(toks)
        pair = min(pair_counts, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break # no more possible merges

        new_token = merges[pair]
        toks = replace_pair(toks, pair, new_token)
    return toks

In [None]:
class BPETokeniser:

    def __init__(self):
        self.stoi = {}
        self.vocab = {}
        self.merges = {}

    def decode(self, ids):
        """ given ids (list of integers), return Python string """
        tokens = "".join(self.vocab[idx] for idx in ids)
        return tokens

    def encode(string):
        toks = [self.stoi[x] for x in string]
        while len(toks) >= 2: # avoid error when only 1 token is passed
            pair_counts = get_pair_counts(toks)
            pair = min(pair_counts, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # no more possible merges

            new_token = self.merges[pair]
            toks = replace_pair(toks, pair, new_token)
        return toks

    def train(self, text, desired_vocab_size):
        unique_chars = sorted(list(set(text)))
        self.current_vocab_size = len(unique_chars)

        self.stoi = { ch:i for i,ch in enumerate(unique_chars) }
        self.vocab = { i:ch for i,ch in enumerate(unique_chars) }

        tokens = [self.stoi[x] for x in list(text)]

        num_merges = desired_vocab_size - current_vocab_size
        ids = list(tokens)

        for i in tqdm(range(num_merges), desc="Merges: ", leave=False):
            pair_counts = get_pair_counts(ids)
            max_occuring_pair = max(pair_counts, key=pair_counts.get)

            if pair_counts[max_occuring_pair] == 1:
                print("Max pair only occurs once, breaking from loop")
                break

            new_token_code = current_vocab_size + i
            ids = replace_pair(ids, max_occuring_pair, new_token_code)
            self.merges[max_occuring_pair] = new_token_code

        for (p0, p1), idx in merges.items(): # ordered dictionaries baby (ly python >= 3.8)
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]

In [None]:
bpe = BPETokeniser()

In [None]:
bpe.train(text, desired_vocab_size=276)

Merges:   0%|          | 0/211 [00:00<?, ?it/s]

True
