<a target="_blank" href="https://colab.research.google.com/github/holmrenser/deep_learning/blob/main/tokenization.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

# Tokenization

In [121]:
from collections import Counter
from itertools import chain
from tqdm.auto import trange
import json
from dataclasses import dataclass, field
from typing import Generator

with open('input.txt', 'r') as fh:
    data = fh.read()

text_bytes = data.encode('utf-8')
tokens = list(text_bytes)
print(data[:10])
print(tokens[:10])

First Citi
[70, 105, 114, 115, 116, 32, 67, 105, 116, 105]


In [125]:
def merge_tokens(tokens: list[int], token_pair: tuple[int,int], new_token: int) -> list[int]:
    new_tokens = []
    i = 0
    while i < len(tokens):
        token = tokens[i]
        if i == len(tokens) - 1:
            new_tokens.append(token)
            break
        next_token = tokens[i+1]
        if token_pair == (token, next_token):
            new_tokens.append(new_token)
            i += 2
        else:
            new_tokens.append(token)
            i += 1
    return new_tokens    

@dataclass
class BytePairEncoder:
    merges: dict[int, tuple[int, int]] = field(default_factory=dict)

    @property
    def vocab(self) -> dict[int, str]:
        base_vocab = {token: chr(token) for token in range(256)}        
        merge_vocab = {token: self.decode([token]) for token in self.merges}
        return base_vocab | merge_vocab

    def _token_to_bytes(self, token: int) -> Generator[int, None, None]:
        if token not in self.merges:
            yield token
            return
        for pair_token in self.merges[token]:
            if pair_token >= 256:
                yield from self._token_to_bytes(pair_token)
            else:
                yield pair_token

    def train(self, input: str, vocab_size: int = 512) -> None:
        assert vocab_size > 256, f'Invalid vocab_size: {vocab_size}, must be larger than 256'
        tokens = list(input.encode('utf-8'))
        num_merges = vocab_size - 256
        for i in trange(num_merges):
            pair_counts = Counter(zip(tokens[:-1], tokens[1:]))
            merge_pair = pair_counts.most_common(1)[0][0]
            new_token = 256 + i
            self.merges[new_token] = merge_pair
            tokens = merge_tokens(tokens, merge_pair, new_token)

    def encode(self, input: str) -> list[int]:
        tokens = list(input.encode('utf-8'))
        for new_token, merge_pair in self.merges.items():
            tokens = merge_tokens(tokens, merge_pair, new_token)
        return tokens

    def decode(self, tokens: list[int]) -> str:
        decoded_tokens = chain.from_iterable(map(self._token_to_bytes, tokens))
        return bytes(decoded_tokens).decode('utf-8', errors='replace')

    def save(self, prefix: str) -> None:
        with open(f'{prefix}.vocab', 'w') as fh:
            json.dump(self.vocab, fh)
        with open(f'{prefix}.model', 'w') as fh:
            json.dump(self.merges, fh)

    @classmethod
    def load(cls, model_filename: str) -> 'BytePairEncoder':
        with open(model_filename, 'r') as fh:
            merges = json.load(fh)
        sanitized_merges = {int(k):tuple(v) for k,v in merges.items()}
        return cls(sanitized_merges)

bpe = BytePairEncoder()
bpe.vocab

{0: '\x00',
 1: '\x01',
 2: '\x02',
 3: '\x03',
 4: '\x04',
 5: '\x05',
 6: '\x06',
 7: '\x07',
 8: '\x08',
 9: '\t',
 10: '\n',
 11: '\x0b',
 12: '\x0c',
 13: '\r',
 14: '\x0e',
 15: '\x0f',
 16: '\x10',
 17: '\x11',
 18: '\x12',
 19: '\x13',
 20: '\x14',
 21: '\x15',
 22: '\x16',
 23: '\x17',
 24: '\x18',
 25: '\x19',
 26: '\x1a',
 27: '\x1b',
 28: '\x1c',
 29: '\x1d',
 30: '\x1e',
 31: '\x1f',
 32: ' ',
 33: '!',
 34: '"',
 35: '#',
 36: '$',
 37: '%',
 38: '&',
 39: "'",
 40: '(',
 41: ')',
 42: '*',
 43: '+',
 44: ',',
 45: '-',
 46: '.',
 47: '/',
 48: '0',
 49: '1',
 50: '2',
 51: '3',
 52: '4',
 53: '5',
 54: '6',
 55: '7',
 56: '8',
 57: '9',
 58: ':',
 59: ';',
 60: '<',
 61: '=',
 62: '>',
 63: '?',
 64: '@',
 65: 'A',
 66: 'B',
 67: 'C',
 68: 'D',
 69: 'E',
 70: 'F',
 71: 'G',
 72: 'H',
 73: 'I',
 74: 'J',
 75: 'K',
 76: 'L',
 77: 'M',
 78: 'N',
 79: 'O',
 80: 'P',
 81: 'Q',
 82: 'R',
 83: 'S',
 84: 'T',
 85: 'U',
 86: 'V',
 87: 'W',
 88: 'X',
 89: 'Y',
 90: 'Z',
 91: '[',


In [126]:
bpe.train(data, vocab_size=260)
bpe.save('shakespeare_260')

  0%|          | 0/4 [00:00<?, ?it/s]

In [127]:
BytePairEncoder.load('./shakespeare_260.model')

BytePairEncoder(merges={256: (101, 32), 257: (116, 104), 258: (116, 32), 259: (115, 32)})

In [None]:
import sys
!{sys.executable} -m pip install sentencepiece

from sentencepiece import SentencePieceProcessor, SentencePieceTrainer

SentencePieceTrainer.train('--input=input.txt --model_prefix=shakespeare_200 --vocab_size=200 --model_type=bpe')

sp = SentencePieceProcessor()
sp.load('shakespeare_200.model')

sp.decode(sp.encode('hello how are you'))

sp.vocab_size()

sp.__dict__