In [1]:
import json
import os
import sys
from typing import BinaryIO, Dict, List, Set, Tuple

import pandas as pd
from transformers import AutoTokenizer
sys.path.append('assignment1-basics')
from cs336_basics.tokenizer.train import find_chunk_boundaries, PAT

In [2]:
gpt2_tokenizer = AutoTokenizer.from_pretrained('gpt2')

# 1. Initialize Vocab

In [3]:
vocab: Dict[int, bytes] = {
    i: chr(i).encode('utf-8') for i in range(256)
}
vocab_size = 256

# For representing whitespace
print("Ġ".encode('utf-8'))
vocab[vocab_size]="Ġ".encode('utf-8')
vocab_size+=1

# Add Special Tokens
special_tokens = ['<|endoftext|>']

encoded_special_tokens = [
    x.encode('utf-8') for x in special_tokens
]

for tok in encoded_special_tokens:
    vocab[vocab_size]=tok
    vocab_size+=1

split_special_token = "<|endoftext|>".encode('utf-8')
print(len(split_special_token), split_special_token)

b'\xc4\xa0'
13 b'<|endoftext|>'


In [4]:
## Inverse Vocab
vocab_inv = {tok:i for i,tok in vocab.items()}

In [5]:
vocab_inv

{b'\x00': 0,
 b'\x01': 1,
 b'\x02': 2,
 b'\x03': 3,
 b'\x04': 4,
 b'\x05': 5,
 b'\x06': 6,
 b'\x07': 7,
 b'\x08': 8,
 b'\t': 9,
 b'\n': 10,
 b'\x0b': 11,
 b'\x0c': 12,
 b'\r': 13,
 b'\x0e': 14,
 b'\x0f': 15,
 b'\x10': 16,
 b'\x11': 17,
 b'\x12': 18,
 b'\x13': 19,
 b'\x14': 20,
 b'\x15': 21,
 b'\x16': 22,
 b'\x17': 23,
 b'\x18': 24,
 b'\x19': 25,
 b'\x1a': 26,
 b'\x1b': 27,
 b'\x1c': 28,
 b'\x1d': 29,
 b'\x1e': 30,
 b'\x1f': 31,
 b' ': 32,
 b'!': 33,
 b'"': 34,
 b'#': 35,
 b'$': 36,
 b'%': 37,
 b'&': 38,
 b"'": 39,
 b'(': 40,
 b')': 41,
 b'*': 42,
 b'+': 43,
 b',': 44,
 b'-': 45,
 b'.': 46,
 b'/': 47,
 b'0': 48,
 b'1': 49,
 b'2': 50,
 b'3': 51,
 b'4': 52,
 b'5': 53,
 b'6': 54,
 b'7': 55,
 b'8': 56,
 b'9': 57,
 b':': 58,
 b';': 59,
 b'<': 60,
 b'=': 61,
 b'>': 62,
 b'?': 63,
 b'@': 64,
 b'A': 65,
 b'B': 66,
 b'C': 67,
 b'D': 68,
 b'E': 69,
 b'F': 70,
 b'G': 71,
 b'H': 72,
 b'I': 73,
 b'J': 74,
 b'K': 75,
 b'L': 76,
 b'M': 77,
 b'N': 78,
 b'O': 79,
 b'P': 80,
 b'Q': 81,
 b'R': 82,
 b'S': 

# 2. Load Data

In [6]:
# Load Data
input_path = 'data/owt_valid.txt'
num_processes=8

with open(input_path, 'rb') as f:
    boundaries = find_chunk_boundaries(
        f,
        num_processes,
        split_special_token
    )

In [7]:
boundaries

[0,
 36335216,
 72505172,
 108752143,
 145027268,
 181256470,
 217499287,
 253752435,
 289998753]

In [8]:
with open(input_path, 'rb') as f:
    for b_i in range(1, len(boundaries)):
        start = boundaries[b_i-1]
        # every chunk except first contains split_special_token at start
        if b_i!=1:
            start+=len(split_special_token)
            
        end = boundaries[b_i]
        # Last Chunk contains split_special_token at the end
        if b_i==len(boundaries)-1:
            end-=len(split_special_token)
        
        f.seek(start)
        chunk = f.read(end - start).decode("utf-8", errors="ignore")
        print(b_i, start, end)
        print('begin',repr(chunk[:10]))
        print('end', repr(chunk[-15:]))
        
    # for start, end in zip(boundaries[:-1], boundaries[1:]):
    #     f.seek(start)
    #     chunk = f.read(end - start).decode("utf-8", errors="ignore")
    #     print('begin',chunk[:10])
    #     print('end', chunk[-15:])

1 0 36335216
begin 'LOUISVILLE'
end 'evel above 4th.'
2 36335229 72505172
begin 'Story high'
end 'tion,” he said.'
3 72505185 108752143
begin '1705-hill-'
end 'Newswire posts:'
4 108752156 145027268
begin 'Get the bi'
end 'graphy [ edit ]'
5 145027281 181256470
begin 'Soos Goes '
end 'k my swan song.'
6 181256483 217499287
begin 'Monday, Au'
end 'elsh and Irish.'
7 217499300 253752435
begin 'There’s an'
end '9s&w=600&h=315]'
8 253752448 289998740
begin 'Address 68'
end 'ce on March 1."'


# 3. Handle Pretokenization

In [9]:
# Test whitespace
print(" ".isspace())
print("\n".isspace())
print("\t".isspace())

print(len(" h".encode('utf-8')))
print("space")
print(" h".encode('utf-8')[0])
print("tab")
print("\th".encode('utf-8')[0])
print("newline")
print("\nh".encode('utf-8')[0])

print(repr(bytes([32]).decode('utf-8')))    # space
print(repr(bytes([9]).decode('utf-8'))) # tab
print(repr(bytes([10]).decode('utf-8'))) # newline

print(gpt2_tokenizer(['H'])) # merged into one
print(gpt2_tokenizer([' H'])) # merged into one
print(gpt2_tokenizer(['\nH']))
print(gpt2_tokenizer(['ĊH']))
print(gpt2_tokenizer(['\tH']))

True
True
True
2
space
32
tab
9
newline
10
' '
'\t'
'\n'
{'input_ids': [[39]], 'attention_mask': [[1]]}
{'input_ids': [[367]], 'attention_mask': [[1]]}
{'input_ids': [[198, 39]], 'attention_mask': [[1, 1]]}
{'input_ids': [[128, 232, 39]], 'attention_mask': [[1, 1, 1]]}
{'input_ids': [[197, 39]], 'attention_mask': [[1, 1]]}


In [10]:
whitespace_token = "Ġ"
whitespace_token_bytes = whitespace_token.encode('utf-8')
print(whitespace_token_bytes)

newline_token = "Ċ"
newline_token_bytes = newline_token.encode('utf-8')
print(newline_token_bytes)

b'\xc4\xa0'
b'\xc4\x8a'


In [11]:
x = 'hi'.encode('utf-8')[0]
print(type(newline_token_bytes))
print(type(x))
# b''.join([newline_token_bytes, bytes([x])])
# bytes([newline_token_bytes, x[0]])

<class 'bytes'>
<class 'int'>


In [12]:
# Chunk Pretokenization
import regex as re

class TokenNode:
    def __init__(self, val):
        self.val = val
        self.prev = None
        self.next = None
        # For determining pre-tokenization boundary
        self.is_next_connected = True

def add_node(byte_val, prev):
    """Helper to create and link a new TokenNode."""
    node = TokenNode(byte_val)
    if prev:
        prev.next = node
        node.prev = prev
    return node

i=0

head=None
prev=None

# Outer: Pre-tokenized Tokens
for pre_tok in re.finditer(PAT, chunk):
    text = pre_tok.group()
    bytes_to_process = []

    if text[0] in (' ', '\n'):
        # Determine the prefix byte token (space or newline)
        prefix_bytes = whitespace_token_bytes if text[0] == ' ' else newline_token_bytes
        rest = text[1:].encode('utf-8') if len(text) > 1 else b""

        # Merge prefix with first byte of rest, or use prefix alone
        if rest:
            first = bytes([rest[0]])
            node = add_node(prefix_bytes + first, prev)
            prev = node
            bytes_to_process = rest[1:]
        else:
            node = add_node(prefix_bytes, prev)
            prev = node
            bytes_to_process = b""
    else:
        bytes_to_process = text.encode('utf-8')

    # Add remaining bytes as separate nodes
    for byte in bytes_to_process:
        prev = add_node(bytes([byte]), prev)
        if head is None:
            head = prev

    if prev:
        prev.is_next_connected = False

In [13]:
from collections import defaultdict

pair_positions = defaultdict(set)
node = head
while node and node.next:
    # print(node.val, node.is_next_connected)
    if not node.is_next_connected:
        node=node.next
        continue
    
    pair_positions[
        (node.val, node.next.val)
    ].add(node)
    node = node.next

In [14]:
pair_counts = {pair: len(nodes) for pair, nodes in pair_positions.items()}
# print(pair_counts)

max_count_pair = max(pair_counts, key=pair_counts.get)
print(max_count_pair, pair_counts[max_count_pair], type(max_count_pair[0]))

(b'h', b'e') 516999 <class 'bytes'>


In [15]:
print(repr(b''.join(max_count_pair).decode('utf-8')))

'he'


# 4. Merge

In [16]:
merges: List[Tuple[bytes, bytes]] = []

num_merges = 5

for merge_i in range(num_merges):
    max_count_pair = max(pair_counts, key=pair_counts.get)
    # Add to merges
    merges.append(max_count_pair)
    
    # Add new vocab
    merged_val = b''.join(max_count_pair)
    vocab[vocab_size]=merged_val
    vocab_size+=1
    
    print("MERGE {} {}".format(merge_i, merged_val))
    
    for node_a in list(pair_positions[max_count_pair]):
        node_b = node_a.next
        
        # 1. Merge Node
        new_node = TokenNode(merged_val)
        new_node.prev=node_a.prev
        new_node.next=node_b.next
        new_node.is_next_connected=node_b.is_next_connected
        
        # 2. Update Left
        if node_a.prev:
            if node_a.prev.is_next_connected:
                # Remove previous
                prev_pair = (node_a.prev.val, node_a.val)
                pair_counts[prev_pair]-=1
                pair_positions[prev_pair].discard(node_a.prev)
                
                # Add new merged version
                new_pair = (node_a.prev.val, merged_val)
                pair_counts[new_pair] = pair_counts.get(new_pair, 0) + 1
                pair_positions[new_pair].add(node_a.prev)
            node_a.prev.next=new_node
        
        # 3. Update Right
        if node_b.next and node_b.is_next_connected:
            if node_b.is_next_connected:
                # Remove previous
                prev_pair = (node_b.val, node_b.next.val)
                pair_counts[prev_pair]-=1
                pair_positions[prev_pair].discard(node_b)
                
                # Add new merged version
                new_pair = (merged_val, node_b.next.val)
                pair_counts[new_pair] = pair_counts.get(new_pair, 0) + 1
                pair_positions[new_pair].add(new_node)
            node_b.next.prev=new_node
        
        del node_a
        del node_b
    
    # Delete pair count, positions
    del pair_counts[max_count_pair]
    del pair_positions[max_count_pair]

MERGE 0 b'he'
MERGE 1 b'er'
MERGE 2 b'\xc4\xa0the'
MERGE 3 b'in'
MERGE 4 b'on'


In [17]:
b'\xc4\xa0'.decode('utf-8')

'Ġ'