<a href="https://colab.research.google.com/github/iamthedoan/nn-zero-to-hero/blob/master/lectures/gpt_tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Building a GPT-4 Tokenizer**



In [88]:
!pip install tiktoken # added for colab



In [89]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # GPT-4 tokenizer
print(enc.encode("안녕하세요 👋 (hello in Korean!)"))
print(enc.decode(enc.encode("안녕하세요 👋 (hello in Korean!)")) == "안녕하세요 👋 (hello in Korean!)")
# match the above for your own tokenizer, and also implement a train() function

[31495, 230, 75265, 243, 92245, 62904, 233, 320, 15339, 304, 16526, 16715]
True


### minbpe exercise

At this point you have everything you need to build your own GPT-4 tokenizer. This is the [exercise progression](https://github.com/karpathy/minbpe/blob/master/exercise.md) you may wish to follow. You'll note that it is part of the [minbpe](https://github.com/karpathy/minbpe) repo, which is the solution to that exercise, and is a cleaned up version of the code above.

In [72]:
!wget -O input.txt https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt

--2024-06-13 23:30:46--  https://raw.githubusercontent.com/karpathy/minbpe/master/tests/taylorswift.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 185768 (181K) [text/plain]
Saving to: ‘input.txt’


2024-06-13 23:30:47 (5.51 MB/s) - ‘input.txt’ saved [185768/185768]



In [66]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

print(text[:1000])

Copy paste of the Wikipedia article on Taylor Swift, as of Feb 16, 2024.
---

Main menu

WikipediaThe Free Encyclopedia

Search
Create account
Log in

Personal tools
Contents  hide
(Top)
Life and career
Toggle Life and career subsection
Artistry
Toggle Artistry subsection
Accolades and achievements
Cultural status
Toggle Cultural status subsection
Wealth
Toggle Wealth subsection
Discography
Filmography
Tours
See also
Footnotes
References
Toggle References subsection
External links
Taylor Swift

136 languages
Article
Talk
Read
View source
View history

Tools
 Featured article
Page semi-protected
From Wikipedia, the free encyclopedia
For the album, see Taylor Swift (album).
Taylor Swift
Portrait of Taylor Swift in a cocktail dress
Swift at the 2023 MTV Video Music Awards
Born	Taylor Alison Swift
December 13, 1989 (age 34)
West Reading, Pennsylvania, US
Occupations
Singer-songwriter producer director businesswoman actress
Years active	2004–present
Works
Albumssinglessongsvideosperformance

Step 1
Write the BasicTokenizer class, with the following three core functions:


*   def train(self, text, vocab_size, verbose=False)
*   def encode(self, text)
*   def decode(self, ids)

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file tests/taylorswift.txt.

In [23]:
def get_stats(ids, counts=None):
  # needs to store and counts the number of pairs
  counts = {} if counts is None else counts
  for pair in zip(ids,ids[1:]): # iterates through ids in a pairwise manner
    counts[pair] = counts.get(pair, 0) + 1
  return counts

def merge(ids, pair,idx):
    # idx stores new id values, should be higher than 255 (0-255 BPE)
    newids = []
    i = 0
    while i < len(ids): # iterate through all the ids
      # if the id is not the last one and matches the pair inputted
      if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
        newids.append(idx) # add to new ids
        i += 2
      else:
        # edge case, pair is at end of text
        newids.append(ids[i])
        i += 1
    return newids

In [9]:
class Tokenizer:
# base class for all the tokenizers

  def __init__(self):
    self.merges = {}
    self.pattern = ""
    self.vocab = self._build_vocab()

  def train(self, text, vocab_size, verbose=False):
    raise NotImplementedError

  def encode(self, text):
    raise NotImplementedError

  def decode(self, ids):
    raise NotImplementedError

  def _build_vocab(self):
    vocab = {idx: bytes([idx]) for idx in range(256)}
    for (p0, p1), idx in self.merges.items():
      vocab[idx] = vocab[p0] + vocab[p1]
    return vocab

In [39]:
class BasicTokenizer(Tokenizer):

  def __init__(self):
    super().__init__()


  def train(self, text, vocab_size, verbose=False):
      assert vocab_size >= 256
      num_merges = vocab_size - 256


      ids = list(text.encode("utf-8"))
      if num_merges > len(ids):
        num_merges = len(ids)-1


      vocab = {idx: bytes([idx]) for idx in range(256)}
      merges = {}
      for i in range(num_merges):
        stats = get_stats(ids)
        pair = max(stats, key=stats.get) #get most common pair
        idx = 256 + i
        ids = merge(ids, pair, idx)
        vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

        merges[pair] = idx

        if verbose:
          print(f"merging {pair} into a new token {idx} with {stats[pair]} occurences")

        # save class variables
      self.merges = merges # used in encode()
      self.vocab = vocab   # used in decode()


  def encode(self,text):
    # given a string, return list of integers (the tokens)
    tokens = list(text.encode("utf-8"))
    while len(tokens) >= 2: # edge case of inputting only one char
      stats = get_stats(tokens)
      # min num from merges will return the first pair created
      pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
      # subtle: if there are no more merges available, the key will
      # result in an inf for every single pair, and the min will be
      # just the first pair in the list, arbitrarily
      # we can detect this terminating case by a membership check
      if pair not in self.merges:
        break # nothing else can be merged
      idx = self.merges[pair]
      tokens = merge(tokens, pair, idx)
    return tokens

  def decode(self,ids):
    # Given a sequence of integers in the range [0, vocab_size], what is the text?
    tokens = b"".join(self.vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text



Testing the Function

In [20]:
text = "some lines a day"
brokenizer = BasicTokenizer()
brokenizer.decode(brokenizer.encode(text))

'some lines a day'

In [21]:
valtext = "Many common characters, including numerals, punctuation, and other symbols, are unified within the standard and are not treated as specific to any given writing system. Unicode encodes thousands of emoji, with the continued development thereof conducted by the Consortium as a part of the standard.[4] Moreover, the widespread adoption of Unicode was in large part responsible for the initial popularization of emoji outside of Japan. Unicode is ultimately capable of encoding more than 1.1 million characters."
valtext2 = brokenizer.decode(brokenizer.encode(valtext))
print(valtext2 == valtext)

True


In [24]:
brokenizer.train(valtext, 266)
brokenizer.merges

{(101, 32): 256,
 (116, 104): 257,
 (97, 114): 258,
 (97, 110): 259,
 (111, 110): 260,
 (105, 110): 261,
 (100, 32): 262,
 (32, 257): 263,
 (111, 102): 264,
 (99, 111): 265}

Training with Taylor Swift wiki

In [34]:
num_merges = 300 - 256

In [40]:
tok = BasicTokenizer()
# if vocab size is too large, encoutner error
tok.train(text, 300)

In [41]:
tok.merges

{(115, 111): 256,
 (256, 109): 257,
 (257, 101): 258,
 (258, 32): 259,
 (259, 108): 260,
 (260, 105): 261,
 (261, 110): 262,
 (262, 101): 263,
 (263, 115): 264,
 (264, 32): 265,
 (265, 97): 266,
 (266, 32): 267,
 (267, 100): 268,
 (268, 97): 269,
 (269, 121): 270}

In [98]:
for i in tok.merges.keys():
  print(chr(i[0]), chr(i[1]))

e  
,  
d  
.  
r  
2 0
s  
i n
o n
r i
t  
t h
e Ă
ā ą
a n
a r
e Ą
y  
a l
ċ Ā
v Č
w i
e r
Ĉ  
ĕ f
R e
S Ę
o Ą
c h
č 1
o m
b Đ
  ē
a y
e n
o r
Ē  
e m
. 

ĉ e
ć g
č 2
t i
ġ l
" ă
l l
T ī
t ħ
Ħ  
t o
ă ę
Ĳ į
ĳ Ĕ
Į ě
e s
ĵ Ě
u s
r Ğ
ĥ ğ
) ă
A r
f Ĺ
Ļ "
Ď Ă


Step 2

Convert you BasicTokenizer into a RegexTokenizer, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

In [46]:
import regex as re

In [83]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}


print(re.findall(GPT4_SPLIT_PATTERN, "Hello've world123 how's are you!!!?"))

['Hello', "'ve", ' world', '123', ' how', "'s", ' are', ' you', '!!!?']


In [79]:
class RegexTokenizer(Tokenizer):


  def __init__(self, pattern=None):
    super().__init__()
    self.pattern = GPT4_SPLIT_PATTERN if pattern is None else pattern
    self.compiled_pattern = re.compile(self.pattern)
    self.special_tokens = {}
    self.inverse_special_tokens = {}


  def train(self, text, vocab_size, verbose=False):
    assert vocab_size >= 256
    num_merges = vocab_size - 256

    # split text into chunks
    txt_ch = re.findall(self.compiled_pattern, text)

    # process text
    ids = [list(ch.encode("utf-8")) for ch in txt_ch]

    # if num_merges > len(ids):
    #   num_merges = len(ids)-1


    vocab = {idx: bytes([idx]) for idx in range(256)}
    merges = {}
    for i in range(num_merges):
      stats = {}
      for ch_ids in ids:
        get_stats(ch_ids, stats)
      pair = max(stats, key=stats.get) #get most common pair
      idx = 256 + i
      ids = [merge(ch_ids, pair, idx) for ch_ids in ids]
      vocab[idx] = vocab[pair[0]] + vocab[pair[1]]

      merges[pair] = idx

      if verbose:
        print(f"merging {pair} into a new token {idx} with {stats[pair]} occurences")

    # save class variables
    self.merges = merges # used in encode()
    self.vocab = vocab   # used in decode()

  def decode(self,ids):
    # Given a sequence of integers in the range [0, vocab_size], what is the text?
    part_bytes = []
    for idx in ids:
      if idx in self.vocab:
        part_bytes.append(self.vocab[idx])
      else:
        raise ValueError(f"invalid token id: {idx}")
    text = b"".join(part_bytes).decode("utf-8", errors="replace")
    return text

  def register_special_tokens(self, special_tokens):
      # special_tokens is a dictionary of str -> int
      # example: {"<|endoftext|>": 100257}
      self.special_tokens = special_tokens
      self.inverse_special_tokens = {v: k for k, v in special_tokens.items()}

  # taken from minibpe

  def _encode_chunk(self, text_bytes):
    # return the token ids
    # let's begin. first, convert all bytes to integers in range 0..255
    ids = list(text_bytes)
    while len(ids) >= 2:
        # find the pair with the lowest merge index
        stats = get_stats(ids)
        pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
        # subtle: if there are no more merges available, the key will
        # result in an inf for every single pair, and the min will be
        # just the first pair in the list, arbitrarily
        # we can detect this terminating case by a membership check
        if pair not in self.merges:
            break # nothing else can be merged anymore
        # otherwise let's merge the best pair (lowest merge index)
        idx = self.merges[pair]
        ids = merge(ids, pair, idx)
    return ids

  def encode_ordinary(self, text):
      """Encoding that ignores any special tokens."""
      # split text into chunks of text by categories defined in regex pattern
      text_chunks = re.findall(self.compiled_pattern, text)
      # all chunks of text are encoded separately, then results are joined
      ids = []
      for chunk in text_chunks:
          chunk_bytes = chunk.encode("utf-8") # raw bytes
          chunk_ids = self._encode_chunk(chunk_bytes)
          ids.extend(chunk_ids)
      return ids

  def encode(self, text, allowed_special="none_raise"):
      """
      Unlike encode_ordinary, this function handles special tokens.
      allowed_special: can be "all"|"none"|"none_raise" or a custom set of special tokens
      if none_raise, then an error is raised if any special token is encountered in text
      this is the default tiktoken behavior right now as well
      any other behavior is either annoying, or a major footgun
      """
      # decode the user desire with respesct to handling of special tokens
      special = None
      if allowed_special == "all":
          special = self.special_tokens
      elif allowed_special == "none":
          special = {}
      elif allowed_special == "none_raise":
          special = {}
          assert all(token not in text for token in self.special_tokens)
      elif isinstance(allowed_special, set):
          special = {k: v for k, v in self.special_tokens.items() if k in allowed_special}
      else:
          raise ValueError(f"allowed_special={allowed_special} not understood")
      if not special:
          # shortcut: if no special tokens, just use the ordinary encoding
          return self.encode_ordinary(text)
      # otherwise, we have to be careful with potential special tokens in text
      # we handle special tokens by splitting the text
      # based on the occurrence of any exact match with any of the special tokens
      # we can use re.split for this. note that surrounding the pattern with ()
      # makes it into a capturing group, so the special tokens will be included
      special_pattern = "(" + "|".join(re.escape(k) for k in special) + ")"
      special_chunks = re.split(special_pattern, text)
      # now all the special characters are separated from the rest of the text
      # all chunks of text are encoded separately, then results are joined
      ids = []
      for part in special_chunks:
          if part in special:
              # this is a special token, encode it separately as a special case
              ids.append(special[part])
          else:
              # this is an ordinary sequence, encode it normally
              ids.extend(self.encode_ordinary(part))
      return ids


In [80]:
regtok = RegexTokenizer()
regtok.train(text, 272)


In [81]:
regtok.merges

{(101, 114): 256,
 (50, 48): 257,
 (111, 114): 258,
 (105, 110): 259,
 (101, 100): 260,
 (32, 116): 261,
 (111, 110): 262,
 (104, 101): 263,
 (32, 83): 264,
 (97, 114): 265,
 (97, 110): 266,
 (32, 65): 267,
 (261, 263): 268,
 (97, 108): 269,
 (114, 105): 270,
 (118, 260): 271}

In [85]:
print(regtok.vocab)

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',

Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both encode and decode, matching tiktoken.

In [91]:
def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing
    # a small BPE training run on all the tokens, in their order.
    # also see https://github.com/openai/tiktoken/issues/60
    # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges


def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    parts = [bytes([b]) for b in token]
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts

In [105]:
class GPT4Tokenizer(RegexTokenizer):
  def __init__(self):
    super().__init__(pattern=GPT4_SPLIT_PATTERN)

    # get official GPT4 tokenizer and its merges
    enc = tiktoken.get_encoding("cl100k_base")
    mergeable_ranks = enc._mergeable_ranks

    # merges from gpt4
    self.merges = recover_merges(mergeable_ranks)

    vocab = {idx: bytes([idx]) for idx in range(256)}
    # merges is a dict (int, int) --> int
    for (p0, p1), idx in self.merges.items():
      vocab[idx] = vocab[p0] + vocab[p1]
    self.vocab = vocab

    self.byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}
    self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}

    self.register_special_tokens(GPT4_SPECIAL_TOKENS)

  def decode(self, ids):
    # unpermute the bytes before decoding
    text_bytes = b"".join(self.vocab[idx] for idx in ids) # storing all the bytes in a larger byte
    text_bytes = bytes(self.inverse_byte_shuffle[i] for i in text_bytes)
    text = text_bytes.decode("utf-8", errors="replace")
    return text

  def _encode_chunk(self, text_bytes):
    text_bytes = bytes(self.byte_shuffle[i] for i in text_bytes)
    return super()._encode_chunk(text_bytes)




In [99]:
# match this
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back

In [109]:
print(ids)
print(text)

[15339, 1917, 12340, 30, 320, 31495, 230, 75265, 243, 92245, 16715, 28509, 4513, 57037]
hello world!!!? (안녕하세요!) lol123 😉


In [106]:
tiktok = GPT4Tokenizer()
# tiktok.vocab
ids = tiktok.encode("hello world!!!? (안녕하세요!) lol123 😉")

In [111]:
print(ids)
print(tiktok.decode(ids))

[15339, 1917, 12340, 30, 320, 31495, 230, 75265, 243, 92245, 16715, 28509, 4513, 57037]
hello world!!!? (안녕하세요!) lol123 😉


In [114]:
print(enc.encode("<|endoftext|>hello world", allowed_special="all"))
print(tiktok.encode("<|endoftext|>hello world", allowed_special="all"))

[100257, 15339, 1917]
[100257, 15339, 1917]


https://github.com/karpathy/minbpe/blob/master/exercise.md