In [2]:
def get_stats(ids,counts = None):
    counts = {} if counts is None else counts
    for pair in zip(ids[:],ids[1:]):   # iterate consecutive elements
        counts[pair] = counts.get(pair,0)+1
    return counts

In [4]:
#ids = list of integer
# pair = pair of consecutive index that we are going to replace
# idx = replace it to index token idx
def merge(ids,pair,idx):
    newids = []
    i = 0
    while i < len(ids):
        if ids[i]== pair[0] and i < len(ids)-1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i +=2
        else:
            newids.append(ids[i])
            i +=1
    return newids

In [5]:
# replace control characters =  non-printing character like \n
# we don't want to print control characters
def replace_control_characters(s:str)-> str:
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != "C":
            chars.append(ch) # this character is ok
        else:
            chars.append(f"\\u{ord(ch):04x}")
    return "".join(chars)


In [6]:
# nicer way to decode / print the final text by not printing control characters
def render_token(t: bytes)-> str:
    s = t.decode('utf-8',errors='replace')
    s = replace_control_characters(s)
    return s

In [8]:
class Tokenizer:
    
    def __init__(self):
        self.merges = {} # (int,int) -> int
        self.pattern = "" # string
        self.special_tokens = {} #str -> int ,e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab()  # int -> bytes
        
    def train(self,text,vocab_size, verbose = False):
        # Tokenizer can train a vocabulary of size vocab_size from text
        raise NotImplementedError 
    
    def encode(self,text):
        # Tokenizer can encode a string into a list of integers
        raise NotImplementedError
    
    def decode(self,ids):
        # Tokenizer can decode a list of integers into a string
        raise NotImplementedError
    
    def _build_vocab(self):
        # vocab is simply and deterministically derived from merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0,p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
            
        for special,idx in self.special_tokens.items():
            vocab[idx]= special.encode("utf-8")
            
        return vocab

In [70]:
class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
    
    def train(self,text,vocab_size,verbose = False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256 
        
        #input text pre-processing 
        text_bytes = text.encode("utf-8") # raw bytes 
        ids = list(text_bytes) # list of integers in range 0..255
        
        # iteratively merge the most common pairs to create new tokens
        merges = {} # (int, int) -> int
        vocab = {idx: bytes([idx]) for idx in range(256)}
        
        for i in range(num_merges):
             # count up the number of times every consecutive pair appears
            stats = get_stats(ids)
            
             # find the pair with the highest count
            pair = max(stats,key = stats.get)
            
            # mint a new token: assign it the next available id
            idx = 256 +i
            
            # replace all occurrences of pair in ids with idx
            ids = merge(ids,pair,idx)
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            
            # print - easier for debugging 
            if verbose:
                print(f"merge {i+1}/{num_merges} : {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurences")
        
        # save class variables for future use
        self.merges = merges  # it will be used in encode()
        self.vocab = vocab   # it will be used in decode()
        
    def decode(self,ids):
        # given ids (list of integers), return Python string
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8",errors = "replace")
        return text
        
    def encode(self,text):
        # given a string text, return the token ids
        text_bytes = text.encode("utf-8") # raw bytes
        ids = list(text_bytes) # list of integers in range 0..255
        while len(ids) >=2:
            # find the pair with the lowest merge index -> it must be the latest merge 
            stats = get_stats(ids)
            pair = min(stats,key = lambda p: self.merges.get(p,float("inf")))
            #if there are no more merges available, the key will
            # result in an inf for every single pair, and the min will be
            # just the first pair in the list,
            
            if pair not in self.merges:
                break # nothing else can be merged anymore
            
            # otherwise let's merge the best pair (lowest merge index)
            idx = self.merges[pair]
            ids = merge(ids,pair,idx)
        
        return ids
            

In [106]:
text = open('taylorswift.txt','r',encoding="utf-8").read()

In [44]:
text[:20]

'Copy paste of the Wi'

In [66]:
# BasicTokenizer.encode(text)

In [68]:
text[:100]

'Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.\n---\n\nMain menu\n\nWikipediaTh'

In [71]:
tokenizer = BasicTokenizer()

In [72]:
tokenizer.train(text,512,verbose=True)

merge 1/256 : (101, 32) -> 256 (b'e ') had 2981 occurences
merge 2/256 : (44, 32) -> 257 (b', ') had 2961 occurences
merge 3/256 : (100, 32) -> 258 (b'd ') had 2617 occurences
merge 4/256 : (46, 32) -> 259 (b'. ') had 2560 occurences
merge 5/256 : (114, 32) -> 260 (b'r ') had 2428 occurences
merge 6/256 : (50, 48) -> 261 (b'20') had 2365 occurences
merge 7/256 : (115, 32) -> 262 (b's ') had 2053 occurences
merge 8/256 : (105, 110) -> 263 (b'in') had 2006 occurences
merge 9/256 : (111, 110) -> 264 (b'on') had 1815 occurences
merge 10/256 : (114, 105) -> 265 (b'ri') had 1805 occurences
merge 11/256 : (116, 32) -> 266 (b't ') had 1802 occurences
merge 12/256 : (116, 104) -> 267 (b'th') had 1737 occurences
merge 13/256 : (101, 258) -> 268 (b'ed ') had 1736 occurences
merge 14/256 : (257, 261) -> 269 (b', 20') had 1705 occurences
merge 15/256 : (97, 110) -> 270 (b'an') had 1487 occurences
merge 16/256 : (97, 114) -> 271 (b'ar') had 1360 occurences
merge 17/256 : (101, 260) -> 272 (b'er ') h

In [73]:
# let's see our vocab
for token, token_bytes in tokenizer.vocab.items():
    if token>255:
        print(f"Token ID: {token} -> Merged Bytes: {token_bytes} -> As String: '{token_bytes.decode('utf-8', 'replace')}'")


Token ID: 256 -> Merged Bytes: b'e ' -> As String: 'e '
Token ID: 257 -> Merged Bytes: b', ' -> As String: ', '
Token ID: 258 -> Merged Bytes: b'd ' -> As String: 'd '
Token ID: 259 -> Merged Bytes: b'. ' -> As String: '. '
Token ID: 260 -> Merged Bytes: b'r ' -> As String: 'r '
Token ID: 261 -> Merged Bytes: b'20' -> As String: '20'
Token ID: 262 -> Merged Bytes: b's ' -> As String: 's '
Token ID: 263 -> Merged Bytes: b'in' -> As String: 'in'
Token ID: 264 -> Merged Bytes: b'on' -> As String: 'on'
Token ID: 265 -> Merged Bytes: b'ri' -> As String: 'ri'
Token ID: 266 -> Merged Bytes: b't ' -> As String: 't '
Token ID: 267 -> Merged Bytes: b'th' -> As String: 'th'
Token ID: 268 -> Merged Bytes: b'ed ' -> As String: 'ed '
Token ID: 269 -> Merged Bytes: b', 20' -> As String: ', 20'
Token ID: 270 -> Merged Bytes: b'an' -> As String: 'an'
Token ID: 271 -> Merged Bytes: b'ar' -> As String: 'ar'
Token ID: 272 -> Merged Bytes: b'er ' -> As String: 'er '
Token ID: 273 -> Merged Bytes: b'y ' -> 

In [155]:
import json
data = {
    "merges": {f"{p0},{p1}": idx for (p0,p1), idx in tokenizer.merges.items()},
    "vocab": {str(idx): token_bytes.decode("latin-1") for idx, token_bytes in tokenizer.vocab.items()},
}
with open("basic_tokenizer.json", "w") as f:
    json.dump(data, f)


In [58]:
import regex as re

In [59]:
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [107]:
class RegexTokenizer(Tokenizer):
    
    def __init__(self, pattern = None):
        super().__init__()
        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.special_token = {}
        self.inverse_special_token = {}
        
    def train(self,text,vocab_size,verbose = False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        
        # split the text up into text chunks 
        text_chunks = re.findall(self.compiled_pattern,text)
        
        # input text preprocessing
        ids = [list(ch.encode("utf-8")) for ch in text_chunks]
        
        # iteratively merge the most common pairs to create new tokens
        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
#         vocab = {idx: bytes([idx]) for idx in range(256)}
        for i in range(num_merges):
            # count the number of times every consecutive pair appears
            stats = {}
            for chunks_ids in ids:
                # passing in stats will update it in place, adding up counts
                get_stats(chunks_ids,stats)
            
            # find the pair with the highest count
            pair = max(stats,key = stats.get)
            # mint a new token: assign it the next available id
            idx = 256 +i
            # replace all occurrences of pair in ids with idx
            ids = [merge(chunks_ids,pair,idx) for chunks_ids in ids]
            # save the merge
            merges[pair] = idx
            vocab[idx] = vocab[pair[0]] + vocab[pair[1]]
            # prints
            if verbose:
                print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab[idx]}) had {stats[pair]} occurrences")
            
            # save class variables
            self.merges = merges # used in encode()
            self.vocab = vocab # used in decode()
            
    def register_special_tokens(self,special_tokens):
        # special_tokens is a dictionary of str -> int
        # example: {"<|endoftext|>": 100257}
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v:k for k,v in special_tokens.items()}
        
    def decode(self,ids):
        # given ids (list of integers), return Python string
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            elif idx in self.inverse_special_tokens:
                part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
            else:
                raise ValueError(f"invalid token id: {idx}")
                
        text_bytes = b"".join(part_bytes)
        text = text_bytes.decode("utf-8",errors = "replace")
        return text
    
    def _encode_chunk(self,text_bytes):
        # first, convert all bytes to integers in range 0..255
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the pair with the lowest merge index
            stats = get_stats(ids)
            pair = min(stats, key = lambda p: self.merges.get(p,float("inf")))
            if pair not in self.merges:
                break
            idx = self.merges[pair]
            ids = merge(ids,pair,idx)
            
        return ids
    
    def encode_ordinary(self,text):
        # """Encoding that ignores any special tokens."""
        # split text into chunks of text by categories defined in regex pattern
        text_chunks = re.findall(self.compiled_pattern,text)
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for chunks in text_chunks:
            chunks_bytes = chunks.encode("utf-8")
            chunks_ids = self._encode_chunk(chunks_bytes)
            ids.extend(chunks_ids)
        return ids
    
    def encode(self,text,allowed_special = "none_raise"):
        special = None
        if allowed_special == "all":
            special = self.special_tokens
            
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special,set):
            special = {k:v for k,v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        if not special:
            # shortcut: if no special tokens, just use the ordinary encoding
            return self.encode_ordinary(text)
        
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern,text)
        # now all the special characters are separated from the rest of the text
        # all chunks of text are encoded separately, then results are joined
        ids = []
        for part in special_chunks:
            if part in special:
                
                 # this is a special token, encode it separately as a special case
                ids.append(special[part])
            else:
                # this is an ordinary sequence, encode it normally
                ids.extend(self.encode_ordinary(part))
        return ids
                    
        
        

In [108]:
tokenizer2 = RegexTokenizer()

In [109]:
tokenizer2.train(text,512,verbose=True)

merge 1/256: (101, 114) -> 256 (b'er') had 2359 occurrences
merge 2/256: (50, 48) -> 257 (b'20') had 2187 occurrences
merge 3/256: (111, 114) -> 258 (b'or') had 2076 occurrences
merge 4/256: (105, 110) -> 259 (b'in') had 2006 occurrences
merge 5/256: (101, 100) -> 260 (b'ed') had 1876 occurrences
merge 6/256: (32, 116) -> 261 (b' t') had 1824 occurrences
merge 7/256: (111, 110) -> 262 (b'on') had 1815 occurrences
merge 8/256: (104, 101) -> 263 (b'he') had 1772 occurrences
merge 9/256: (32, 83) -> 264 (b' S') had 1633 occurrences
merge 10/256: (97, 114) -> 265 (b'ar') had 1519 occurrences
merge 11/256: (97, 110) -> 266 (b'an') had 1487 occurrences
merge 12/256: (32, 65) -> 267 (b' A') had 1335 occurrences
merge 13/256: (261, 263) -> 268 (b' the') had 1169 occurrences
merge 14/256: (97, 108) -> 269 (b'al') had 1164 occurrences
merge 15/256: (114, 105) -> 270 (b'ri') had 1156 occurrences
merge 16/256: (118, 260) -> 271 (b'ved') had 1104 occurrences
merge 17/256: (115, 116) -> 272 (b'st') 

In [110]:
# let's see our vocab
for token, token_bytes in tokenizer2.vocab.items():
    if token>255:
        print(f"Token ID: {token} -> Merged Bytes: {token_bytes} -> As String: '{token_bytes.decode('utf-8', 'replace')}'")


Token ID: 256 -> Merged Bytes: b'er' -> As String: 'er'
Token ID: 257 -> Merged Bytes: b'20' -> As String: '20'
Token ID: 258 -> Merged Bytes: b'or' -> As String: 'or'
Token ID: 259 -> Merged Bytes: b'in' -> As String: 'in'
Token ID: 260 -> Merged Bytes: b'ed' -> As String: 'ed'
Token ID: 261 -> Merged Bytes: b' t' -> As String: ' t'
Token ID: 262 -> Merged Bytes: b'on' -> As String: 'on'
Token ID: 263 -> Merged Bytes: b'he' -> As String: 'he'
Token ID: 264 -> Merged Bytes: b' S' -> As String: ' S'
Token ID: 265 -> Merged Bytes: b'ar' -> As String: 'ar'
Token ID: 266 -> Merged Bytes: b'an' -> As String: 'an'
Token ID: 267 -> Merged Bytes: b' A' -> As String: ' A'
Token ID: 268 -> Merged Bytes: b' the' -> As String: ' the'
Token ID: 269 -> Merged Bytes: b'al' -> As String: 'al'
Token ID: 270 -> Merged Bytes: b'ri' -> As String: 'ri'
Token ID: 271 -> Merged Bytes: b'ved' -> As String: 'ved'
Token ID: 272 -> Merged Bytes: b'st' -> As String: 'st'
Token ID: 273 -> Merged Bytes: b'wi' -> As

In [111]:
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text1 = enc.decode(ids) # get the same text back
print(text1)

hello world!!!? (안녕하세요!) lol123 😉


In [114]:
ids2 = tokenizer2.encode("hello world!!!? (안녕하세요!) lol123 😉")
text2 = tokenizer2.decode(ids2)
print(text2)

hello world!!!? (안녕하세요!) lol123 😉


In [156]:
import json
data = {
    "merges": {f"{p0},{p1}": idx for (p0,p1), idx in tokenizer2.merges.items()},
    "vocab": {str(idx): token_bytes.decode("latin-1") for idx, token_bytes in tokenizer2.vocab.items()},
    "special_tokens": tokenizer2.special_tokens
}
with open("regex_tokenizer.json", "w") as f:
    json.dump(data, f)


In [118]:
def bpe(mergeable_ranks, token, max_rank):
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i , pair in enumerate(zip(parts[:-1], parts[1::])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
                
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx+1]] + parts[min_idx+2:]
        
    return parts


In [119]:
def recover_merges(mergeable_ranks):
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) ==1:
            continue
        pair = tuple(bpe(mergeable_ranks,token,rank))
        assert len(pair)==2
        rank1 = mergeable_ranks[pair[0]]
        rank2 = mergeable_ranks[pair[1]]
        merges[(rank1,rank2)]= rank
    return merges

In [120]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}

In [145]:
class GPT4Tokenizer(RegexTokenizer):
    def __init__(self):
        super().__init__()
        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        self.merges = recover_merges(mergeable_ranks)
        # reconstruct the vocab from the merges
        vocab = {idx:bytes([idx]) for idx in range(256)}
        for (p0,p1),idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        self.vocab = vocab
        self.byte_shuffle = {i:mergeable_ranks[bytes([i])] for i in range(256)}
        self.inverse_byte_shuffle = {v:k for  k,v in self.byte_shuffle.items()}
        self.register_special_tokens(GPT4_SPECIAL_TOKENS)
        
    def _encode_chunk(self,text_bytes):
        text_bytes = bytes(self.byte_shuffle[b] for b in text_bytes)
        ids = super()._encode_chunk(text_bytes)
        return ids
    
    def decode(self,ids):
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text_bytes = bytes(self.inverse_byte_shuffle[b] for b in text_bytes)
        text = text_bytes.decode("utf-8",errors="replace")
        return text
    
    def train(self,text,vocab_size,verbose=False):
        raise NotImplementedError
        
    def save(self,file_prefix):
        raise NotImplementedError("GPT4Tokenizer cannot be saved.")

    def load(self, model_file):
        raise NotImplementedError("GPT4Tokenizer cannot be loaded.")
        
    def save_vocab(self,vocab_file):
        vocab = {idx:bytes([self.inverse_byte_shuffle[idx]]) for idx in range(256)}
        for (p0,p1),idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        #  # now merge the shuffled bytes and write to file
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, "w", encoding="utf-8") as f:
            for idx, token in vocab.items():
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(vocab[idx0])
                    s1 = render_token(vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n")
                else:
                    f.write(f"[{s}] {idx}\n")
        

In [146]:
tokenizer3 = GPT4Tokenizer()

In [147]:
ids3 = tokenizer3.encode("hello world!!!? (안녕하세요!) lol123 😉")
text3 = tokenizer3.decode(ids3)
print(text3)

hello world!!!? (안녕하세요!) lol123 😉


In [151]:
ids3 = tokenizer3.encode("Hello/n    World")
text3 = tokenizer3.decode(ids3)
print(ids3)

[9906, 9809, 262, 4435]


In [154]:
ids2 = tokenizer2.encode("Hello/n    World")
text2 = tokenizer2.decode(ids2)
print(ids2), print(len(ids2))

[72, 101, 301, 111, 47, 110, 32, 32, 32, 346, 258, 509]
12


(None, None)