<a href="https://colab.research.google.com/github/davidyuan-mle/llm/blob/main/tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch

In [13]:
# ord() is the unicode
# each chr can be encoded as a number
x = 'hello 你好 😊'
for s in x:
  print(ord(s))

104
101
108
108
111
32
20320
22909
32
128522


In [14]:
# encode('utf-8') is to encode to 'utf-8' encoding scheme
# 'utf-8' represens each chr by 1 or more bytes,
# e.g. Englisth typically uses 1 byte, and Chinese uses 3 bytes, emoji uses 4 bytes
# A byte is a group of 8 bits, which represents a number from 0 - 255
# e.g. 00000000 equals 0; 00000001 equals 2^0 = 1; 01000001 equals 2^6 + 2^0 = 65; and 11111111 equals 255
# basically each bit is 0,1 to indicate if the position is turned on or off; if turned on then 2^position, otherwise 0
# and position is as index: 7 - 0, with far right index 0

x = 'hello 你好 😊'
list(x.encode('utf-8'))

[104,
 101,
 108,
 108,
 111,
 32,
 228,
 189,
 160,
 229,
 165,
 189,
 32,
 240,
 159,
 152,
 138]

In [15]:
# For tiktokenization, different LLMs use different methods:
# GPT - BPE (byte pair encoding)
# Claude - SentencePiece + BPE hybrid
# Gemini - SentencePiece
# LLaMA - SentencePiece

# so what is BPE algorithm?
# Here is the wiki explaination: https://en.wikipedia.org/wiki/Byte_pair_encoding
# Idea is to repeatly find the most occurance pairs of byte and replace with a new token (here starting from 256, 257...) until the pre-specified # of vocaburay reached.
# As example: 'aaabdaaabac'
# 1) 'aa' pairs happen the most, so replace it with a new letter (token) say Z, then it becomes 'ZabdZabac'
# 2) 'ab' pairs happen the most, so replace ith with Y, then it becomes 'ZYdZYac'
# 3) 'ZY' pairs happen the most, so replace it with X, then it becomes 'XdXac'; that's it!

In [19]:
# wiki tokenization page: https://en.wikipedia.org/wiki/Tokenization_(data_security)
text = "The Chinese Wikipedia (traditional Chinese: 中文維基百科; simplified Chinese: 中文维基百科; pinyin: Zhōngwén Wéijī Bǎikē) is the written vernacular Chinese edition of Wikipedia. It was created on 11 May 2001.[1] It is one of multiple projects supported by the Wikimedia Foundation.The Chinese Wikipedia currently has 1,469,839 articles, 3,690,353 registered users, and 7,189 active editors, of whom 63 have administrative privileges.The Chinese Wikipedia has been blocked in mainland China since May 2015.[2] Nonetheless, the Chinese Wikipedia is still one of the top ten most active versions of Wikipedia by number of edits and number of editors,[3][4] due to contributions from users from Taiwan, Hong Kong, Macau, Singapore, Malaysia, and the Chinese diaspora.Taiwan and Hong Kong contribute most of the page views to the Chinese Wikipedia."

In [25]:
tokens = list(text.encode('utf-8'))

print('---------')
print(f"text examples: {text[:50]}")
print(f"tokens examples: {tokens[:50]}")
print('---------')
print(f"text length: {len(text)}")
print(f"token length: {len(tokens)}")

---------
text examples: The Chinese Wikipedia (traditional Chinese: 中文維基百科
tokens examples: [84, 104, 101, 32, 67, 104, 105, 110, 101, 115, 101, 32, 87, 105, 107, 105, 112, 101, 100, 105, 97, 32, 40, 116, 114, 97, 100, 105, 116, 105, 111, 110, 97, 108, 32, 67, 104, 105, 110, 101, 115, 101, 58, 32, 228, 184, 173, 230, 150, 135]
---------
text length: 831
token length: 861


In [29]:
# example to get pairs
for pair in zip(tokens[:10], tokens[:10][1:]):
  print(pair)

(84, 104)
(104, 101)
(101, 32)
(32, 67)
(67, 104)
(104, 105)
(105, 110)
(110, 101)
(101, 115)


In [37]:
def countPairs(tokens):
  counts = {}
  for pair in zip(tokens, tokens[1:]):
    counts[pair] = counts.get(pair, 0) + 1
  return counts

counts = countPairs(tokens)
print(counts)

{(84, 104): 3, (104, 101): 11, (101, 32): 28, (32, 67): 10, (67, 104): 10, (104, 105): 10, (105, 110): 17, (110, 101): 12, (101, 115): 12, (115, 101): 11, (32, 87): 9, (87, 105): 8, (105, 107): 9, (107, 105): 8, (105, 112): 8, (112, 101): 7, (101, 100): 17, (100, 105): 14, (105, 97): 10, (97, 32): 7, (32, 40): 1, (40, 116): 1, (116, 114): 4, (114, 97): 3, (97, 100): 2, (105, 116): 6, (116, 105): 10, (105, 111): 5, (111, 110): 15, (110, 97): 3, (97, 108): 2, (108, 32): 2, (101, 58): 2, (58, 32): 3, (32, 228): 2, (228, 184): 2, (184, 173): 2, (173, 230): 2, (230, 150): 2, (150, 135): 2, (135, 231): 2, (231, 182): 1, (182, 173): 1, (173, 229): 1, (229, 159): 2, (159, 186): 2, (186, 231): 2, (231, 153): 2, (153, 190): 2, (190, 231): 2, (231, 167): 2, (167, 145): 2, (145, 59): 2, (59, 32): 2, (32, 115): 4, (115, 105): 4, (105, 109): 2, (109, 112): 1, (112, 108): 2, (108, 105): 1, (105, 102): 1, (102, 105): 1, (105, 101): 2, (100, 32): 10, (231, 187): 1, (187, 180): 1, (180, 229): 1, (32, 11

In [38]:
max(counts, key=counts.get) # find the max value and return key

(101, 32)

In [40]:
def mergePairs(tokens, pair, index):
  new_tokens = []
  i = 0
  while i < len(tokens):
    if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
      new_tokens.append(index)
      i += 2
    else:
      new_tokens.append(tokens[i])
      i += 1
  return new_tokens

In [42]:
print(mergePairs(tokens, (101,32), 256))

[84, 104, 256, 67, 104, 105, 110, 101, 115, 256, 87, 105, 107, 105, 112, 101, 100, 105, 97, 32, 40, 116, 114, 97, 100, 105, 116, 105, 111, 110, 97, 108, 32, 67, 104, 105, 110, 101, 115, 101, 58, 32, 228, 184, 173, 230, 150, 135, 231, 182, 173, 229, 159, 186, 231, 153, 190, 231, 167, 145, 59, 32, 115, 105, 109, 112, 108, 105, 102, 105, 101, 100, 32, 67, 104, 105, 110, 101, 115, 101, 58, 32, 228, 184, 173, 230, 150, 135, 231, 187, 180, 229, 159, 186, 231, 153, 190, 231, 167, 145, 59, 32, 112, 105, 110, 121, 105, 110, 58, 32, 90, 104, 197, 141, 110, 103, 119, 195, 169, 110, 32, 87, 195, 169, 105, 106, 196, 171, 32, 66, 199, 142, 105, 107, 196, 147, 41, 32, 105, 115, 32, 116, 104, 256, 119, 114, 105, 116, 116, 101, 110, 32, 118, 101, 114, 110, 97, 99, 117, 108, 97, 114, 32, 67, 104, 105, 110, 101, 115, 256, 101, 100, 105, 116, 105, 111, 110, 32, 111, 102, 32, 87, 105, 107, 105, 112, 101, 100, 105, 97, 46, 32, 73, 116, 32, 119, 97, 115, 32, 99, 114, 101, 97, 116, 101, 100, 32, 111, 110, 32,

In [48]:
def countPairs(tokens):
  counts = {}
  for pair in zip(tokens, tokens[1:]):
    counts[pair] = counts.get(pair, 0) + 1
  return counts

def mergePairs(tokens, pair, index):
  new_tokens = []
  i = 0
  while i < len(tokens):
    if i < len(tokens) - 1 and tokens[i] == pair[0] and tokens[i+1] == pair[1]:
      new_tokens.append(index)
      i += 2
    else:
      new_tokens.append(tokens[i])
      i += 1
  return new_tokens

vocal_size = 280 # expected number of vocal size
num_merges = vocal_size - 256
tokens = list(text.encode('utf-8'))
tokens_copy = list(tokens)  # make a copy

merges = {} # (int, int) -> int
for i in range(num_merges):
  counts = countPairs(tokens_copy)
  pair = max(counts, key=counts.get)
  index = 256 + i
  print(f"merging {pair} into a new token {index}")
  tokens_copy = mergePairs(tokens_copy, pair, index)
  merges[pair] = index

merging (101, 32) into a new token 256
merging (105, 110) into a new token 257
merging (101, 100) into a new token 258
merging (111, 110) into a new token 259
merging (101, 115) into a new token 260
merging (258, 105) into a new token 261
merging (115, 32) into a new token 262
merging (104, 256) into a new token 263
merging (67, 104) into a new token 264
merging (264, 257) into a new token 265
merging (116, 105) into a new token 266
merging (265, 260) into a new token 267
merging (105, 107) into a new token 268
merging (44, 32) into a new token 269
merging (87, 268) into a new token 270
merging (270, 105) into a new token 271
merging (261, 97) into a new token 272
merging (111, 102) into a new token 273
merging (273, 32) into a new token 274
merging (267, 256) into a new token 275
merging (271, 112) into a new token 276
merging (276, 272) into a new token 277
merging (116, 263) into a new token 278
merging (101, 114) into a new token 279


In [49]:
merges

{(101, 32): 256,
 (105, 110): 257,
 (101, 100): 258,
 (111, 110): 259,
 (101, 115): 260,
 (258, 105): 261,
 (115, 32): 262,
 (104, 256): 263,
 (67, 104): 264,
 (264, 257): 265,
 (116, 105): 266,
 (265, 260): 267,
 (105, 107): 268,
 (44, 32): 269,
 (87, 268): 270,
 (270, 105): 271,
 (261, 97): 272,
 (111, 102): 273,
 (273, 32): 274,
 (267, 256): 275,
 (271, 112): 276,
 (276, 272): 277,
 (116, 263): 278,
 (101, 114): 279}

In [53]:
print(f"length of raw tokens: {len(tokens)}")
print(f"length of compressed tokens: {len(tokens_copy)}")
print(f"compression ratio: {len(tokens) / len(tokens_copy):.2f}")

length of raw tokens: 861
length of compressed tokens: 606
compression ratio: 1.42


In [57]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

print(vocab)

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',