<a href="https://colab.research.google.com/github/el-eshaano/ml/blob/main/BPE_Tokeniser.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from typing import Dict, Tuple
from tqdm.notebook import tqdm

In [3]:
with open('../datasets/Shakespeare/input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

unique_chars = sorted(list(set(text)))
vocab_size = len(unique_chars)

stoi = { ch:i for i,ch in enumerate(unique_chars) }
itos = { i:ch for i,ch in enumerate(unique_chars) }

tokens = [stoi[x] for x in list(text)]

In [4]:
def get_pair_counts(tokens: list[int]) -> Dict[Tuple[int, int], int]:

    pair_counts = {}
    for pair in zip(tokens, tokens[1:]):
        pair_counts[pair] = pair_counts.get(pair, 0) + 1

    return pair_counts

def replace_pair(arr, pair, replacement):

    new_arr = []
    i = 0
    while i < len(arr):
        if i < len(arr) - 1 and arr[i] == pair[0] and arr[i+1] == pair[1]:
            new_arr.append(replacement)
            i += 2
        else:
            new_arr.append(arr[i])
            i += 1

    return new_arr

In [5]:
desired_vocab_size = 256
current_vocab_size = vocab_size
# ---

num_merges = desired_vocab_size - current_vocab_size
ids = list(tokens)

merges = {}
for i in tqdm(range(num_merges), desc="Merges: ", leave=False):
    pair_counts = get_pair_counts(ids)
    max_occuring_pair = max(pair_counts, key=pair_counts.get)

    if pair_counts[max_occuring_pair] == 1:
        print("Max pair only occurs once, breaking from loop")
        break

    new_token_code = current_vocab_size + i
    ids = replace_pair(ids, max_occuring_pair, new_token_code)
    merges[max_occuring_pair] = new_token_code



Merges:   0%|          | 0/211 [00:00<?, ?it/s]

In [6]:
vocab = itos.copy()
for (p0, p1), idx in merges.items(): # ordered dictionaries baby (ly python >= 3.8)
    vocab[idx] = vocab[p0] + vocab[p1]

print(len(vocab))
# print(vocab)

276


In [7]:
def decode(ids):
  """ given ids (list of integers), return Python string """
  tokens = "".join(vocab[idx] for idx in ids)
  return tokens

In [8]:
def encode(string):
    toks = [stoi[x] for x in string]
    while len(toks) >= 2: # avoid error when only 1 token is passed
        pair_counts = get_pair_counts(toks)
        pair = min(pair_counts, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            break # no more possible merges

        new_token = merges[pair]
        toks = replace_pair(toks, pair, new_token)
    return toks

In [9]:
class BPETokeniser:

    def __init__(self):
        self.stoi = {}
        self.vocab = {}
        self.merges = {}

    def decode(self, ids):
        """ given ids (list of integers), return Python string """
        tokens = "".join(self.vocab[idx] for idx in ids)
        return tokens

    def encode(string):
        toks = [self.stoi[x] for x in string]
        while len(toks) >= 2: # avoid error when only 1 token is passed
            pair_counts = get_pair_counts(toks)
            pair = min(pair_counts, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # no more possible merges

            new_token = self.merges[pair]
            toks = replace_pair(toks, pair, new_token)
        return toks

    def train(self, text, desired_vocab_size):
        unique_chars = sorted(list(set(text)))
        self.current_vocab_size = len(unique_chars)

        self.stoi = { ch:i for i,ch in enumerate(unique_chars) }
        self.vocab = { i:ch for i,ch in enumerate(unique_chars) }

        tokens = [self.stoi[x] for x in list(text)]

        num_merges = desired_vocab_size - current_vocab_size
        ids = list(tokens)

        for i in tqdm(range(num_merges), desc="Merges: ", leave=False):
            pair_counts = get_pair_counts(ids)
            max_occuring_pair = max(pair_counts, key=pair_counts.get)

            if pair_counts[max_occuring_pair] == 1:
                print("Max pair only occurs once, breaking from loop")
                break

            new_token_code = current_vocab_size + i
            ids = replace_pair(ids, max_occuring_pair, new_token_code)
            self.merges[max_occuring_pair] = new_token_code

        for (p0, p1), idx in merges.items(): # ordered dictionaries baby (ly python >= 3.8)
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]

In [10]:
bpe = BPETokeniser()

In [None]:
bpe.train(text, desired_vocab_size=276)

Merges:   0%|          | 0/211 [00:00<?, ?it/s]