In [44]:
with open("corpus.txt", 'r') as f:
    corpus = f.read()
print(corpus)

low low low low low 
lower lower widest widest widest
newest newest newest newest newest newest 


In [45]:
vocab_size = 262
num_of_merges = vocab_size - 256
vocab = {}

# Pre-tokenization
def pretokenize(corpus):
    counts = {}
    for word in corpus.split():
        counts[word] = counts.get(word, 0) + 1
    return counts

corpus = pretokenize(corpus)

In [47]:
corpus.keys()

dict_keys(['low', 'lower', 'widest', 'newest'])

In [38]:
def split_to_bytes(corpus):
    counts = {}
    # Count byte pairs
    for k, v in corpus.items():
        new_key = tuple([c for c in k])
        counts[new_key] = v

    # Sort by the highest frequency
    counts = dict(sorted(counts.items(), key=lambda x: x[1], reverse=True))
    return counts

corpus = split_to_bytes(corpus)
corpus

{('n', 'e', 'w', 'e', 's', 't'): 6,
 ('l', 'o', 'w'): 5,
 ('w', 'i', 'd', 'e', 's', 't'): 3,
 ('l', 'o', 'w', 'e', 'r'): 2}

{('n', 'e', 'w', 'e', 's', 't'): 6,
 ('l', 'o', 'w'): 5,
 ('w', 'i', 'd', 'e', 's', 't'): 3,
 ('l', 'o', 'w', 'e', 'r'): 2}

In [28]:
def count_bytepairs(corpus):
    counts = {}

    for k, v in corpus.items():
        for c1, c2 in zip(k, k[1:]):
            counts[(c1, c2)] = counts.get((c1, c2), 0) + v
    return counts

In [29]:
def get_mf_pair(counts):
    # Get the max frequency
    maxf = counts[max(counts, key=counts.get)]

    # Get the candidates with the max frequency
    candidates = [k for k, v in counts.items() if v == maxf]
    # Pick the lexicographically greater pair
    pair = max(candidates)
    return pair

In [30]:
corpus

{('n', 'e', 'w', 'e', 's', 't'): 6,
 ('l', 'o', 'w'): 5,
 ('w', 'i', 'd', 'e', 's', 't'): 3,
 ('l', 'o', 'w', 'e', 'r'): 2}

In [31]:
def merge(corpus, merge_pair):
    new_dict = {}

    for k, v in corpus.items():
        new_k = []
        b = 0
        while b < len(k):
            if b + 1 < len(k):
                c1, c2 = k[b], k[b+1]
                if c1 + c2 == merge_pair:
                    new_k.append(merge_pair)
                    b += 2
                else:
                    new_k.append(c1)
                    b += 1
            else:
                new_k.append(k[b])
                b += 1
        new_dict[tuple(new_k)] = v
    
    return new_dict

In [32]:
merges = []

# Start merges
for _ in range(num_of_merges):
    # Count bytepairs
    counts = count_bytepairs(corpus)
    # Get the most frequent pair
    pair = get_mf_pair(counts)
    # Add merge to merges
    merges.append(pair)
    # Add the merge to the vocab
    vocab.append("".join(pair))
    # Apply the merge to the corpus
    corpus = merge(corpus, "".join(pair))

In [33]:
vocab

['st', 'est', 'ow', 'low', 'west', 'ne']

In [34]:
merges

[('s', 't'), ('e', 'st'), ('o', 'w'), ('l', 'ow'), ('w', 'est'), ('n', 'e')]