Byte Pair Encoding (BPE) Implementation and Evaluation on NLTK Dataset

1. Implement BPE Algorithm

In [38]:
# importing libraries
from collections import defaultdict

def count_frequency(corpus):
  """
    Calculate the frequency of adjacent subword pairs in a text corpus.

    Args:
        corpus (dict): A dictionary where keys are strings of space-separated subwords 
                      (representing words or words) and values are their corresponding 
                      frequencies.

    Returns:
        defaultdict: A dictionary where keys are tuples of adjacent subwords and values are 
                     the summed frequencies of those subword pairs across the corpus.
                     
  """

  # creating a dictionary to store the frequency of subword pairs present in the corpus
  subword_pairs = defaultdict(int) 

  #iterating through the corpus to find subword and its corresponding frequency
  for words,freq in corpus.items():
    # extracting the subwords from the words by splitting with help of spaces
    subwords=words.split()
    for i in range(len(subwords)-1):
      # creating a tuple with adjacent subwords
      pair=(subwords[i],subwords[i+1])
      # counting the frequency of each pair
      subword_pairs[pair]+=freq
  return subword_pairs



def merge_subword(pair,corpus):
  """
    Merge the most frequent pair of subwords in the corpus.
    
    Args:
        pair (tuple): A tuple representing the pair of subwords to be merged.
        word_dictionary (dict): The input text corpus where keys are strings of space-separated subwords
                     and values are their frequencies.

    Returns:
        dict: The updated text corpus word dictionary after merging the most frequent pair.
    """
  updated_corpus={}
  space_seperated_pair = ' '.join(pair)
  merged_pair = ''.join(pair)

  for word in corpus:
    # replacing the separated pair with the merged pair in all words where this pair exists
    new_word = word.replace(space_seperated_pair,merged_pair)
    updated_corpus[new_word] = corpus[word]
  return updated_corpus



def get_processed_corpus(texts):
    """
    Given a list of strings, returns a dictionary of words mapping to their frequency 
    count in the data and a set of unique letters.

    Args:
        texts (list of str): A list of input texts, where each text is a string.

    Returns:
        tuple: A tuple containing:
               - dict: A dictionary where keys are words with space-separated characters and 
                       an end-of-word marker, and values are their frequencies.
               - set: A set of unique characters found in the corpus.
    """

    corpus = defaultdict(int)
    unique_letters = set()
    unique_letters.add('_')
    for text in texts:
      # text = text.lower()   # converting text to lowercase
      for word in text.split():
        for letter in word:
          unique_letters.add(letter)
        corpus[' '.join(list(word)) + ' _'] += 1
    return corpus, unique_letters



def byte_pair_encoding_algorithm(texts, num_merges):
    """
    Apply Byte Pair Encoding (BPE) to a list of texts to generate a subword vocabulary.
    Args:
        texts (list of str): A list of input texts, where each text is a string.
        num_merges (int): The number of merge operations to perform.
    Returns:
        tuple: A tuple containing:
               - dict: A dictionary where keys are words with merged subwords and values are their frequencies.
               - list: A list of tuples where each tuple contains a merged pair and its frequency.
    """
    vocabulary = {}
    index = 0
    corpus, unique_letters = get_processed_corpus(texts)

    # initializing the vocabulary with unique characters in the training corpus
    for letter in unique_letters:
        vocabulary[letter] = index
        index += 1

    merge_frequencies = []

    for _ in range(num_merges):
        subword_pairs = count_frequency(corpus)
        if not subword_pairs:
            break
        best_pair = max(subword_pairs, key=subword_pairs.get) # finding best pair with highest frequency
        merge_frequencies.append((best_pair, subword_pairs[best_pair])) 
        corpus = merge_subword(best_pair, corpus) # merging the best pair in corpus
        vocabulary[''.join(best_pair)] = index  # adding the merged pair to vocabulary
        index += 1
    return vocabulary, merge_frequencies


def encode(text, vocabulary):
    """
    Encode a given text using the BPE vocabulary.
    Args:
        text (str): The input text to be encoded.
        vocabulary (dict): The BPE vocabulary mapping subwords to their indices.
    Returns:
        list: A list of subword tokens representing the encoded text.
    """
    
    tokens = []
    sorted_vocabulary = sorted(vocabulary, key=vocabulary.get)
    for word in text.split():
        subwords = list(word) + ['_']  # adding end-of-word marker
        while True:
            # finding the best pair to merge
            pairs = [(subwords[i], subwords[i + 1]) for i in range(len(subwords) - 1)]
            best_pair = None
            for pair in pairs:
                if ''.join(pair) in sorted_vocabulary:
                    best_pair = pair
                    break
            if best_pair is None:
                break  # No more pairs to merge

            # merging the best pair
            merged_pair = ''.join(best_pair)
            new_subwords = []
            i = 0
            while i < len(subwords):
                if i < len(subwords) - 1 and (subwords[i], subwords[i + 1]) == best_pair:
                    new_subwords.append(merged_pair)
                    i += 2  # skipping the next subword since it's merged
                else:
                    new_subwords.append(subwords[i])
                    i += 1
            subwords = new_subwords
        tokens.extend(subwords)
    return tokens



def decode(tokens):
    """
    Decode a list of subword tokens back into the original text.

    Args:
        tokens (list): A list of subword tokens.

    Returns:
        str: The decoded original text.
    """
    text = ''.join(tokens)
    return text.replace('_', ' ')  # Remove end-of-word markers


def clean_bpe_tokens(bpe_tokens):
  # Remove the end-of-word markers and filter out tokens that are only underscores
  return [token.replace('_', '') for token in bpe_tokens if token != '_']


2. Train on NLTK Dataset

In [39]:
import nltk

# Download the NLTK Gutenberg Corpus
nltk.download('gutenberg')

# Import the Gutenberg Corpus module
from nltk.corpus import gutenberg

# Get the list of available books in the Gutenberg Corpus
book_list = gutenberg.fileids()

# # Print the list of available books
# print("Available Books:")
# for book in book_list:
#     print(book)

# Load books for training
train_book_1 = gutenberg.raw('austen-emma.txt')
train_book_2 = gutenberg.raw('blake-poems.txt')
train_book_3 = gutenberg.raw('shakespeare-hamlet.txt')

# # Display the first 500 characters of the selected book
# print("\nSample Text from 'shakespeare-hamlet.txt':")
# print(train_book_3[:500])

# text corpora for training
texts = [train_book_1,train_book_2,train_book_3]

# defining maximum number of merges
num_merges = 10000

# generating the vocabulary using BPE on training corpora
BPE_vocabulary, merge_frequencies = byte_pair_encoding_algorithm(texts,num_merges)

print(BPE_vocabulary)
print("Vocabulary length:", len(BPE_vocabulary))


# testing the BPE tokensization on a sample text
sample_text = "I have enjoyed this assignment."

sample_encoding = encode(sample_text, BPE_vocabulary)
print("Sample Tokensized text:", sample_encoding)

print("Length of tokenized text:", len(sample_encoding))

sample_decoding = decode(sample_encoding)
print("Decoded sample text:", sample_decoding)

cleaned_sample_encoding = clean_bpe_tokens(sample_encoding)
print("Cleaned encoding:", cleaned_sample_encoding)


[nltk_data] Downloading package gutenberg to /home/rahul/nltk_data...
[nltk_data]   Package gutenberg is already up-to-date!


{'o': 0, 'g': 1, ';': 2, '9': 3, '1': 4, 'L': 5, 'S': 6, 'R': 7, '_': 8, 'G': 9, 'z': 10, 'K': 11, ')': 12, 'l': 13, '5': 14, 'I': 15, 't': 16, 'y': 17, 'J': 18, 'x': 19, '!': 20, 'E': 21, 'P': 22, 'C': 23, 'k': 24, '-': 25, 'X': 26, 'A': 27, '?': 28, 'u': 29, 'c': 30, ']': 31, 'F': 32, 'T': 33, 'a': 34, '.': 35, 'r': 36, 'i': 37, '&': 38, 'B': 39, '3': 40, '[': 41, '(': 42, 'Z': 43, 'd': 44, '6': 45, 'V': 46, 'v': 47, 'h': 48, 'M': 49, 'H': 50, 'U': 51, '"': 52, '2': 53, 'm': 54, 'n': 55, ',': 56, '8': 57, '0': 58, '`': 59, 's': 60, 'Y': 61, "'": 62, 'W': 63, 'w': 64, 'f': 65, 'O': 66, 'Q': 67, 'p': 68, 'e': 69, 'b': 70, 'j': 71, 'N': 72, 'D': 73, '4': 74, '7': 75, ':': 76, 'q': 77, 'e_': 78, 't_': 79, 'th': 80, 'd_': 81, 'er': 82, 's_': 83, ',_': 84, 'in': 85, 'an': 86, 'y_': 87, 'ou': 88, 'o_': 89, '._': 90, 'on': 91, 'en': 92, 'or': 93, 'ar': 94, 'f_': 95, 'the_': 96, 'ing': 97, 'to_': 98, 'ha': 99, 'er_': 100, 'll': 101, 'and_': 102, 're': 103, 'of_': 104, 'ing_': 105, 'he_': 106,

3. Test on NLTK Dataset

In [40]:
# Load books for testing
test_book_1 = gutenberg.raw('carroll-alice.txt')
test_book_2 = gutenberg.raw('burgess-busterbrown.txt')
test_book_3 = gutenberg.raw('shakespeare-caesar.txt')

test_texts = [test_book_1,test_book_2,test_book_3]
# test_texts = [test_book_2,test_book_3]

bpe_tokenized_texts = [encode(text, BPE_vocabulary) for text in test_texts]

# print(BPE_tokenization)
print(len(bpe_tokenized_texts[0]),len(bpe_tokenized_texts[1]),len(bpe_tokenized_texts[2]))

57056 34286 45530


4. Tokenization comparison (BPE vs Word Tokenize)

In [41]:
import itertools
from collections import Counter
from nltk.tokenize import word_tokenize

# flattening the list of tokens generated by BPE
bpe_tokens = list(itertools.chain(*bpe_tokenized_texts))
# cleaning the tokens such as removing end of word
bpe_tokens = clean_bpe_tokens(bpe_tokens)
print("bpe tokens:", bpe_tokens)

# Tokenize the test dataset using NLTK's default tokenizer
nltk_tokenized_texts = [word_tokenize(text) for text in test_texts]
nltk_tokens = list(itertools.chain(*nltk_tokenized_texts))
print("nltk tokens:",nltk_tokens)

# Calculate token counts
nltk_counts = Counter(nltk_tokens)
bpe_counts = Counter(bpe_tokens)

# Calculate TP, FP and FN
true_positives = sum([min(nltk_counts[token], bpe_counts[token]) for token in set(nltk_tokens + bpe_tokens)])
false_positives = sum(bpe_counts.values()) - true_positives
false_negatives = sum(nltk_counts.values()) - true_positives

# Tokenization accuracy
total_nltk_counts = sum([nltk_counts[token] for token in set(nltk_tokens)])
tokenization_accuracy = (true_positives/total_nltk_counts) * 100
print(f"Tokenization accuracy: {tokenization_accuracy:.2f}%")

# Tokenization coverage
tokenization_coverage = (len(set(bpe_tokens).intersection(set(nltk_tokens)))/len(nltk_counts)) * 100
print(f"Tokenization coverage: {tokenization_coverage:.2f}%")

# Precision, Recall and F1-Score
precision = true_positives / (true_positives + false_positives)
recall = true_positives / (true_positives + false_negatives)
f1_score = 2 * (precision * recall) / (precision + recall)

print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 score: {f1_score:.4f}")

# Jaccard Similarity
jaccard_similarity = len(set(bpe_tokens).intersection(set(nltk_tokens)))/len(set(bpe_tokens).union(set(nltk_tokens)))
print(f"Jaccard Similarity: {jaccard_similarity:.4f}")


bpe tokens: ['[', 'Al', 'i', 'ce', "'s", 'Ad', 've', 'n', 'tur', 'es', 'in', 'W', 'on', 'de', 'r', 'la', 'n', 'd', 'by', 'Le', 'wi', 's', 'Car', 'ro', 'll', '18', '6', '5', ']', 'CHAPTER', 'I.', 'Do', 'w', 'n', 'the', 'R', 'ab', 'bi', 't-', 'Ho', 'le', 'Al', 'ice', 'was', 'beg', 'in', 'ning', 'to', 'get', 'very', 'ti', 're', 'd', 'of', 'si', 't', 'ti', 'n', 'g', 'by', 'her', 'si', 'st', 'er', 'on', 'the', 'bank,', 'and', 'of', 'ha', 'vi', 'n', 'g', 'not', 'hi', 'n', 'g', 'to', 'do:', 'once', 'or', 'tw', 'ice', 'she', 'had', 'pe', 'e', 'pe', 'd', 'into', 'the', 'bo', 'o', 'k', 'her', 'si', 'st', 'er', 'was', 'read', 'ing,', 'but', 'it', 'had', 'no', 'pi', 'ct', 'ur', 'es', 'or', 'conver', 'sa', 'tion', 's', 'in', 'it', ',', "'", 'and', 'wh', 'at', 'is', 'the', 'use', 'of', 'a', 'bo', 'o', 'k', ",'", 'thought', 'Al', 'ice', "'", 'wi', 'thou', 't', 'pi', 'ct', 'ur', 'es', 'or', 'conver', 'sa', 'tion', "?'", 'So', 'she', 'was', 'consi', 'de', 'ri', 'n', 'g', 'in', 'her', 'own', 'mi', 'n', 

References
1. https://stackoverflow.com/questions/50583254/explain-bpe-byte-pair-encoding-with-examples
2. https://github.com/karpathy/minbpe/tree/master/minbpe
3. https://www.geeksforgeeks.org/byte-pair-encoding-bpe-in-nlp/