In [1]:
import os, json
from datasets import load_dataset


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Vocabulary:
    """Class to map codes from huggingface dataset to tokens in Llama 3-8B token"""

    def __init__(self):
        self.stoi = {}
        self.itos = {}
    
    def build_vocabulary(self, parquet_files, tokenizer_file="tokenizer.json"):
        '''
        creates the vocabulary from the Llama 3 tokenizer and hugging face dataset
        Args:
            tokenizer_file(str): file downloaded from Llama 3(8B) which contains the vocabulary for the model
            parquet_files(list): director with the dataset from hugging face in parquet format

        '''
        # Open the JSON file
        with open(tokenizer_file, 'r') as file:
            # Load the JSON data
            data = json.load(file)
        
        llama_stoi = data['model']['vocab']
        llama_itos = {value:key for key,value in llama_stoi.items()}

        #load hugging face data
        dataset = load_dataset('parquet', data_files=parquet_files)
        vocabulary = set()

        for sent in dataset["train"]["txt"]:
            for word in sent.split():
                vocabulary.add(word)
        
        self.itos = {int(value):llama_itos[int(value)] for value in vocabulary}
        self.stoi = {value:key for key,value in self.itos.items()}
    
    def save(self, file_path):
        with open(file_path, "w") as file:
            json.dump(self.itos, file)

In [3]:
os.chdir("..")

In [4]:
train_dir = [f"dataset/default/partial-train/000{i}.parquet" for i in range(10)]
dataset = load_dataset('parquet', data_files=train_dir)
txt = dataset["train"]["txt"]


In [5]:
type(txt)

list

## Applying BPE

using the following link for implimentation: https://huggingface.co/learn/nlp-course/en/chapter6/5

In [6]:
type(txt[0])

str

In [7]:
from collections import defaultdict

def add_special_character(corpus, special_char='▁'):
    modified_corpus = []
    for sentence in corpus:
        modified_sentence = ''
        words = []
        previous_char_is_space = False
        
        for char in sentence:
            if char == ' ':
                previous_char_is_space = True
                words.append(modified_sentence)
                modified_sentence = ''
            elif previous_char_is_space:
                modified_sentence += special_char + char
                previous_char_is_space = False
            else:
                modified_sentence += char
        
        modified_corpus.extend(words)
    
    return modified_corpus

# Example usage
corpus = [t for t in txt]
modified_corpus = add_special_character(corpus)
word_freqs = defaultdict(int)
for word in modified_corpus:
    word_freqs[word] += 1


In [8]:
modified_corpus[0:10]

['896',
 '▁2029',
 '▁935',
 '▁679',
 '▁1115',
 '▁3601',
 '▁3000',
 '▁222',
 '▁3446',
 '▁2218']

In [9]:
len(word_freqs)

4435

In [10]:
for i, item in enumerate(word_freqs.items()):
    print(item)
    if i == 7:
        break

('896', 5)
('▁2029', 4261)
('▁935', 6058)
('▁679', 8699)
('▁1115', 1982)
('▁3601', 639)
('▁3000', 635)
('▁222', 27691)


In [11]:
alphabet = set()

for word in word_freqs.keys():
    alphabet.add(word)
alphabet = sorted(alphabet)

print(len(alphabet))
print(alphabet[-10:-1])

4435
['▁990', '▁991', '▁992', '▁993', '▁994', '▁995', '▁996', '▁997', '▁998']


In [28]:
def preprocess_corpus(text, special_char='_'):
    output = []
    for sentence in text:
        modified_sentence = ''
        words = []
        full_sent = ""
        previous_char_is_space = False
        
        for char in sentence:
            if char == ' ':
                previous_char_is_space = True
                words.append(modified_sentence)
                full_sent += modified_sentence
                modified_sentence = ''
            elif previous_char_is_space:
                modified_sentence += special_char + char
                previous_char_is_space = False
            else:
                modified_sentence += char
        
        output.append((full_sent, words))
    
    return output

In [29]:
print(str(txt[0]))


896 2029 935 679 1115 3601 3000 222 3446 2218 3072 550 3652 665 2596 2809 3649 251 2610 2536 47 2852 2940 3353 3400 3336 325 2647 4076 3653 3253 58 3664 1424 1388 222 278 897 447 2355 2453 2531 2712 828 2895 2398 2908 901 2536 222 3686 2620 3254 3962 0 1448 222 863 3593 124 124 1048 1593 222 4086 2647 3236 1767 2800 697 514 3648 2337 1338 1114 340 3514 4076 2658 1954 3867 2300 251 317 7 1091 1768 1440 3167 672 1253 188 3544 2934 1368 479 3951 3387 514 2438 1262 3166 462 3530 333 2596 3808 2796 1920 794 263 2626 2596 1949 57 3990 3785 146 404 3731 479 3840 3840 3664 940 2550 4076 544 3465 3232 269 79 2159 3879 1734 3900 755 1756 818 800 1249 171 319 727 171 3698 3683 2596 3969 2431 1838 3969 126 2673 2596 4012 1010 2151 3437 417 2386 2712 3705 1838 3428 1168 1838 1527 3885 1952 2443 3997 3562 1667 3651 3981 2426 1494 1532 2426 3602 1855


In [30]:
preprocess_corpus(txt)[0]

('896_2029_935_679_1115_3601_3000_222_3446_2218_3072_550_3652_665_2596_2809_3649_251_2610_2536_47_2852_2940_3353_3400_3336_325_2647_4076_3653_3253_58_3664_1424_1388_222_278_897_447_2355_2453_2531_2712_828_2895_2398_2908_901_2536_222_3686_2620_3254_3962_0_1448_222_863_3593_124_124_1048_1593_222_4086_2647_3236_1767_2800_697_514_3648_2337_1338_1114_340_3514_4076_2658_1954_3867_2300_251_317_7_1091_1768_1440_3167_672_1253_188_3544_2934_1368_479_3951_3387_514_2438_1262_3166_462_3530_333_2596_3808_2796_1920_794_263_2626_2596_1949_57_3990_3785_146_404_3731_479_3840_3840_3664_940_2550_4076_544_3465_3232_269_79_2159_3879_1734_3900_755_1756_818_800_1249_171_319_727_171_3698_3683_2596_3969_2431_1838_3969_126_2673_2596_4012_1010_2151_3437_417_2386_2712_3705_1838_3428_1168_1838_1527_3885_1952_2443_3997_3562_1667_3651_3981_2426_1494_1532_2426_3602',
 ['896',
  '_2029',
  '_935',
  '_679',
  '_1115',
  '_3601',
  '_3000',
  '_222',
  '_3446',
  '_2218',
  '_3072',
  '_550',
  '_3652',
  '_665',
  '_25

In [31]:
for sent, word in preprocess_corpus(txt):
    print(sent)
    print(word)
    break

896_2029_935_679_1115_3601_3000_222_3446_2218_3072_550_3652_665_2596_2809_3649_251_2610_2536_47_2852_2940_3353_3400_3336_325_2647_4076_3653_3253_58_3664_1424_1388_222_278_897_447_2355_2453_2531_2712_828_2895_2398_2908_901_2536_222_3686_2620_3254_3962_0_1448_222_863_3593_124_124_1048_1593_222_4086_2647_3236_1767_2800_697_514_3648_2337_1338_1114_340_3514_4076_2658_1954_3867_2300_251_317_7_1091_1768_1440_3167_672_1253_188_3544_2934_1368_479_3951_3387_514_2438_1262_3166_462_3530_333_2596_3808_2796_1920_794_263_2626_2596_1949_57_3990_3785_146_404_3731_479_3840_3840_3664_940_2550_4076_544_3465_3232_269_79_2159_3879_1734_3900_755_1756_818_800_1249_171_319_727_171_3698_3683_2596_3969_2431_1838_3969_126_2673_2596_4012_1010_2151_3437_417_2386_2712_3705_1838_3428_1168_1838_1527_3885_1952_2443_3997_3562_1667_3651_3981_2426_1494_1532_2426_3602
['896', '_2029', '_935', '_679', '_1115', '_3601', '_3000', '_222', '_3446', '_2218', '_3072', '_550', '_3652', '_665', '_2596', '_2809', '_3649', '_251', '_

In [32]:
splits = {sentence: [c for c in word] for (sentence, word) in preprocess_corpus(txt)}

In [33]:
for i, item in enumerate(splits.items()):
    print(item)
    if i == 7:
        break

('896_2029_935_679_1115_3601_3000_222_3446_2218_3072_550_3652_665_2596_2809_3649_251_2610_2536_47_2852_2940_3353_3400_3336_325_2647_4076_3653_3253_58_3664_1424_1388_222_278_897_447_2355_2453_2531_2712_828_2895_2398_2908_901_2536_222_3686_2620_3254_3962_0_1448_222_863_3593_124_124_1048_1593_222_4086_2647_3236_1767_2800_697_514_3648_2337_1338_1114_340_3514_4076_2658_1954_3867_2300_251_317_7_1091_1768_1440_3167_672_1253_188_3544_2934_1368_479_3951_3387_514_2438_1262_3166_462_3530_333_2596_3808_2796_1920_794_263_2626_2596_1949_57_3990_3785_146_404_3731_479_3840_3840_3664_940_2550_4076_544_3465_3232_269_79_2159_3879_1734_3900_755_1756_818_800_1249_171_319_727_171_3698_3683_2596_3969_2431_1838_3969_126_2673_2596_4012_1010_2151_3437_417_2386_2712_3705_1838_3428_1168_1838_1527_3885_1952_2443_3997_3562_1667_3651_3981_2426_1494_1532_2426_3602', ['896', '_2029', '_935', '_679', '_1115', '_3601', '_3000', '_222', '_3446', '_2218', '_3072', '_550', '_3652', '_665', '_2596', '_2809', '_3649', '_251'

In [84]:
print(len(word_freqs), len(splits))

4435 4435


In [65]:
def compute_pair_freqs(splits, word_freqs=word_freqs):
    pair_freqs = defaultdict(int)
    for word, freq in word_freqs.items():
        split = splits[word]
        if len(split) == 1:
            continue
        for i in range(len(split) - 1):
            pair = (split[i], split[i + 1])
            pair_freqs[pair] += freq
    return pair_freqs

In [67]:
pair_freqs = compute_pair_freqs(splits)

In [70]:
for i, (key, val) in enumerate(pair_freqs.items()):
    print((key, val))
    if i == 7:
        break

(('8', '9'), 131938)
(('9', '6'), 170831)
(('▁', '2'), 2060412)
(('2', '0'), 312772)
(('0', '2'), 107433)
(('2', '9'), 233446)
(('▁', '9'), 250512)
(('9', '3'), 171191)


In [71]:
best_pair = ""
max_freq = None

for pair, freq in pair_freqs.items():
    if max_freq is None or max_freq < freq:
        best_pair = pair
        max_freq = freq

print(best_pair, max_freq)

('▁', '2') 2060412
