In [2]:
from datasets import load_dataset
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = load_dataset('notaphoenix/shakespeare_dataset')
dataset

DatasetDict({
    validation: Dataset({
        features: ['text', 'label'],
        num_rows: 429
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1072
    })
    training: Dataset({
        features: ['text', 'label'],
        num_rows: 3859
    })
})

In [None]:
dataset['training'][0]


list('abcd') #-> ['a', 'b', 'c', 'd']

['a', 'b', 'c', 'd']

### How many characters/words are there in total in the train split of the dataset ?

In [None]:
n_chars = sum(len(elt['text']) for elt in dataset['training'])
n_words = sum(len(elt['text'].split(' ')) for elt in dataset['training'])
n_chars, n_words

(179380, 40229)

# 1. Byte-pair encoding

In [5]:
# For convenience we use the training data as a list of str:
corpus = [elt['text'] for elt in dataset['training']]
corpus[0]

'How well my comfort is revived by this !\n'

### 1.1 First step: build a list "tokenized_corpus" of same size as the corpus such that tokenized_corpus[i] is the list of characters in corpus[i] (just by splitting the str into its list of characters)
So tokenized_corpus starts by representing each corpus entry by a list of characters. As a reminder, Byte-pair encoding will iteratively refine it by merging the most frequent pairs of characters:
['a', 'b', 'c', 'a', 'b'] -> ['ab', 'c', 'ab']: this is the representation which we store in 'tokenized_corpus'.

In [8]:
#tokenized_corpus = [list(word) for word in corpus] # one-liner (optional)

tokenized_corpus = []
for elt in corpus:
    tokenized_corpus.append(list(elt))

tokenized_corpus[0]

['H',
 'o',
 'w',
 ' ',
 'w',
 'e',
 'l',
 'l',
 ' ',
 'm',
 'y',
 ' ',
 'c',
 'o',
 'm',
 'f',
 'o',
 'r',
 't',
 ' ',
 'i',
 's',
 ' ',
 'r',
 'e',
 'v',
 'i',
 'v',
 'e',
 'd',
 ' ',
 'b',
 'y',
 ' ',
 't',
 'h',
 'i',
 's',
 ' ',
 '!',
 '\n']

### 1.2 Now build a dictionary "vocab" mapping each unique character in the corpus to a unique integer between 0 and the total number of characters - 1

In [11]:
#vocab = {char: i for i, char in enumerate(set("".join(corpus)))} # one-liner (optional)

vocab = {}
for elt in tokenized_corpus: # iterate over corpus
    for char in elt: # iterate over characters of elt
        if char not in vocab: # add to dictionary
            vocab[char] = len(vocab) # index of character is current length of vocab !

vocab.items()

dict_items([('H', 0), ('o', 1), ('w', 2), (' ', 3), ('e', 4), ('l', 5), ('m', 6), ('y', 7), ('c', 8), ('f', 9), ('r', 10), ('t', 11), ('i', 12), ('s', 13), ('v', 14), ('d', 15), ('b', 16), ('h', 17), ('!', 18), ('\n', 19), ('B', 20), ('u', 21), ('I', 22), ('a', 23), ('k', 24), ('p', 25), ("'", 26), (',', 27), ('n', 28), ('.', 29), ('g', 30), ('L', 31), ('O', 32), ('G', 33), ('S', 34), ('A', 35), ('Y', 36), ('C', 37), ('R', 38), ('T', 39), ('W', 40), ('P', 41), ('N', 42), ('?', 43), ('D', 44), ('j', 45), ('q', 46), ('x', 47), ('-', 48), ('F', 49), ('M', 50), ('"', 51), ('V', 52), ('J', 53), ('E', 54), ('0', 55), ('U', 56), ('z', 57), (';', 58), (':', 59), ('K', 60), (')', 61), ('(', 62), ('Q', 63), ('Z', 64)])

In [13]:
 #Dictionary manipulations:
 d = {}
 d['a'] = 2 # allocation of new key
 if 'a' in d: # checking a key belongs
     print('do something')

d['a'] # acccecssing a value (error if 'a' not in d)

do something


2

### 1.3 Now implement a method which, given 'tokenized_corpus' computes a dictionary 'pairs' of which:
- the keys are pairs of elements appearing in the corpus (if corpus contains ['c', ..., 'a', 'b', ...] then the tuple ('a', 'b') should appear as key in pairs)
- the values are the number of times each pair appears in the corpus

(The right way is to build the pairs dictionary by iterating over the tokenized_corpus)

In [16]:
def get_pairs(tokenized_corpus):
    """Get the frequency of adjacent pairs in the words."""
    pairs = {}
    for word in tokenized_corpus:
        for i in range(len(word) - 1):
            if (word[i], word[i + 1]) not in pairs:
                pairs[(word[i], word[i + 1])] = 1
            else:
                pairs[(word[i], word[i + 1])] += 1
    return pairs

pairs = get_pairs(tokenized_corpus)
pairs.items()

dict_items([(('H', 'o'), 80), (('o', 'w'), 609), (('w', ' '), 515), ((' ', 'w'), 1722), (('w', 'e'), 461), (('e', 'l'), 650), (('l', 'l'), 1193), (('l', ' '), 1293), ((' ', 'm'), 2017), (('m', 'y'), 516), (('y', ' '), 2372), ((' ', 'c'), 972), (('c', 'o'), 490), (('o', 'm'), 739), (('m', 'f'), 16), (('f', 'o'), 506), (('o', 'r'), 1106), (('r', 't'), 337), (('t', ' '), 3782), ((' ', 'i'), 1329), (('i', 's'), 1235), (('s', ' '), 3386), ((' ', 'r'), 402), (('r', 'e'), 1626), (('e', 'v'), 188), (('v', 'i'), 179), (('i', 'v'), 244), (('v', 'e'), 1096), (('e', 'd'), 623), (('d', ' '), 2690), ((' ', 'b'), 1422), (('b', 'y'), 150), ((' ', 't'), 3839), (('t', 'h'), 3279), (('h', 'i'), 1123), ((' ', '!'), 402), (('!', '\n'), 379), (('B', 'u'), 105), (('u', 't'), 596), (('i', 'f'), 197), (('f', ' '), 776), ((' ', 'I'), 761), (('I', ' '), 909), ((' ', 's'), 2087), (('s', 'h'), 474), (('h', 'a'), 1557), (('a', 'k'), 279), (('k', 'e'), 442), (('e', ' '), 6284), (('i', 't'), 882), ((' ', 'u'), 275), 

### Now, given pairs, we can compute the most frequent pair of tokens via:

In [17]:
pair_to_merge = max(pairs, key=pairs.get)
pair_to_merge

('e', ' ')

### 1.4 Write a method which, given a pair_to_merge and the current tokenized_corpus and applies the merge operation to the tokenized_corpus
e.g. if the pair to merge is ('a', 'b') and some entry in tokenized_corpus contains ['c', ..., 'a', 'b'...], it should be mapped to ['c', ..., 'ab', ...]

In [None]:
def merge_pair(pair, tokenized_corpus):
    """Merge the most frequent pair in all tokenized_corpus."""
    new_words = []
    for word in tokenized_corpus:
        new_word = []
        i = 0
        while i < len(word):
            if i < len(word) - 1 and (word[i], word[i + 1]) == pair:
                new_word.append(word[i] + word[i + 1])  # Merge pair ("+" is str concatenation)
                i += 2  # Skip next character since it's merged
            else:
                new_word.append(word[i])
                i += 1
        new_words.append(new_word)
    return new_words

tokenized_corpus = merge_pair(pair_to_merge, tokenized_corpus)
tokenized_corpus[0]

['H', 'o', 'w', ' ', 'w', 'e', 'l', 'l', ' ', 'm', 'y', ' ', 'c', 'o', 'm', 'f', 'o', 'r', 't', ' ', 'i', 's', ' ', 'r', 'e', 'v', 'i', 'v', 'e', 'd', ' ', 'b', 'y', ' ', 't', 'h', 'i', 's', ' ', '!', '\n']


['Ho',
 'w',
 ' ',
 'w',
 'e',
 'l',
 'l',
 ' ',
 'm',
 'y',
 ' ',
 'c',
 'o',
 'm',
 'f',
 'o',
 'r',
 't',
 ' ',
 'i',
 's',
 ' ',
 'r',
 'e',
 'v',
 'i',
 'v',
 'e',
 'd',
 ' ',
 'b',
 'y',
 ' ',
 't',
 'h',
 'i',
 's',
 ' ',
 '!',
 '\n']

### 1.5 Now write the full byte-pair encoding algorithm which, given an initial corpus (list of str) returns its tokenized version as well as the vocabulary.
This function should take a parameter 'n_merge' indicating the number of merges we will do. 

What is the final length of the vocabulary ?

In [11]:
def byte_pair_encoding(corpus, num_merges: int = 10):
    """Perform BPE on a given corpus."""
    tokenized_corpus = [list(word) for word in corpus]  # Start with character tokens
    vocab = {char: i for i, char in enumerate(set("".join(corpus)))}  # Initial vocab
    
    for _ in range(num_merges):
        pairs = get_pairs(tokenized_corpus)
        if not pairs:
            break
        best_pair = max(pairs, key=pairs.get)
        tokenized_corpus = merge_pair(best_pair, tokenized_corpus)
        new_token = best_pair[0] + best_pair[1]
        vocab[new_token] = len(vocab)  # Assign new token an index
    
    return vocab, tokenized_corpus

vocab, tokenized_corpus = byte_pair_encoding(corpus, num_merges=500)

### 1.6 Now, given the obtained 'vocab' write a 'tokenize_to_str_list' method which applies the tokenization to an input string s, returns the list of tokens as a list of str
Verify your result on dataset['test'][0]

In [12]:
def tokenize_to_str_list(s: str, vocab):
    """Tokenize a given string based on the trained BPE vocabulary."""
    tokens = list(s)  # Start with character tokens
    
    while True:
        pairs = [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
        valid_pairs = [pair for pair in pairs if pair[0] + pair[1] in vocab]
        
        if not valid_pairs:
            break
        
        best_pair = max(valid_pairs, key=lambda p: vocab.get(p[0] + p[1], float('-inf')))
        merged_token = best_pair[0] + best_pair[1]
        
        new_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == best_pair:
                new_tokens.append(merged_token)
                i += 2
            else:
                new_tokens.append(tokens[i])
                i += 1
        
        tokens = new_tokens
    
    return tokens    

tokenize_to_str_list(dataset['test'][0]['text'], vocab)

['L',
 'ea',
 've ',
 'me ',
 'a',
 'lo',
 'ne',
 ' ',
 'for ',
 'a ',
 'min',
 'u',
 'te ',
 '.\n']

### 1.7 Now, given the obtained 'vocab' write a 'tokenize' method which applies the tokenization to an input string s, returns the list of integers corresponding to each symbol (the keys of vocab!)
Verify your result on dataset['test'][0]

In [13]:
def tokenize(s: str, vocab):
    tokens = tokenize_to_str_list(s, vocab)
    return [vocab[elt] for elt in tokens]

tokenize(dataset['test'][0]['text'], vocab)

[62, 81, 98, 88, 33, 165, 283, 59, 163, 100, 427, 51, 296, 70]

### => Yay, we are not able to map any input text to list of integers, not too long, not too short, of which representatens are not too rare, and which preserves (at least partially) the structure of the words
### 1.8 What are the current limitations of our implementation of BPE ?

- multiple redundant tokenizations e.g. dog! and dog?
- we need to handle unknown characters to not run into errors later on