In [3]:
# get different pair counts
def get_stats(ids, counts = None):
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# debug
example = [1, 2, 3, 4, 6, 2, 3, 4]
counts = get_stats(example)
print(f'counts: \n\t{counts}')

counts: 
	{(1, 2): 1, (2, 3): 2, (3, 4): 2, (4, 6): 1, (6, 2): 1}


In [4]:
# merge according to stats, merge results
def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i] == pair[0] and ids[i+1] == pair[1] and i < len(ids) - 1:
            newids.append(idx)
            i += 2 
        else:
            newids.append(ids[i])
            i += 1
    return newids

# debug
ids=[1, 2, 3, 1, 2]
pair=(1, 2)
newids = merge(ids, pair, 4)
print(f'newids: \n\t{newids}')

newids: 
	[4, 3, 4]


In [None]:
# build vocab
def build_vocab():
    merges = {}
    vocab = {idx: bytes([idx]) for idx in range(8)}
    for (p0, p1), idx in merges.items():
        vocab[idx] = vocab[p0] + vocab[p1]
    return vocab

In [6]:
# BPE
class BasicTokenizer():
    def __init__(self):
        self.merges = {}
        self.vocab = self.build_vocab()
    
    def build_vocab(self):
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        return vocab
    
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        text_bytes = text.encode("utf-8") 
        ids = list(text_bytes) 

        merges = {} 
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats, key=stats.get) 
            idx = 256 + i
            # returned ids is the new token where corresponding elements were replaced by pair idx
            ids = merge(ids, pair, idx)
            # merges is the new merged alphabet
            merges[pair] = idx
            # vocab is the total alphabet
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
            self.merges = merges
            self.vocab = vocab
            
    def text2id(self, text):
        text = "N".encode('utf-8')
        for k, p in self.vocab.items():
            if p == text:
                return k
            else:
                return 10086
# debug
bpe = BasicTokenizer()
text = '''   
Large Language Models is all you need,
what can i say, manba out. 
Attention is All you need.
Vision Transformers, 
Generative Pretrained Transformers,
Reinforcement leraning from human feedback
chain of thought is basic resoning tool.
LLMs can evaluate NLP results.
Richard Sutton Refinforcement Learning Introduction edition 2.
encoder-only
'''
vocab_size = len(list(text.encode('utf-8')))
bpe.train(text.replace('\n', ''), vocab_size = vocab_size)
for i in range(256,vocab_size,1):
    print(bpe.vocab[i])
print(bpe.merges)

b'n '
b'in'
b'an'
b's '
b'ed'
b'en'
b'on '
b'er'
b'e '
b'ou'
b't '
b'ti'
b'fo'
b'for'
b'ar'
b'od'
b'is '
b'eed'
b'an '
b'ba'
b'tion '
b're'
b'ing'
b'ing '
b'  '
b'ge '
b'ua'
b'al'
b'l '
b'l y'
b'l you'
b'l you '
b'l you n'
b'l you need'
b'ha'
b'can '
b', '
b'tt'
b'si'
b'Tr'
b'Tran'
b'Trans'
b'Transfor'
b'Transform'
b'Transformer'
b'Transformers'
b'tr'
b'Re'
b'infor'
b'inforc'
b'inforce'
b'inforcem'
b'inforcemen'
b'inforcement '
b' re'
b' res'
b'on'
b'   '
b'   L'
b'   Lar'
b'   Large '
b'   Large L'
b'   Large Lan'
b'   Large Lang'
b'   Large Langua'
b'   Large Language '
b'   Large Language M'
b'   Large Language Mod'
b'   Large Language Mode'
b'   Large Language Model'
b'   Large Language Models '
b'   Large Language Models is '
b'   Large Language Models is al'
b'   Large Language Models is all you need'
b'   Large Language Models is all you need,'
b'   Large Language Models is all you need,w'
b'   Large Language Models is all you need,wha'
b'   Large Language Models is all you need

In [7]:
# encoder
# utf-8 token ids
text = 'i love transfromers'
text_bytes = text.encode("utf-8") # raw bytes

# bpe token ids
ids = list(text_bytes) # list of integers in range 0..255
while len(ids) >= 2:
    stats = get_stats(ids)
    # ids = (2,3,4,5)
    #    2
    # 2, 3
    # 3, 4
    # 4, 5
    # 5
    # stats： pair(2,3), (3,4), (4,5)
    # pair (2,3)~268, (3,4)~269 ...
  
    # bpe.merges.get([3,4]) = 268
    # bpe.merges.get([3,5]) = 289
    # bpe.merges.get([4,5]) = inf
    # min means most frequently in 'merge'
    pair = min(stats, key=lambda p: bpe.merges.get(p, float("inf"))) 
    print(pair)
    print(bpe.vocab[pair[0]], bpe.vocab[pair[1]])
    if pair not in bpe.merges:
        break 
    idx = bpe.merges[pair] # (3,4) -> 268
    ids = merge(ids, pair, idx) # (2,3,4,5) -> (2, 268, 5)
print(ids)

(97, 110)
b'a' b'n'
(101, 114)
b'e' b'r'
(101, 32)
b'e' b' '
(116, 114)
b't' b'r'
(105, 32)
b'i' b' '
[105, 32, 108, 111, 118, 264, 302, 258, 115, 102, 114, 111, 109, 263, 115]


In [8]:
# decoder
text_bytes = b"".join(bpe.vocab[idx] for idx in ids)
decode_text = text_bytes.decode("utf-8", errors="replace")
print(decode_text)

i love transfromers
