In [19]:
import tiktoken

In [20]:
#initialize byte pair encoding

tokenizer=tiktoken.get_encoding('gpt2')


In [21]:
#text to be encoded
text=(
    'hello do you like coffee?<|endoftext|> yes i like'
)
#call encode method which return ids of subword token
integers=tokenizer.encode(text,allowed_special={'<|endoftext|>'})
print(integers)

[31373, 466, 345, 588, 6891, 30, 50256, 3763, 1312, 588]


In [22]:
#now convert back token id back to text or decode

strings=tokenizer.decode(integers)
print(strings)

hello do you like coffee?<|endoftext|> yes i like


# lets implement BPE from scratch

In [23]:
vocab={}#maps ID to string oe character
str_to_id={}#maps string to id inverse of vocab
merges={}# maps (id1,id2) to new id eg id12


In [39]:
text='the man sat on the chair'
vocab_size=50
print(text)

the man sat on the chair


In [40]:
text=text.replace(' ','_')# replace space with _ eg hey you becomes hey_you
print(text)

the_man_sat_on_the_chair


In [41]:
chars=sorted(set(text)) #sorted unique characters
print(chars)

['_', 'a', 'c', 'e', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't']


In [42]:
vocab={i:c for i,c in enumerate(chars)}#initialize vocabulary
str_to_id={c:i for i,c in enumerate(chars)}

print(vocab)

{0: '_', 1: 'a', 2: 'c', 3: 'e', 4: 'h', 5: 'i', 6: 'm', 7: 'n', 8: 'o', 9: 'r', 10: 's', 11: 't'}


In [None]:
token_ids=[str_to_id[i] for i in text] #token id in the order of how the characters comes in the text not sorted

print(token_ids)

[11, 4, 3, 0, 6, 1, 7, 0, 10, 1, 11, 0, 8, 7, 0, 11, 4, 3, 0, 2, 4, 1, 5, 9]


## find most frequent pairs

In [60]:
from collections import Counter

def find_most_frequent_pairs(token_ids):
    pairs=Counter(zip(token_ids,token_ids[1:])) #the zip function help us to do sliding window with size two read about zip()
    print(max(pairs))
    print(len(pairs))

    return max(pairs,key=pairs.get) if pairs else None
find_most_frequent_pairs(token_ids)

print(len(token_ids))
print(token_ids)

(11, 4)
19
24
[11, 4, 3, 0, 6, 1, 7, 0, 10, 1, 11, 0, 8, 7, 0, 11, 4, 3, 0, 2, 4, 1, 5, 9]


## replace all occurence of pair with new id

In [None]:
def replace_pair(token_ids,pair,new_id):
    res=[121]
    i=0
    while i<len(token_ids):
        #check that atleast two id exists
        if i < len(token_ids)-1 and (token_ids[i],token_ids[i+1])==pair:
            res.append(new_id)
            i+=2
        else:
            res.append(token_ids[i])
            i+=1
    return res
            

replace_pair(token_ids,(11, 4),77)


[77, 3, 0, 6, 1, 7, 0, 10, 1, 11, 0, 8, 7, 0, 77, 3, 0, 2, 4, 1, 5, 9]

In [None]:
def merge_frequent(token_ids):
    new_id=len(vocab)
    #merge until vocabulary size is reached

 #vocab size must be greater than len(vocab) initial without it doesnt make sense
    while new_id<vocab_size:
        pair=find_most_frequent_pairs(token_ids)
        if not pair:
            break
        merges[pair]=new_id
        token_ids=replace_pair(token_ids,pair,new_id)
        
        merged_str=vocab[pair[0]] + vocab[pair[1]]

        vocab[new_id]=merged_str
        

        str_to_id[merged_str]=new_id

        new_id+=1

    print(token_ids)



merge_frequent(token_ids)

Counter({(11, 4): 2, (4, 3): 2, (3, 0): 2, (7, 0): 2, (0, 6): 1, (6, 1): 1, (1, 7): 1, (0, 10): 1, (10, 1): 1, (1, 11): 1, (11, 0): 1, (0, 8): 1, (8, 7): 1, (0, 11): 1, (0, 2): 1, (2, 4): 1, (4, 1): 1, (1, 5): 1, (5, 9): 1})
Counter({(12, 3): 2, (3, 0): 2, (7, 0): 2, (0, 6): 1, (6, 1): 1, (1, 7): 1, (0, 10): 1, (10, 1): 1, (1, 11): 1, (11, 0): 1, (0, 8): 1, (8, 7): 1, (0, 12): 1, (0, 2): 1, (2, 4): 1, (4, 1): 1, (1, 5): 1, (5, 9): 1})
Counter({(13, 0): 2, (7, 0): 2, (0, 6): 1, (6, 1): 1, (1, 7): 1, (0, 10): 1, (10, 1): 1, (1, 11): 1, (11, 0): 1, (0, 8): 1, (8, 7): 1, (0, 13): 1, (0, 2): 1, (2, 4): 1, (4, 1): 1, (1, 5): 1, (5, 9): 1})
Counter({(7, 0): 2, (14, 6): 1, (6, 1): 1, (1, 7): 1, (0, 10): 1, (10, 1): 1, (1, 11): 1, (11, 0): 1, (0, 8): 1, (8, 7): 1, (0, 14): 1, (14, 2): 1, (2, 4): 1, (4, 1): 1, (1, 5): 1, (5, 9): 1})
Counter({(14, 6): 1, (6, 1): 1, (1, 15): 1, (15, 10): 1, (10, 1): 1, (1, 11): 1, (11, 0): 1, (0, 8): 1, (8, 15): 1, (15, 14): 1, (14, 2): 1, (2, 4): 1, (4, 1): 1, (1

In [47]:
print(merges)

{(11, 4): 12, (12, 3): 13, (13, 0): 14, (7, 0): 15, (14, 6): 16, (16, 1): 17, (17, 15): 18, (18, 10): 19, (19, 1): 20, (20, 11): 21, (21, 0): 22, (22, 8): 23, (23, 15): 24, (24, 14): 25, (25, 2): 26, (26, 4): 27, (27, 1): 28, (28, 5): 29, (29, 9): 30}


# encode

In [48]:
def encode(text):# str to id aka integer

    text=text.replace(' ','*')

    #convert to token id from already trained or created

    token_ids=[str_to_id[i] for i in text]
    #apply merge
    while len(token_ids)>1:
        #find the earliest merge,lowest id
        current_pair=None
        current_id=float('inf')
        for pair in zip(token_ids,token_ids[1:]):
            if pair in merges and merges[pair]<current_id:
                current_pair=pair
                current_id=merges[pair]
        if current_pair is None:#if current pair is not in the merge break the loop
            break
        token_ids=replace_pair(token_ids,current_pair,current_id)
    print(token_ids)
    return token_ids



            
            
encode(text)


print(vocab)

[30]
{0: '_', 1: 'a', 2: 'c', 3: 'e', 4: 'h', 5: 'i', 6: 'm', 7: 'n', 8: 'o', 9: 'r', 10: 's', 11: 't', 12: 'th', 13: 'the', 14: 'the_', 15: 'n_', 16: 'the_m', 17: 'the_ma', 18: 'the_man_', 19: 'the_man_s', 20: 'the_man_sa', 21: 'the_man_sat', 22: 'the_man_sat_', 23: 'the_man_sat_o', 24: 'the_man_sat_on_', 25: 'the_man_sat_on_the_', 26: 'the_man_sat_on_the_c', 27: 'the_man_sat_on_the_ch', 28: 'the_man_sat_on_the_cha', 29: 'the_man_sat_on_the_chai', 30: 'the_man_sat_on_the_chair'}


# decode

In [53]:
def decode(token_ids):

    text=''
    for id in token_ids:
        token=vocab[id]
        token=token.replace('_',' ')
        text+=token
    return text
decode([19, 1, 11, 0, 8, 15, 5,])

'the man sat on i'

## BPE From Scratch With OOP

In [80]:
class BPE_From_Scratch:
    def __init__(self):
        self.vocabulary={}
        self.token_to_id={}
        self.merges={}


    def frequent(self,token_ids):
        pair=Counter(zip(token_ids,token_ids[1:]))

        if pair:
            return max(pair)
        return None

        

    def replace(self,token_ids,pair,new_id):
        res=[]
        i=0

        while i<len(token_ids):
            if i<len(token_ids)-1 and (token_ids[i],token_ids[i+1])==pair:
                res.append(new_id)
                i+=2
            else:
                res.append(token_ids[i])
                i+=1
        return res

    def merge(self,token_ids,vocab_size):

        new_id=len(self.vocabulary)

        while new_id<vocab_size:
            pair=self.frequent(token_ids)
            

            if pair is None:
                break

            self.merges[pair]=new_id

            token_ids=self.replace(token_ids,pair,new_id)

            merged_char=self.vocabulary[pair[0]]+self.vocabulary[pair[1]]

            self.vocabulary[new_id]=merged_char
            self.token_to_id[merged_char]=new_id

            new_id+=1
        return self.vocabulary
        
    
    def train(self,text,vocab_size):
        text=text.replace(' ','<space>')
        chars=sorted(set(text))
        self.vocabulary={i:c for i,c in enumerate(chars)}
        self.token_to_id={c:i for i,c in enumerate(chars)}

        token_ids=[self.token_to_id[i] for i in text]

        generated_vocab=self.merge(token_ids,vocab_size)

        return generated_vocab



        

    def encode(self,text):
        pass
    def decode(self,token_ids):
        pass

In [81]:
bpe=BPE_From_Scratch()
bpe.train(text=text,vocab_size=20)

{0: '_',
 1: 'a',
 2: 'c',
 3: 'e',
 4: 'h',
 5: 'i',
 6: 'm',
 7: 'n',
 8: 'o',
 9: 'r',
 10: 's',
 11: 't',
 12: 'th',
 13: 'the',
 14: 'the_',
 15: 'the_m',
 16: 'the_ma',
 17: 'the_man',
 18: 'the_man_',
 19: 'the_man_s'}