In [None]:
#Train Tokeniser and Save to File
import time
import pickle

In [None]:
#Hyperparameters
PercentTraining = 1
VocabSize = 2500

In [None]:
#Load input text (Lorem Ipsum)
with open("LoremIpsum.txt", "r", encoding="utf-8") as f:
    text = f.read()

In [None]:
#Load input text (Shakespeare)
#with open("Tiny Shakespeare.txt", "r", encoding="utf-8") as f:
#    text = f.read()

In [None]:
#Load input text (XFM)
#with open("XFM.txt", "r", encoding="utf-8") as f:
#    text = f.read()

In [None]:
#Print first 1000 characters
print(text[:1000])

In [None]:
#Decode string from tokens
def decode(tokens):
    str = []
    index = 0
    offset = 0
    while index - offset < len(tokens):
        offset = 0
        token = tokens[index]
        #print(token)
        if token in Initialvocab:
            str.append(token)
        else:
            #print(merges)
            if token in merges.values():
                 for key, value in merges.items():
                    if value == token:
                        #print(token)
                        tokens.insert(index + 1, key[0])
                        tokens.insert(index + 2, key[1])
                        offset = 1
                        break
        index +=1
    return bytes(str).decode("utf-8", errors="replace")

In [None]:
#Encode text as bytes
tokenisedText = list(map(int, text.encode("utf-8")))
#Split text into training and testing data
n = int(PercentTraining*len(text))
TrainingData = tokenisedText[:n]
TestingData = tokenisedText[n:]
print("Initial length: " + str(len(TrainingData)))

In [None]:
def getPairFreqs(text):
    freqs = {}
    for pair in zip(text, text[1:]):
        try:
            freqs[pair] +=1
        except KeyError:
            freqs[pair] = 1
    return freqs

In [None]:
def merge(text, pair, newChar):
    newText = []
    i=0
    while i < len(text):
        if i < len(text) - 1 and (text[i], text[i+1]) == (pair[0], pair[1]):
            newText.append(newChar)
            i+=2
        else:
            newText.append(text[i])
            i+=1
    return newText

In [None]:
#Train the tokeniser using BPE.
start_time = time.time()
Initialvocab = sorted(list(set(TrainingData)))
#We remove the characters that aren't UTF-8 compatible from initialVocabulary before tokeniser training
for word in Initialvocab[:]:
    try:
        bytes([word]).decode("utf-8")
    except:
        Initialvocab.remove(word)
vocab = Initialvocab.copy()
#print(vocab)
newChar = max(vocab) + 1
merges = {}
while len(vocab) < VocabSize:
    freqs = getPairFreqs(TrainingData)
    topPair = max(freqs, key=freqs.get)
    TrainingData = merge(TrainingData, topPair, newChar)
    vocab.append(newChar)
    merges[topPair] = newChar
    newChar +=1
    if len(vocab) % 25 == 0:
        new_time = time.time()
        print(len(vocab))
        print(str(new_time - start_time) + " seconds elapsed")
        start_time = start_time = time.time()
#My code is inefficient in the sense that there are gaps in the vocabulary e.g. we might have token "10" and then token "32",
#i.e. no tokens are indexed between 10 and 32. Whilst the number of tokens in the Vocabulary is still VocabSize as I defined it,
#We need to update VocabSize to be the highest numbered token in the vocabulary, so that we can define our one hot encoding correctly.
#This does create some redundancy in entries of our matrices, which is inefficient and I would have addressed if I had more time.
#However as VocabSize gets large, these redunancies become a smaller % of the VocabSize as we aren't introducing any new redundancies
#through the merging process
VocabSize = newChar
print("Tokenised length: " + str(len(TrainingData)))
print(decode(vocab))

# Save the variables using pickle so that I only need to train the tokeniser once
with open('initialvocab.pkl', 'wb') as f:
    pickle.dump(Initialvocab, f)

with open('vocab.pkl', 'wb') as f:
    pickle.dump(vocab, f)

with open('vocabsize.pkl', 'wb') as f:
    pickle.dump(VocabSize, f)

with open('merges.pkl', 'wb') as f:
    pickle.dump(merges, f)

with open('trainingdata.pkl', 'wb') as f:
    pickle.dump(TrainingData, f)