In [285]:
with open("data/TinyStoriesV2-GPT4-valid.txt", 'r') as f:
    corpus = f.read()[:3000]
print(corpus)

u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
<|endoftext|>
Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
Tom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."
Sam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."
They went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one could hear t

In [286]:
import regex as re

def strip_of_special_tokens(corpus, special_tokens):
    """Strips of special tokens to avoid counting them as bytes"""

    # Escape | delimiter in special tokens
    for i in range(len(special_tokens)):
        if "|" in special_tokens[i]:
            special_tokens[i] = re.escape(special_tokens[i])

    # Join special tokens into a delim for a splitting pattern
    delim = "|".join(special_tokens)
    chunks = re.split(delim, corpus)
    # Remove empty chunks
    chunks = [ch for ch in chunks if ch.strip()]
    return chunks

special_tokens = ["<|endoftext|>", "<start>", "<end>"]

In [287]:
vocab_size = 262
num_of_merges = vocab_size - 256
vocab = {}

PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

# Pre-tokenization
def pretokenize(corpus, ptrn):
    """Pre-tokenizes on regex pattern"""

    counts = {}
    for t in corpus:
        for word in re.findall(ptrn, t):
            counts[word] = counts.get(word, 0) + 1
    return counts

In [288]:
def split_to_bytes(corpus):
    """Splits words by characters and counts frequency"""

    counts = {}
    # Count byte pairs
    for k, v in corpus.items():
        new_key = tuple([c for c in k])
        counts[new_key] = v

    # Sort by the highest frequency
    counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True))
    return counts

In [None]:
from collections import defaultdict

def count_bytepairs(corpus, bp_to_counts=None, bp_to_words=None, mf_pair=None, merged_words=None):
    """Counts bytepair frequencies in the corpus
    If ran the first time (no counts provided) then count all byte pairs in the whole corpus
    If ran consecutively, remove the most frequent pair count as it's merged now
    and count byte pairs only on the merged words now"""

    # If counts are provided, remove the most frequent byte pair from the counts and then re-count only the merged words
    if bp_to_counts and bp_to_words:
        # Collect byte pairs to remove which overlap with pre-merge bytes
        pairs_to_remove = [mf_pair]
        pairs_to_remove.extend([p for p in bp_to_counts.keys() if p[1] == mf_pair[0] or p[0] == mf_pair[1]])

        # Remove the old byte pairs that were created on pre-merge
        for p in pairs_to_remove:
            bp_to_counts.pop(p)
            bp_to_words.pop(p)

        # Count only merged words as only these are updated
        for w in merged_words:
           for c1, c2 in zip(w, w[1:]):
                bp_to_counts[(c1, c2)] = bp_to_counts.get((c1, c2), 0) + corpus[w]
                bp_to_words[(c1, c2)].add(w) 
    
    # For the first time we need to count every single pair
    else:
        bp_to_counts = {}
        bp_to_words = defaultdict(set)

        # Count every single pair in the corpus
        for k, v in corpus.items():
            for c1, c2 in zip(k, k[1:]):
                bp_to_counts[(c1, c2)] = bp_to_counts.get((c1, c2), 0) + v
                bp_to_words[(c1, c2)].add(k)

    return bp_to_counts, bp_to_words

In [290]:
def get_mf_pair(counts, counts_to_words):
    """Takes the most frequent byte pair along with words with that pair and returns them"""

    # Get the max frequency
    maxf = counts[max(counts, key=counts.get)]

    # Get the candidates with the max frequency
    candidates = [k for k, v in counts.items() if v == maxf]
    # Pick the lexicographically greater pair
    pair = max(candidates)
    return pair, counts_to_words[pair]

In [291]:
def merge(corpus, merge_pair, merge_pair_words):
    """Merges the word in the corpus 
    by joining two bytes into one and
    re-assigning the key in the dictionary"""

    # Keeps track of merged words to update byte counting only on these as these only change
    merged_words = set()
    
    # Merging the keys 
    for w in merge_pair_words:
        new_k = []
        b = 0
        while b < len(w):
            if b + 1 < len(w):
                c1, c2 = w[b], w[b+1]
                if c1 + c2 == merge_pair:
                    new_k.append(merge_pair)
                    b += 2
                else:
                    new_k.append(c1)
                    b += 1
            else:
                new_k.append(w[b])
                b += 1
        # Add new merged word
        corpus[tuple(new_k)] = corpus[w]

        # Add to merged words to optimize byte pair counts as this only changes and the rest is still the same
        merged_words.add(tuple(new_k))
    
    # Pop the unmerged words
    for w in merge_pair_words:
        corpus.pop(w)
    
    return corpus, merged_words

In [292]:
def pair_to_bytes(pair):
    return tuple(b.encode("utf-8") for b in pair)

## Main merging loop

In [293]:
merges = []
mf_pair = None
merged_words = None

corpus = strip_of_special_tokens(corpus, special_tokens)
corpus = pretokenize(corpus, PAT)
corpus = split_to_bytes(corpus)

# Start merges
for i in range(num_of_merges):
    if i == 0:
        # Count bytepairs
        counts, counts_to_words = count_bytepairs(corpus)
    else:
       counts, counts_to_words = count_bytepairs(corpus, counts, counts_to_words, mf_pair, merged_words) 
    print(f"# of counts: {len(counts)}")
    # Get the most frequent pair
    mf_pair, mf_pair_words = get_mf_pair(counts, counts_to_words)
    # Add merge to merges
    pair_b = pair_to_bytes(mf_pair)
    merges.append(pair_b)
    # Add the merge to the vocab
    merge_b = "".join(mf_pair).encode("utf-8")
    vocab[256 + i] = merge_b
    # Apply the merge to the corpus
    corpus, merged_words = merge(corpus, "".join(mf_pair), mf_pair_words)

# of counts: 275
# of counts: 290


KeyError: (' ', 't', 'h', 'e')

In [None]:
vocab

{256: b'he', 257: b' t', 258: b' a', 259: b' s', 260: b' the', 261: b' w'}

In [None]:
merges

[(b'h', b'e'),
 (b' ', b't'),
 (b' ', b'a'),
 (b' ', b's'),
 (b' t', b'he'),
 (b' ', b'w')]