In [2]:
def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1
    
    return new_ids

In [3]:
import unicodedata
from unicodedata import category

In [4]:
category('\n')

'Cc'

In [5]:
ch = '\n'
f"\\u{ord(ch):04x}"

'\\u000a'

In [6]:
ch

'\n'

In [7]:
def replace_control_characters(s: str) -> str:
    chars = []
    for ch in s:
        if unicodedata.category(ch)[0] != 'C':
            chars.append(ch)
        else:
            chars.append(f"\\u{ord(ch):04x}") # escape
    
    return "".join(chars)

In [8]:
replace_control_characters("abcd\ne\r")

'abcd\\u000ae\\u000d'

In [9]:
def render_token(t: bytes) -> str:
    s = t.decode('utf-8', errors='replace')
    return replace_control_characters(s)

In [10]:
# Encoding individual characters
# Encoding the character 'A' (U+0041)
utf8_encoded_A = b'\x41'
print(utf8_encoded_A)  # Output: b'A'

# Encoding the Euro sign '€' (U+20AC)
utf8_encoded_euro = b'\xe2\x82\xac'
print(utf8_encoded_euro)  # Output: b'\xe2\x82\xac'

# Encoding the emoji '😊' (U+1F60A)
utf8_encoded_emoji = b'\xf0\x9f\x98\x8a'
print(utf8_encoded_emoji)  # Output: b'\xf0\x9f\x98\x8a'


# Encoding a string
# Encoding the string 'Hello, world!' in UTF-8
utf8_encoded_string = 'Hello, world!'.encode('utf-8')
print(utf8_encoded_string)  # Output: b'Hello, world!'

# Encoding a string with characters from multiple scripts
# Encoding the string '你好, world!' containing Chinese characters (U+4F60 U+597D)
utf8_encoded_multilingual_string = '你好, world!'.encode('utf-8')
print(utf8_encoded_multilingual_string)  # Output: b'\xe4\xbd\xa0\xe5\xa5\xbd, world!'

b'A'
b'\xe2\x82\xac'
b'\xf0\x9f\x98\x8a'
b'Hello, world!'
b'\xe4\xbd\xa0\xe5\xa5\xbd, world!'


In [11]:
list(utf8_encoded_multilingual_string)

[228, 189, 160, 229, 165, 189, 44, 32, 119, 111, 114, 108, 100, 33]

In [12]:
len(utf8_encoded_multilingual_string)

14

In [13]:
import math
math.pow(2, 23)

8388608.0

In [14]:
math.log(65535)

11.090339630053647

In [15]:
# understanding utf-8 encoding (smiley is reprsented using 4 bytes, registered using 2 and euro using 3)
# each number in the output list is [0, 255]
a = '\n\r\r😊®€'
#  list(b"".join([a.encode('utf-8')]))
list(a.encode('utf-8'))

[10, 13, 13, 240, 159, 152, 138, 194, 174, 226, 130, 172]

In [16]:
replace_control_characters('\n\rabcd')

'\\u000a\\u000dabcd'

In [17]:
a.encode('utf-8')

b'\n\r\r\xf0\x9f\x98\x8a\xc2\xae\xe2\x82\xac'

In [18]:
render_token(a.encode('utf-8'))

'\\u000a\\u000d\\u000d😊®€'

In [19]:
a = {}
a[(1, 1)] = 2
a[(2, 2)] = 3
a[(3, 3)] = 4

In [20]:
for idx1, idx2 in a:
    print(idx1, idx2)

1 1
2 2
3 3


In [21]:
for i,j in a.items():
    print(i, j)

(1, 1) 2
(2, 2) 3
(3, 3) 4


In [22]:
from pathlib import Path
p = Path('abcd')
model_file = p.with_suffix('.model')
model_file

PosixPath('abcd.model')

In [23]:
with open(model_file, 'w') as f:
    f.write('abcdefg')

In [24]:
class Tokenizer:
    """Base class for tokenizer"""

    def __init__(self):
        # default vocab size is 256 (same as ascii chars), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int eg. {'<|endoftext|>': 1}
        self.vocab = self._build_vocab() # int -> bytes
    
    def _build_vocab(self):
        vocab = {idx: bytes(idx) for idx in range(256)}
        # the fact that iteration order is same as order in which items are inserted is key here, otherwise we don't have vocab entries for previous merges
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        for special, idx in self.special_tokens.items():
            vocab[idx] = special.encode('utf-8')
        
        return vocab
    
    def train(self, text, vocab_size, verbose=False):
        raise NotImplementedError
    
    def encode(self, text):
        raise NotImplementedError
    
    def decode(self, ids):
        raise NotImplementedError
    
    def save(self, file_prefix):
        """
        Saves two files: file_prefix.vocab and file_prefix.model
        Similar to sentencepiece
        - model file is used for model loading, vocab is just for human viz.
        """
        file = Path(file_prefix)
        model_file = file.with_suffix('.model')
        with open(model_file, 'w') as f:
            # write version, pattern and merges
            f.write('minbpe v1\n')
            f.write(f"{self.pattern}\n")
            # special tokens
            f.write(f"{len(self.special_tokens)}\n")
            for special, idx in self.special_tokens.items():
                f.write(f"{special} {idx}\n")
            
            # merges dict
            for idx1, idx2 in self.merges: # write only the ids of the merge
                f.write(f"{idx1} {idx2}\n")
        
        # write the vocab, for human viz
        # vocab file is different than actual vocab, file is lossy but self.vocab is good.
        vocab_file = file.with_suffix('.vocab')
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, 'w', encoding='utf-8') as f:
            for idx, token in self.vocab.items():
                # replaces some partial utf-8 seq into ? token, so this can't be decoded due to error = 'replace'
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(self.vocab[idx0])
                    s1 = render_token(self.vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") # we should be able to change this
                else:
                    # print the bytes and special characters, double check the special characters part.
                    f.write(f"[{s}] {idx}")
    
    def load(self, model_file):
        """Invert the functionality in save, but only for model file"""
        assert model_file.endswith(".model")
        # read the model file
        merges = {}
        special_tokens = {}
        idx = 256

        with open(model_file, 'r', encoding='utf-8') as f: # this is decoding, but understand this part more.
            version = f.readline().strip()
            assert version == "minbpe v1"
            self.pattern = f.readline().strip()
            num_special = int(f.readline().strip())
            for _ in range(num_special):
                special, special_idx = f.readline().strip().split()
                special_tokens[special] = int(special_idx)

            # read the merges
            for line in f:
                idx1, idx2 = map(int, line.split())
                merges[(idx1, idx2)] = idx
                idx += 1
        
        self.merges = merges
        self.special_tokens = special_tokens
        self.vocab = self._build_vocab()

In [25]:
class BasicTokenizer(Tokenizer):
    def __init__(self):
        super().__init__()
    
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        ids = list(text.encode('utf-8'))

        merges = {}
        vocab = {idx: bytes([idx]) for idx in range(256)}
        idx = 256
        for i in range(num_merges):
            stats = get_stats(ids)
            top_pair = max(stats, key=stats.get)
            ids = merge(ids, top_pair, idx)
            merges[top_pair] = idx
            vocab[idx] = vocab[top_pair[0]] + vocab[top_pair[1]]
            if verbose:
                print(f"merge {i+1}/{merges}: {top_pair} -> {idx} {vocab[idx]} has {stats[top_pair]} occurences")
            idx +=1
        
        self.merges = merges
        self.vocab = vocab
    
    def decode(self, ids):
        """Converts ids to a string"""
        text_bytes = b"".join(self.vocab[idx] for idx in ids)
        text = text_bytes.decode("utf-8", errors="replace")
        return text
    
    def encode(self, text):
        """Retums ids from text"""
        ids = list(text.encode('utf-8'))
        while len(ids) >= 2:
            # find the element in stats that has the smallest associated value in merges
            stats = get_stats(ids)
            top_pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if top_pair not in self.merges:
                break
            ids = merge(ids, top_pair, self.merges[top_pair])
        
        return ids

In [26]:
tokenizer = BasicTokenizer()
tokenizer.train("How are you doing", 257)

In [27]:
tokenizer.decode(tokenizer.encode("abcd"))

'abcd'

In [28]:
import sys
sys.path

['/Users/htkumar/llms/tokenization/minbpe',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python38.zip',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8/lib-dynload',
 '',
 '/Users/htkumar/anaconda3/envs/myenv/lib/python3.8/site-packages']

In [29]:
# import os
# dirname = os.path.dirname(os.path.abspath(__file__))

In [30]:
# print(__file__)

In [31]:
# globals()

In [32]:
import regex as re

In [33]:
Tokenizer

__main__.Tokenizer

In [34]:
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [35]:
compiled_pattern = re.compile(GPT4_SPLIT_PATTERN)

In [36]:
text = "How are you doing 123's        :) 😉"

In [37]:
gpt4_tokens = re.findall(compiled_pattern, text)
gpt4_tokens


['How', ' are', ' you', ' doing', ' ', '123', "'s", '       ', ' :)', ' 😉']

In [38]:
list(gpt4_tokens[9].encode('utf-8'))

[32, 240, 159, 152, 137]

In [39]:
ids = [list(ch.encode('utf-8')) for ch in gpt4_tokens]
ids

[[72, 111, 119],
 [32, 97, 114, 101],
 [32, 121, 111, 117],
 [32, 100, 111, 105, 110, 103],
 [32],
 [49, 50, 51],
 [39, 115],
 [32, 32, 32, 32, 32, 32, 32],
 [32, 58, 41],
 [32, 240, 159, 152, 137]]

In [40]:
merges = {} # (int, int) -> int
vocab = {idx: bytes(idx) for idx in range(256)}
vocab[0]

b''

In [41]:
bytes([128])

b'\x80'

In [42]:
bytes(5) # returns bytes object with null bytes

b'\x00\x00\x00\x00\x00'

In [43]:
# bytes??

In [44]:
num_merges = 3

In [45]:
# for i in range(num_merges):
#     # maintain global count of occurences of consecutive pairs
#     stats = {}
#     for chunk_id in ids:
#         get_stats(chunk_id, stats)
    
#     # find the pair which occurs most consecutively
#     max_pair = max(stats, key=stats.get)
#     new_id = 256 + i
#     ids = [merge(chunk_id, max_pair, new_id) for chunk_id in ids]
#     merge[max_pair] = new_id
#     vocab[new_id] = vocab[max_pair[0]] + vocab[max_pair[1]]
    

In [46]:
class RegexTokenizer(Tokenizer):
    def __init__(self, pattern=None):
        """
        - pattern to split the text by, default is gpt-4 pattern
        """
        super().__init__()
        self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
        self.compiled_pattern = re.compile(self.pattern)
        self.special_tokens = {} # dict from str -> int reprsenting special tokens
        self.inverse_special_tokens = {} # dict from int -> str reprsenting special tokens
    
    def register_special_tokens(self, special_tokens):
        # special tokens is a dict from str -> int
        self.special_tokens = special_tokens
        self.inverse_special_tokens = {v:k for k, v in self.special_tokens.items()}
    
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size >= 256
        num_merges = vocab_size - 256
        vocab = {idx: bytes([idx]) for idx in range(256)}

        text_chunks = re.findall(self.compiled_pattern, text)
        ids = [list(chunk.encode('utf-8')) for chunk in text_chunks]
        merges = {}

        for i in range(num_merges):
            # maintain global count of occurences of consecutive pairs
            stats = {}
            for chunk_id in ids:
                get_stats(chunk_id, stats)
            
            # find the pair which occurs most consecutively
            max_pair = max(stats, key=stats.get)
            new_id = 256 + i
            ids = [merge(chunk_id, max_pair, new_id) for chunk_id in ids]
            merges[max_pair] = new_id
            vocab[new_id] = vocab[max_pair[0]] + vocab[max_pair[1]]

            # print stats
            if verbose:
                print(f"merge {i+1}/{num_merges}: {max_pair} -> {new_id} ({vocab[new_id]}) had {stats[max_pair]} occurences")
            
        self.merges = merges # used in encode()
        self.vocab = vocab # used in decode()
    
    def decode(self, ids):
        # return python string given list of integers
        part_bytes = []
        for idx in ids:
            if idx in self.vocab:
                part_bytes.append(self.vocab[idx])
            elif idx in self.inverse_special_tokens:  # understand this part thoroughly
                part_bytes.append(self.inverse_special_tokens[idx].encode("utf-8"))
            else:
                raise ValueError(f"Invalid token for decoding: {idx}")
        
        s = b"".join(part_bytes)
        return s.decode('utf-8', errors='replace')
    
    def _encode_chunk(self, text_bytes):
        ids = list(text_bytes)
        while len(ids) >= 2:
            # find the element in stats that has the smallest associated value in merges
            stats = get_stats(ids)
            top_pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if top_pair not in self.merges:
                break
            ids = merge(ids, top_pair, self.merges[top_pair])
        
        return ids
    
    def _encode_ordinary(self, text):
        "Encoding that ignores any special tokens"
        text_chunks = re.findall(self.compiled_pattern, text)
        text_bytes = [chunk.encode('utf-8') for chunk in text_chunks]
        encoded_out = []
        for text_byte in text_bytes:
            encoded_out.extend(self._encode_chunk(text_byte))
        
        return encoded_out
    
    def encode(self, text, allowed_special="none_raise"):
        """
        This function handles special tokens
        allowed_special: can be "all"|"none"|"none_raise"
        tiktoken default behavior is none_raise
        """
        special = None
        if allowed_special == "all":
            special = self.special_tokens
        elif allowed_special == "none":
            special = {}
        elif allowed_special == "none_raise":
            special = {}
            assert all(token not in text for token in self.special_tokens)
        elif isinstance(allowed_special, set):
            special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
        else:
            raise ValueError(f"allowed_special={allowed_special} not understood")
        
        if not special:
            # revert to encode ordinary
            return self._encode_ordinary(text)
        
        # else need to handle special characters
        # enclosing in parenthesis makes it into a capturing group so that special tokens are includes in the output from split 
        special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
        special_chunks = re.split(special_pattern, text)
        ids = []
        for chunk in special_chunks:
            if chunk in special:
                ids.append(special[chunk])
            else:
                ids.extend(self._encode_ordinary(chunk))
        
        return ids

In [47]:
# Next steps
# understanding gpt-2 and gpt-4 patterns more closely.
# Understand the chunking behavior of gpt4 and why it needs lists of integers.
# understand the special characters usage

In [48]:
a = [[1, 2], [3, 4]]

In [49]:
b = []
for i in a:
    b.extend(i)
b

[1, 2, 3, 4]

In [50]:
special = {
    '<|endoftext|>': 100257,
    '<|endofprompt|>': 100258,
}

In [51]:
special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
special_pattern

'(<\\|endoftext\\|>|<\\|endofprompt\\|>)'

In [52]:
p = re.escape('<|endoftext|>')

In [53]:
re.match(p, '<|endoftext|>')

<regex.Match object; span=(0, 13), match='<|endoftext|>'>

In [54]:
re.match('<|endoftext|>', '<|endoftext|>')

<regex.Match object; span=(0, 1), match='<'>

In [55]:
text = "def <|endoftext|>Last document!!! 👋<|endofprompt|> abcd"

In [56]:
special_chunks = re.split(special_pattern, text)
special_chunks

['def ', '<|endoftext|>', 'Last document!!! 👋', '<|endofprompt|>', ' abcd']

In [57]:
tokenizer = RegexTokenizer()
long_str = "How are you doing, How are you doing"
tokenizer.train(long_str, verbose=True, vocab_size=260)

merge 1/4: (72, 111) -> 256 (b'Ho') had 2 occurences
merge 2/4: (256, 119) -> 257 (b'How') had 2 occurences
merge 3/4: (32, 97) -> 258 (b' a') had 2 occurences
merge 4/4: (258, 114) -> 259 (b' ar') had 2 occurences


In [58]:
tokenizer.decode(tokenizer.encode("Hello World"))

'Hello World'

In [59]:
a = "<|endoftext|>Hello world this is one document"
a.strip()

'<|endoftext|>Hello world this is one document'

In [60]:
llama_text = """
<|endoftext|>The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.
Llamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]
The ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million llamas and alpacas in South America and over 158,000 llamas and 100,000 alpacas, descended from progenitors imported late in the 20th century, in the United States and Canada.[5]
<|fim_prefix|>In Aymara mythology, llamas are important beings. The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to Aymara eschatology,<|fim_suffix|> where they come from at the end of time.[6]<|fim_middle|> llamas will return to the water springs and ponds<|endofprompt|>
""".strip()
llama_text

'<|endoftext|>The llama (/ˈlɑːmə/; Spanish pronunciation: [ˈʎama] or [ˈʝama]) (Lama glama) is a domesticated South American camelid, widely used as a meat and pack animal by Andean cultures since the pre-Columbian era.\nLlamas are social animals and live with others as a herd. Their wool is soft and contains only a small amount of lanolin.[2] Llamas can learn simple tasks after a few repetitions. When using a pack, they can carry about 25 to 30% of their body weight for 8 to 13 km (5–8 miles).[3] The name llama (in the past also spelled "lama" or "glama") was adopted by European settlers from native Peruvians.[4]\nThe ancestors of llamas are thought to have originated from the Great Plains of North America about 40 million years ago, and subsequently migrated to South America about three million years ago during the Great American Interchange. By the end of the last ice age (10,000–12,000 years ago), camelids were extinct in North America.[3] As of 2007, there were over seven million l

In [61]:
multi_line_str = """
This is a multi line str
in python
"""

In [62]:
multi_line_str

'\nThis is a multi line str\nin python\n'

### Implement gpt-4

In [64]:
import tiktoken

In [65]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""

In [66]:
GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}

In [68]:
enc = tiktoken.get_encoding("cl100k_base")
enc

<Encoding 'cl100k_base'>

In [71]:
type(enc)

tiktoken.core.Encoding

In [77]:
mergeable_ranks = enc._mergeable_ranks
byte_len_counts = {}
for b, id in mergeable_ranks.items():
    if len(b) not in byte_len_counts:
        byte_len_counts[len(b)] = []
    byte_len_counts[len(b)].append(b)

In [81]:
len(mergeable_ranks)

100256

In [96]:
token = byte_len_counts[5][10]
max_rank = mergeable_ranks[token]
token

b'ystem'

In [97]:
parts = [bytes([b]) for b in token]
parts

[b'y', b's', b't', b'e', b'm']

In [98]:
parts = parts[:2] + [parts[2] + parts[3]] + parts[4:]
tuple(parts)

(b'y', b's', b'te', b'm')

In [95]:
tuple(['a', 'b', 'c'])

('a', 'b', 'c')

In [83]:
# mergeable_ranks

In [None]:
# merges is a dict from (id, id) -> id
# vocab is a dict from id -> byte repr

In [118]:
def bpe(mergeable_ranks, token, max_rank):
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            # print(f"rank is {rank}, pair is {pair}")
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break

        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
        # print(f"min_rank is {min_rank}, min_idx is {min_idx}, parts is {parts}")
    return parts


In [113]:
token, max_rank

(b'ystem', 615)

In [114]:
bpe(mergeable_ranks, token, max_rank)

rank is 1065, pair is (b'y', b's')
rank is 267, pair is (b's', b't')
rank is 668, pair is (b't', b'e')
rank is 336, pair is (b'e', b'm')
min_rank is 267, min_idx is 1, parts is [b'y', b'st', b'e', b'm']
rank is 599, pair is (b'y', b'st')
rank is 5455, pair is (b'st', b'e')
rank is 336, pair is (b'e', b'm')
min_rank is 336, min_idx is 2, parts is [b'y', b'st', b'em']
rank is 599, pair is (b'y', b'st')
rank is 65188, pair is (b'st', b'em')
min_rank is 599, min_idx is 0, parts is [b'yst', b'em']
rank is 615, pair is (b'yst', b'em')


[b'yst', b'em']

In [119]:
def recover_merges(mergeable_ranks):
    """
    mergeable_ranks is a map from byte seq -> int.
    in this func, we are recovering the original pairings.
    We do BPE training on all the tokens to recover this.
    """
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue
        pair = tuple(bpe(mergeable_ranks, token, rank))
        assert len(pair) == 2
        # recover the ranks
        id1 = mergeable_ranks[pair[0]]
        id2 = mergeable_ranks[pair[1]]
        merges[(id1, id2)] = rank
    
    return merges

In [120]:
merges = recover_merges(mergeable_ranks)
len(merges)

100000

188

In [123]:
bytes_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
# bytes_shuffle
inverse_bytes_shuffle = {v: k for k, v in bytes_shuffle.items()}

In [158]:
mergeable_ranks[bytes([0])], inverse_bytes_shuffle[0]

(188, 33)

In [188]:
class GPT4Tokenizer(RegexTokenizer):
    def __init__(self):
        super().__init__(pattern=GPT4_SPLIT_PATTERN)
        # get the official tokenizer
        enc = tiktoken.get_encoding('cl100k_base')
        mergeable_ranks = enc._mergeable_ranks
        # recover merges
        self.merges = recover_merges(mergeable_ranks)
        print(self.merges)
        # build vocab object
        self.vocab = {id: bytes([id]) for id in range(256)}
        for (p0, p1), rank in self.merges.items():
            self.vocab[rank] = self.vocab[p0] + self.vocab[p1]
        
        # continue with the tricky part
        # some individual token bytes are permuted which we need to deal with here.
        self.bytes_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
        # bytes_shuffle
        self.inverse_bytes_shuffle = {v: k for k, v in bytes_shuffle.items()}
        self.register_special_tokens(GPT4_SPECIAL_TOKENS)
    
    def _encode_chunk(self, text_bytes):
        # we need to shuffle as merges uses the shuffled ids
        text_bytes = bytes(self.bytes_shuffle[b] for b in text_bytes)
        return super()._encode_chunk(text_bytes)
    
    def decode(self, ids):
        # TODO: how are special characters handled here, can we add this in regex tokenizer also.
        # TODO: test if we can inverse shuffle first and then decode, seems simpler as we convert to bytes once then.
        # text_bytes = b"".join(self.vocab[idx] for idx in ids)
        # text_bytes = bytes(self.inverse_bytes_shuffle[i] for i in text_bytes)
        # return text_bytes.decode('utf-8', errors='replace')
        text_ids = [self.inverse_bytes_shuffle[id] for id in ids]
        text_bytes = b"".join(self.vocab[idx] for idx in text_ids)
        return text_bytes.decode('utf-8', errors='replace')
    
    # pretrained tokenizer, can't be implemented.
    def train(self, text, vocab_size, verbose=False):
        raise NotImplementedError
    
    def save(self, file_prefix):
        raise NotImplementedError('GPT4Tokenizer can\'t be saved')
    
    def load(self, model_file):
        raise NotImplementedError()
    
    def save_vocab(self, vocab_file):
        # from .base import render_token
        vocab = {idx: bytes([inverse_bytes_shuffle[idx]]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        
        inverted_merges = {idx: pair for pair, idx in self.merges.items()}
        with open(vocab_file, 'w', encoding='utf-8') as f:
            for idx, token in self.vocab.items():
                # replaces some partial utf-8 seq into ? token, so this can't be decoded due to error = 'replace'
                s = render_token(token)
                if idx in inverted_merges:
                    idx0, idx1 = inverted_merges[idx]
                    s0 = render_token(vocab[idx0])
                    s1 = render_token(vocab[idx1])
                    f.write(f"[{s0}][{s1}] -> [{s}] {idx}\n") # we should be able to change this
                else:
                    # print the bytes and special characters, double check the special characters part.
                    f.write(f"[{s}] {idx}")

In [184]:
specials_string = """
<|endoftext|>Hello world this is one document
<|endoftext|>And this is another document
<|endoftext|><|fim_prefix|>And this one has<|fim_suffix|> tokens.<|fim_middle|> FIM
<|endoftext|>Last document!!! 👋<|endofprompt|>
""".strip()

In [189]:
tokenizer = GPT4Tokenizer()
gpt4_tokenizer_ids = tokenizer.encode(specials_string, allowed_special="all")
gpt4_tokenizer_ids
enc = tiktoken.get_encoding("cl100k_base")
tiktoken_ids = enc.encode(specials_string, allowed_special="all")
gpt4_tokenizer_ids == tiktoken_ids
# tokenizer.decode(gpt4_tokenizer_ids)

{(220, 220): 256, (256, 256): 257, (72, 77): 258, (220, 83): 259, (257, 257): 260, (68, 81): 261, (256, 220): 262, (78, 77): 263, (220, 64): 264, (81, 68): 265, (64, 83): 266, (82, 83): 267, (68, 77): 268, (78, 81): 269, (259, 71): 270, (198, 198): 271, (220, 66): 272, (75, 68): 273, (220, 82): 274, (72, 83): 275, (64, 77): 276, (64, 81): 277, (64, 75): 278, (270, 68): 279, (26, 198): 280, (220, 79): 281, (220, 69): 282, (78, 84): 283, (220, 28): 284, (72, 82): 285, (257, 262): 286, (258, 70): 287, (68, 82): 288, (220, 86): 289, (72, 263): 290, (68, 67): 291, (72, 66): 292, (220, 65): 293, (220, 67): 294, (68, 83): 295, (220, 76): 296, (220, 78): 297, (197, 197): 298, (81, 78): 299, (64, 82): 300, (68, 75): 301, (66, 83): 302, (77, 67): 303, (220, 258): 304, (220, 71): 305, (268, 83): 306, (72, 67): 307, (220, 77): 308, (64, 76): 309, (260, 262): 310, (259, 78): 311, (220, 265): 312, (12, 12): 313, (220, 90): 314, (297, 69): 315, (78, 76): 316, (8, 280): 317, (72, 76): 318, (201, 198):

True

In [191]:
# tokenizer.decode(tokenizer.encode("Hello world"))

In [142]:
# vocab
# inverse bytes shuffle is a map from actual id -> ideal id if there was no shuffle

In [192]:
# vocab1 = {idx: bytes([mergeable_ranks[bytes([idx])]]) for idx in range(256)}
# vocab1

In [151]:
vocab[0], bytes([33]), inverse_bytes_shuffle[0]

(b'!', b'!', 33)

In [150]:
inverse_bytes_shuffle[0], bytes_shuffle[33]

(33, 0)

In [129]:
vocab = {i: bytes([i]) for i in range(256)}
ids = b"".join(vocab[i] for i in range(10))
len(ids)

10

In [152]:
# for i in ids:
#     print(i)

In [137]:
mergeable_ranks[vocab[1]], vocab[1]

(189, b'\x01')

In [132]:
[bytes_shuffle[i] for i in ids]

[188, 189, 190, 191, 192, 193, 194, 195, 196, 197]

In [139]:
# bytes_shuffle

In [1]:
from src import GPT4Tokenizer
from src import render_token, replace_control_characters
import tiktoken

In [2]:
tokenizer = GPT4Tokenizer()

{(220, 220): 256, (256, 256): 257, (72, 77): 258, (220, 83): 259, (257, 257): 260, (68, 81): 261, (256, 220): 262, (78, 77): 263, (220, 64): 264, (81, 68): 265, (64, 83): 266, (82, 83): 267, (68, 77): 268, (78, 81): 269, (259, 71): 270, (198, 198): 271, (220, 66): 272, (75, 68): 273, (220, 82): 274, (72, 83): 275, (64, 77): 276, (64, 81): 277, (64, 75): 278, (270, 68): 279, (26, 198): 280, (220, 79): 281, (220, 69): 282, (78, 84): 283, (220, 28): 284, (72, 82): 285, (257, 262): 286, (258, 70): 287, (68, 82): 288, (220, 86): 289, (72, 263): 290, (68, 67): 291, (72, 66): 292, (220, 65): 293, (220, 67): 294, (68, 83): 295, (220, 76): 296, (220, 78): 297, (197, 197): 298, (81, 78): 299, (64, 82): 300, (68, 75): 301, (66, 83): 302, (77, 67): 303, (220, 258): 304, (220, 71): 305, (268, 83): 306, (72, 67): 307, (220, 77): 308, (64, 76): 309, (260, 262): 310, (259, 78): 311, (220, 265): 312, (12, 12): 313, (220, 90): 314, (297, 69): 315, (78, 76): 316, (8, 280): 317, (72, 76): 318, (201, 198):

In [3]:
enc = tiktoken.get_encoding("cl100k_base")
ids = enc.encode("How are you doing")
ids

[4438, 527, 499, 3815]

In [4]:
enc.decode(ids)

'How are you doing'

In [5]:
ids_gpt4 = tokenizer.encode("How are you doing"); ids_gpt4

[4438, 527, 499, 3815]

In [6]:
tokenizer.decode(ids_gpt4)

'How are you doing'

In [7]:
len(tokenizer.inverse_bytes_shuffle)

256

In [8]:
tokenizer.save_vocab('gpt4.vocab')

{0: b'!', 1: b'"', 2: b'#', 3: b'$', 4: b'%', 5: b'&', 6: b"'", 7: b'(', 8: b')', 9: b'*', 10: b'+', 11: b',', 12: b'-', 13: b'.', 14: b'/', 15: b'0', 16: b'1', 17: b'2', 18: b'3', 19: b'4', 20: b'5', 21: b'6', 22: b'7', 23: b'8', 24: b'9', 25: b':', 26: b';', 27: b'<', 28: b'=', 29: b'>', 30: b'?', 31: b'@', 32: b'A', 33: b'B', 34: b'C', 35: b'D', 36: b'E', 37: b'F', 38: b'G', 39: b'H', 40: b'I', 41: b'J', 42: b'K', 43: b'L', 44: b'M', 45: b'N', 46: b'O', 47: b'P', 48: b'Q', 49: b'R', 50: b'S', 51: b'T', 52: b'U', 53: b'V', 54: b'W', 55: b'X', 56: b'Y', 57: b'Z', 58: b'[', 59: b'\\', 60: b']', 61: b'^', 62: b'_', 63: b'`', 64: b'a', 65: b'b', 66: b'c', 67: b'd', 68: b'e', 69: b'f', 70: b'g', 71: b'h', 72: b'i', 73: b'j', 74: b'k', 75: b'l', 76: b'm', 77: b'n', 78: b'o', 79: b'p', 80: b'q', 81: b'r', 82: b's', 83: b't', 84: b'u', 85: b'v', 86: b'w', 87: b'x', 88: b'y', 89: b'z', 90: b'{', 91: b'|', 92: b'}', 93: b'~', 94: b'\xa1', 95: b'\xa2', 96: b'\xa3', 97: b'\xa4', 98: b'\xa5', 99:

In [9]:
render_token(b'!')

'!'

In [10]:
vocab = {idx: bytes([idx]) for idx in range(256)}
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [11]:
import math

In [24]:
-math.log(0.99)

0.01005033585350145

In [25]:
-math.log(0.01)

4.605170185988091

In [27]:
math.log(1.0)

0.0