In [1]:
def get_stats(ids, counts=None):
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

In [2]:
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    
    return new_ids

In [3]:
import unicodedata
from unicodedata import category

In [4]:
category('\n')

'Cc'

In [5]:
ch = '\n'
f"\\u{ord(ch):04x}"

'\\u000a'

In [6]:
ch

'\n'

In [7]:
def replace_control_characters(s: str) -> str:
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != 'C':
            chars.append(ch)
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    
    return "".join(chars)

In [8]:
replace_control_characters("abcd\ne\r")

'abcd\\u000ae\\u000d'

In [9]:
def render_token(t: bytes) -> str:
    s = t.decode('utf-8', errors='replace')
    return replace_control_characters(s)

In [10]:
# Encoding individual characters
# Encoding the character 'A' (U+0041)
utf8_encoded_A = b'\x41'
print(utf8_encoded_A)  # Output: b'A'

# Encoding the Euro sign '€' (U+20AC)
utf8_encoded_euro = b'\xe2\x82\xac'
print(utf8_encoded_euro)  # Output: b'\xe2\x82\xac'

# Encoding the emoji '😊' (U+1F60A)
utf8_encoded_emoji = b'\xf0\x9f\x98\x8a'
print(utf8_encoded_emoji)  # Output: b'\xf0\x9f\x98\x8a'


# Encoding a string
# Encoding the string 'Hello, world!' in UTF-8
utf8_encoded_string = 'Hello, world!'.encode('utf-8')
print(utf8_encoded_string)  # Output: b'Hello, world!'

# Encoding a string with characters from multiple scripts
# Encoding the string '你好, world!' containing Chinese characters (U+4F60 U+597D)
utf8_encoded_multilingual_string = '你好, world!'.encode('utf-8')
print(utf8_encoded_multilingual_string)  # Output: b'\xe4\xbd\xa0\xe5\xa5\xbd, world!'

b'A'
b'\xe2\x82\xac'
b'\xf0\x9f\x98\x8a'
b'Hello, world!'
b'\xe4\xbd\xa0\xe5\xa5\xbd, world!'


In [11]:
list(utf8_encoded_multilingual_string)

[228, 189, 160, 229, 165, 189, 44, 32, 119, 111, 114, 108, 100, 33]

In [12]:
len(utf8_encoded_multilingual_string)

14

In [13]:
import math
math.pow(2, 23)

8388608.0

In [14]:
math.log(65535)

11.090339630053647

In [15]:
# understanding utf-8 encoding (smiley is reprsented using 4 bytes, registered using 2 and euro using 3)
# each number in the output list is [0, 255]
a = '\n\r\r😊®€'
#  list(b"".join([a.encode('utf-8')]))
list(a.encode('utf-8'))

[10, 13, 13, 240, 159, 152, 138, 194, 174, 226, 130, 172]

In [16]:
replace_control_characters('\n\rabcd')

'\\u000a\\u000dabcd'

In [17]:
a.encode('utf-8')

b'\n\r\r\xf0\x9f\x98\x8a\xc2\xae\xe2\x82\xac'

In [18]:
render_token(a.encode('utf-8'))

'\\u000a\\u000d\\u000d😊®€'

In [19]:
a = {}
a[(1, 1)] = 2
a[(2, 2)] = 3
a[(3, 3)] = 4

In [20]:
for idx1, idx2 in a:
    print(idx1, idx2)

1 1
2 2
3 3


In [21]:
for i,j in a.items():
    print(i, j)

(1, 1) 2
(2, 2) 3
(3, 3) 4


In [22]:
from pathlib import Path
p = Path('abcd')
model_file = p.with_suffix('.model')
model_file

PosixPath('abcd.model')

In [23]:
with open(model_file, 'w') as f:
    f.write('abcdefg')

In [24]:
class Tokenizer:
    """Base class for tokenizer"""

    def __init__(self):
        # default vocab size is 256 (same as ascii chars), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int eg. {'<|endoftext|>': 1}
        self.vocab = self._build_vocab() # int -> bytes
    
    def _build_vocab(self):
        vocab = {idx: bytes(idx) for idx in range(256)}
        # the fact that iteration order is same as order in which items are inserted is key here, otherwise we don't have vocab entries for previous merges
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode('utf-8')
        
        return vocab
    
    def train(self, text, vocab_size, verbose=False):
        raise NotImplementedError
    
    def encode(self, text):
        raise NotImplementedError
    
    def decode(self, ids):
        raise NotImplementedError
    
    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        Similar to sentencepiece
        - model file is used for model loading, vocab is just for human viz.
        """
        file = Path(file_prefix)
        model_file = file.with_suffix('.model')
        with open(model_file, 'w') as f:
            # write version, pattern and merges
            f.write('minbpe v1\n')
            f.write(f"{self.pattern}\n")
            # special tokens
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            
            # merges dict
            for idx1, idx2 in self.merges: # write only the ids of the merge
                f.write(f"{idx1} {idx2}\n")
        
        # write the vocab, for human viz
        # vocab file is different than actual vocab, file is lossy but self.vocab is good.
        vocab_file = file.with_suffix('.vocab')
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, 'w', encoding='utf-8') as f:
            for idx, token in self.vocab.items():
                # replaces some partial utf-8 seq into ? token, so this can't be decoded due to error = 'replace'
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") # we should be able to change this
                else:
                    # print the bytes and special characters, double check the special characters part.
                    f.write(f"[{s}] {idx}")
    
    def load(self, model_file):
        """Invert the functionality in save, but only for model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256

        with open(model_file, 'r', encoding='utf-8') as f: # this is decoding, but understand this part more.
            version = f.readline().strip()
            assert version == "minbpe v1"
            self.pattern = f.readline().strip()
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)

            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

In [25]:
class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
    
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        ids = list(text.encode('utf-8'))

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        idx = 256
        for i in range(num_merges):
            stats = get_stats(ids)
            top_pair = max(stats, key=stats.get)
            ids = merge(ids, top_pair, idx)
            merges[top_pair] = idx
            vocab[idx] = vocab[top_pair[0]] + vocab[top_pair[1]]
            if verbose:
                print(f"merge {i+1}/{merges}: {top_pair} -> {idx} {vocab[idx]} has {stats[top_pair]} occurences")
            idx +=1
        
        self.merges = merges
        self.vocab = vocab
    
    def decode(self, ids):
        """Converts ids to a string"""
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text
    
    def encode(self, text):
        """Retums ids from text"""
        ids = list(text.encode('utf-8'))
        while len(ids) >= 2:
            # find the element in stats that has the smallest associated value in merges
            stats = get_stats(ids)
            top_pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if top_pair not in self.merges:
                break
            ids = merge(ids, top_pair, self.merges[top_pair])
        
        return ids

In [26]:
tokenizer = BasicTokenizer()
tokenizer.train("How are you doing", 257)

In [27]:
tokenizer.decode(tokenizer.encode("abcd"))

'abcd'

In [28]:
import sys
sys.path

['/Users/htkumar/llms/tokenization/minbpe',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python38.zip',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8/lib-dynload',
 '',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8/site-packages']

In [29]:
# import os
# dirname = os.path.dirname(os.path.abspath(__file__))

In [30]:
# print(__file__)

In [31]:
# globals()

In [32]:
import regex as re

In [33]:
Tokenizer

__main__.Tokenizer

In [34]:
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [35]:
compiled_pattern = re.compile(GPT4_SPLIT_PATTERN)

In [36]:
text = "How are you doing 123's        :) 😉"

In [37]:
gpt4_tokens = re.findall(compiled_pattern, text)
gpt4_tokens


['How', ' are', ' you', ' doing', ' ', '123', "'s", '       ', ' :)', ' 😉']

In [38]:
list(gpt4_tokens[9].encode('utf-8'))

[32, 240, 159, 152, 137]

In [39]:
ids = [list(ch.encode('utf-8')) for ch in gpt4_tokens]
ids

[[72, 111, 119],
 [32, 97, 114, 101],
 [32, 121, 111, 117],
 [32, 100, 111, 105, 110, 103],
 [32],
 [49, 50, 51],
 [39, 115],
 [32, 32, 32, 32, 32, 32, 32],
 [32, 58, 41],
 [32, 240, 159, 152, 137]]

In [40]:
merges = {} # (int, int) -> int
vocab = {idx: bytes(idx) for idx in range(256)}
vocab[0]

b''

In [41]:
bytes([128])

b'\x80'

In [42]:
bytes(5) # returns bytes object with null bytes

b'\x00\x00\x00\x00\x00'

In [43]:
# bytes??

In [44]:
num_merges = 3

In [46]:
# for i in range(num_merges):
#     # maintain global count of occurences of consecutive pairs
#     stats = {}
#     for chunk_id in ids:
#         get_stats(chunk_id, stats)
    
#     # find the pair which occurs most consecutively
#     max_pair = max(stats, key=stats.get)
#     new_id = 256 + i
#     ids = [merge(chunk_id, max_pair, new_id) for chunk_id in ids]
#     merge[max_pair] = new_id
#     vocab[new_id] = vocab[max_pair[0]] + vocab[max_pair[1]]
    

In [91]:
class RegexTokenizer(Tokenizer):
    def __init__(self, pattern=None):
        """
        - pattern to split the text by, default is gpt-4 pattern
        """
        super().__init__()
        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.special_tokens = {} # dict from str -> int reprsenting special tokens
        self.inverse_special_tokens = {} # dict from int -> str reprsenting special tokens
    
    def register_special_tokens(self, special_tokens):
        # special tokens is a dict from str -> int
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v:k for k, v in self.special_tokens.items()}
    
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        vocab = {idx: bytes([idx]) for idx in range(256)}

        text_chunks = re.findall(self.compiled_pattern, text)
        ids = [list(chunk.encode('utf-8')) for chunk in text_chunks]
        merges = {}

        for i in range(num_merges):
            # maintain global count of occurences of consecutive pairs
            stats = {}
            for chunk_id in ids:
                get_stats(chunk_id, stats)
            
            # find the pair which occurs most consecutively
            max_pair = max(stats, key=stats.get)
            new_id = 256 + i
            ids = [merge(chunk_id, max_pair, new_id) for chunk_id in ids]
            merges[max_pair] = new_id
            vocab[new_id] = vocab[max_pair[0]] + vocab[max_pair[1]]

            # print stats
            if verbose:
                print(f"merge {i+1}/{num_merges}: {max_pair} -> {new_id} ({vocab[new_id]}) had {stats[max_pair]} occurences")
            
        self.merges = merges # used in encode()
        self.vocab = vocab # used in decode()
    
    def decode(self, ids):
        # return python string given list of integers
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            elif idx in self.inverse_special_tokens:  # understand this part thoroughly
                part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
            else:
                raise ValueError(f"Invalid token for decoding: {idx}")
        
        s = b"".join(part_bytes)
        return s.decode('utf-8', errors='replace')
    
    def _encode_chunk(self, text_bytes):
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the element in stats that has the smallest associated value in merges
            stats = get_stats(ids)
            top_pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if top_pair not in self.merges:
                break
            ids = merge(ids, top_pair, self.merges[top_pair])
        
        return ids
    
    def _encode_ordinary(self, text):
        "Encoding that ignores any special tokens"
        text_chunks = re.findall(self.compiled_pattern, text)
        text_bytes = [chunk.encode('utf-8') for chunk in text_chunks]
        encoded_out = []
        for text_byte in text_bytes:
            encoded_out.extend(self._encode_chunk(text_byte))
        
        return encoded_out
    
    def encode(self, text, allowed_special="none_raise"):
        """
        This function handles special tokens
        allowed_special: can be "all"|"none"|"none_raise"
        tiktoken default behavior is none_raise
        """
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        
        if not special:
            # revert to encode ordinary
            return self._encode_ordinary(text)
        
        # else need to handle special characters
        # enclosing in parenthesis makes it into a capturing group so that special tokens are includes in the output from split 
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        ids = []
        for chunk in special_chunks:
            if chunk in special:
                ids.append(special[chunk])
            else:
                ids.extend(self._encode_ordinary(chunk))
        
        return ids

In [48]:
# Next steps
# understanding gpt-2 and gpt-4 patterns more closely.
# Understand the chunking behavior of gpt4 and why it needs lists of integers.
# understand the special characters usage

In [49]:
a = [[1, 2], [3, 4]]

In [50]:
b = []
for i in a:
    b.extend(i)
b

[1, 2, 3, 4]

In [74]:
special = {
    '<|endoftext|>': 100257,
    '<|endofprompt|>': 100258,
}

In [88]:
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_pattern

'(<\\|endoftext\\|>|<\\|endofprompt\\|>)'

In [86]:
p = re.escape('<|endoftext|>')

In [77]:
re.match(p, '<|endoftext|>')

<regex.Match object; span=(0, 13), match='<|endoftext|>'>

In [78]:
re.match('<|endoftext|>', '<|endoftext|>')

<regex.Match object; span=(0, 1), match='<'>

In [83]:
text = "def <|endoftext|>Last document!!! 👋<|endofprompt|> abcd"

In [89]:
special_chunks = re.split(special_pattern, text)
special_chunks

['def ', '<|endoftext|>', 'Last document!!! 👋', '<|endofprompt|>', ' abcd']

In [93]:
tokenizer = RegexTokenizer()
long_str = "How are you doing, How are you doing"
tokenizer.train(long_str, verbose=True, vocab_size=260)

merge 1/4: (72, 111) -> 256 (b'Ho') had 2 occurences
merge 2/4: (256, 119) -> 257 (b'How') had 2 occurences
merge 3/4: (32, 97) -> 258 (b' a') had 2 occurences
merge 4/4: (258, 114) -> 259 (b' ar') had 2 occurences


In [95]:
tokenizer.decode(tokenizer.encode("Hello World"))

'Hello World'