In [1]:
with open("hindi_data.txt", 'r', encoding='utf-8') as file:
    hindi_text = file.read()

In [2]:
tokens_hindi = hindi_text.encode('utf-8')

In [3]:
tokens_hindi = list(map(int,tokens_hindi))

In [4]:
print(len(tokens_hindi), len(set(tokens_hindi)))
print(tokens_hindi[:10])

49513 77
[49, 46, 32, 224, 164, 182, 224, 164, 172, 224]


In [5]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids,ids[1:]):
        counts[pair] = counts.get(pair,0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i+=2
        else:
            newids.append(ids[i])
            i+=1
    return newids

In [6]:
tokens = hindi_text.encode("utf-8")
vocab_size  = 2700
num_merges = vocab_size - 256
ids = list(tokens)
print(len(ids))
merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key = stats.get)
    idx = 256 + i
    # print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

49513


In [8]:
print("Original token length:", len(tokens))
print("BPE ids length:", len(ids))
print(f"Compression Ratio: {len(tokens)/ len(ids):.2f}X")

Original token length: 49513
BPE ids length: 4955
Compression Ratio: 9.99X


In [9]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

In [10]:
def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8",errors='replace')
    return text

In [11]:
def encode(text):
    
    def get_stats(ids):
        counts = {}
        for pair in zip(ids,ids[1:]):
            counts[pair] = counts.get(pair,0) + 1
        return counts
    
    def merge(ids, pair, idx):
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i+=2
            else:
                newids.append(ids[i])
                i+=1
        return newids

    tokens = list(text.encode("utf-8"))
    while len(tokens) >=2:
        stats = get_stats(tokens)
        pair = min(stats,key=lambda p: merges.get(p,float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens,pair,idx)
    return tokens

In [12]:
t1 = encode("साईं इतना दीजिये, जा में कुटुम समाय ।मैं भी भूखा न रहूँ, साधु ना भूखा जाय ॥ 8 ॥")
print(t1)

[300, 500, 679, 135, 290, 272, 704, 590, 337, 263, 848, 416, 390, 149, 469, 402, 755, 353, 32, 273, 295, 433, 173, 542, 294, 150, 420, 589, 662, 680, 758, 322, 642, 294, 150, 580, 353, 32, 1738, 56, 32, 264]


In [13]:
print(decode(t1))

साईं इतना दीजिये, जा में कुटुम समाय ।मैं भी भूखा न रहूँ, साधु ना भूखा जाय ॥ 8 ॥


In [None]:
# import json
# merges_json = {f"{k[0]},{k[1]}": v for k, v in merges.items()}
# print(merges_json)

In [39]:
# with open("tsai_hindi_bpe_tokens.json",'w') as f:
#     json.dump(merges_json,f)

In [40]:
# import pickle
# with open(f"tsai_hindi_vocab.pkl", 'wb') as f:
#     pickle.dump(vocab, f)

In [2]:
# import json
# import pickle

In [3]:
# with open("tsai_hindi_bpe_tokens.json", 'r') as f:
#     merges_json = json.load(f)
# merges = {tuple(map(int, k.split(','))): v for k, v in merges_json.items()}

In [4]:
# with open("tsai_hindi_vocab.pkl", 'rb') as f:
#     vocab = pickle.load(f)

In [None]:
# t1 = encode("साईं इतना दीजिये, जा में कुटुम समाय ।मैं भी भूखा न रहूँ, साधु ना भूखा जाय ॥ 8 ॥")
# print(t1)

In [None]:
# print(decode(t1))