In [4]:
import re
from collections import defaultdict

def get_vocab(data):
    """ Given a list of strings, returns a dictionary of words mapping to their frequency count in the data. """
    vocab = defaultdict(int)
    for line in data:
        for word in line.split():
            # We add ' </w>' to indicate the end of a word
            vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

def get_stats(vocab):
    """ Given a vocabulary (dictionary mapping words to frequency counts), returns a dictionary of tuples representing the frequency count of pairs of characters in the vocabulary. """
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += freq
    return pairs

def merge_vocab(pair, v_in):
    """ Merges the most frequent pair in the vocabulary. """
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

def byte_pair_encoding(data, num_merges):
    """ Performs byte pair encoding on the input data. """
    vocab = get_vocab(data)
    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        best_pair = max(pairs, key=pairs.get)
        vocab = merge_vocab(best_pair, vocab)
    return vocab

# Example usage:
data = ["this is a test", "this is another test", "why is this a test"]
num_merges = 10
bpe_vocab = byte_pair_encoding(data, num_merges)

vocab = get_vocab(data)
pairs = get_stats(vocab)
print(f'Vocab:\n{vocab}')
print(f'Pairs:\n{pairs}')
print(f'BPE Vocab:\n{bpe_vocab}')

Vocab:
defaultdict(<class 'int'>, {'t h i s </w>': 3, 'i s </w>': 3, 'a </w>': 2, 't e s t </w>': 3, 'a n o t h e r </w>': 1, 'w h y </w>': 1})
Pairs:
defaultdict(<class 'int'>, {('t', 'h'): 4, ('h', 'i'): 3, ('i', 's'): 6, ('s', '</w>'): 6, ('a', '</w>'): 2, ('t', 'e'): 3, ('e', 's'): 3, ('s', 't'): 3, ('t', '</w>'): 3, ('a', 'n'): 1, ('n', 'o'): 1, ('o', 't'): 1, ('h', 'e'): 1, ('e', 'r'): 1, ('r', '</w>'): 1, ('w', 'h'): 1, ('h', 'y'): 1, ('y', '</w>'): 1})
BPE Vocab:
{'this</w>': 3, 'is</w>': 3, 'a</w>': 2, 'test</w>': 3, 'an o th e r </w>': 1, 'w h y </w>': 1}


In [5]:
import matplotlib.pyplot as plt

# Assuming you have a function `byte_pair_encoding` that returns the vocabulary
# and a function `validate_model` that returns a performance metric

num_merges_options = [1000, 5000, 10000, 15000, 20000]  # Example merge options
vocab_sizes = []
validation_scores = []

for num_merges in num_merges_options:
    vocab = byte_pair_encoding(data, num_merges)
    vocab_sizes.append(len(vocab))
    
    # Train your model here with the vocab and evaluate on the validation set
    score = validate_model(model, validation_data)
    validation_scores.append(score)

# Now plot the trends
plt.figure(figsize=(14, 7))

plt.subplot(1, 2, 1)
plt.plot(num_merges_options, vocab_sizes, marker='o')
plt.title('Vocabulary Size vs. Number of Merges')
plt.xlabel('Number of Merges')
plt.ylabel('Vocabulary Size')

plt.subplot(1, 2, 2)
plt.plot(num_merges_options, validation_scores, marker='o', color='r')
plt.title('Validation Score vs. Number of Merges')
plt.xlabel('Number of Merges')
plt.ylabel('Validation Score')

plt.tight_layout()
plt.show()

NameError: name 'validate_model' is not defined

In [2]:
text_bytes = 'this is a test'.encode('utf-8')

In [3]:
print(text_bytes)

b'this is a test'


In [4]:
list(text_bytes)

[116, 104, 105, 115, 32, 105, 115, 32, 97, 32, 116, 101, 115, 116]