In [1]:
with open("dataset/hindi_data.txt", 'r', encoding='utf-8') as file:
    hindi_text = file.read()

In [2]:
tokens_hindi = hindi_text.encode('utf-8')

In [3]:
tokens_hindi = list(map(int,tokens_hindi))

In [4]:
print(len(tokens_hindi), len(set(tokens_hindi)))
print(tokens_hindi[:10])

49513 77
[49, 46, 32, 224, 164, 182, 224, 164, 172, 224]


In [5]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids,ids[1:]):
        counts[pair] = counts.get(pair,0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i+=2
        else:
            newids.append(ids[i])
            i+=1
    return newids

In [6]:
tokens = hindi_text.encode("utf-8")
vocab_size  = 2000
num_merges = vocab_size - 256
ids = list(tokens)
print(len(ids))
merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key = stats.get)
    idx = 256 + i
    # print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

49513


In [7]:
print("Original token length:", len(tokens))
print("BPE ids length:", len(ids))
print(f"Compression Ratio: {len(tokens)/ len(ids):.2f}X")

Original token length: 49513
BPE ids length: 5803
Compression Ratio: 8.53X


In [8]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

In [9]:
def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8",errors='replace')
    return text

In [10]:
def encode(text):
    
    def get_stats(ids):
        counts = {}
        for pair in zip(ids,ids[1:]):
            counts[pair] = counts.get(pair,0) + 1
        return counts
    
    def merge(ids, pair, idx):
        newids = []
        i = 0
        while i < len(ids):
            if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
                newids.append(idx)
                i+=2
            else:
                newids.append(ids[i])
                i+=1
        return newids

    tokens = list(text.encode("utf-8"))
    while len(tokens) >=2:
        stats = get_stats(tokens)
        pair = min(stats,key=lambda p: merges.get(p,float("inf")))
        if pair not in merges:
            break
        idx = merges[pair]
        tokens = merge(tokens,pair,idx)
    return tokens

In [11]:
t1 = encode("साईं इतना दीजिये, जा में कुटुम समाय ।मैं भी भूखा न रहूँ, साधु ना भूखा जाय ॥ 8 ॥")
print(t1)

[300, 500, 679, 135, 290, 272, 704, 590, 337, 263, 848, 416, 390, 149, 469, 402, 755, 353, 32, 273, 295, 433, 173, 542, 294, 150, 420, 589, 662, 680, 758, 322, 642, 294, 150, 580, 353, 32, 1738, 56, 32, 264]


In [12]:
print(decode(t1))

साईं इतना दीजिये, जा में कुटुम समाय ।मैं भी भूखा न रहूँ, साधु ना भूखा जाय ॥ 8 ॥


In [14]:
import json
merges_json = {f"{k[0]},{k[1]}": v for k, v in merges.items()}
print(merges_json)

{'224,164': 256, '224,165': 257, '32,256': 258, '256,190': 259, '256,176': 260, '257,128': 261, '259,256': 262, '257,135': 263, '257,165': 264, '257,139': 265, '256,191': 266, '10,256': 267, '130,256': 268, '257,129': 269, '257,141': 270, '261,258': 271, '256,168': 272, '257,164': 273, '263,258': 274, '259,258': 275, '266,256': 276, '260,256': 277, '264,267': 278, '260,258': 279, '256,185': 280, '130,258': 281, '256,149': 282, '256,175': 283, '269,256': 284, '265,258': 285, '257,136': 286, '258,149': 287, '259,260': 288, '264,257': 289, '256,164': 290, '256,178': 291, '256,268': 292, '273,267': 293, '257,268': 294, '256,174': 295, '273,258': 296, '265,256': 297, '261,260': 298, '44,258': 299, '256,184': 300, '270,283': 301, '272,258': 302, '263,256': 303, '170,270': 304, '262,151': 305, '258,172': 306, '256,172': 307, '262,168': 308, '261,256': 309, '258,174': 310, '257,166': 311, '258,151': 312, '262,268': 313, '258,184': 314, '46,267': 315, '174,298': 316, '304,277': 317, '258,170': 

In [15]:
with open("model_files/tsai_hindi_bpe_tokens.json",'w') as f:
    json.dump(merges_json,f)

In [16]:
import pickle
with open(f"model_files/tsai_hindi_vocab.pkl", 'wb') as f:
    pickle.dump(vocab, f)

In [17]:
with open("model_files/tsai_hindi_bpe_tokens.json", 'r') as f:
    merges_json = json.load(f)
merges = {tuple(map(int, k.split(','))): v for k, v in merges_json.items()}

In [18]:
with open("model_files/tsai_hindi_vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)

In [23]:
test_text = "ट्रस्ट के सदस्य भानु प्रकाश ने बताया कि टिकट के लिए 91 काउंटर खोले गए थे। काउंटर के पास 4 हजार से ज्यादा श्रद्धालु लाइन में खड़े थे। उन्हें बैरागी पट्टीडा पार्क में कतार लगाने को कहा गया। आगे जाने की होड़ में अफरा-तफरी मची और भागने के दौरान लोग एक-दूसरे पर चढ़ गए। हादसे में मल्लिका नामक महिला की मौके पर ही मौत हो गई।"
t1 = encode(test_text)
print(t1)

[321, 270, 430, 270, 321, 287, 464, 369, 300, 301, 364, 308, 269, 258, 317, 149, 891, 322, 525, 290, 353, 327, 266, 945, 618, 321, 287, 639, 276, 143, 32, 57, 49, 287, 262, 137, 292, 159, 279, 150, 560, 539, 256, 143, 258, 165, 665, 262, 137, 292, 159, 674, 408, 373, 32, 52, 331, 346, 440, 184, 509, 301, 480, 681, 270, 277, 166, 940, 320, 269, 387, 262, 135, 554, 390, 150, 257, 156, 274, 165, 434, 137, 1109, 390, 1029, 1352, 489, 321, 270, 321, 309, 161, 449, 288, 270, 282, 742, 149, 290, 440, 1067, 308, 393, 285, 603, 1094, 283, 495, 795, 509, 308, 393, 271, 398, 257, 156, 742, 133, 1044, 260, 259, 45, 290, 1044, 329, 174, 471, 271, 148, 279, 173, 305, 272, 393, 1883, 421, 260, 308, 387, 620, 258, 545, 45, 369, 294, 1321, 408, 279, 154, 257, 157, 312, 256, 143, 296, 185, 480, 300, 499, 390, 174, 291, 270, 291, 618, 420, 338, 282, 310, 280, 410, 327, 431, 421, 282, 408, 279, 185, 431, 421, 290, 1010, 151, 383, 273]


In [24]:
print(decode(t1))

ट्रस्ट के सदस्य भानु प्रकाश ने बताया कि टिकट के लिए 91 काउंटर खोले गए थे। काउंटर के पास 4 हजार से ज्यादा श्रद्धालु लाइन में खड़े थे। उन्हें बैरागी पट्टीडा पार्क में कतार लगाने को कहा गया। आगे जाने की होड़ में अफरा-तफरी मची और भागने के दौरान लोग एक-दूसरे पर चढ़ गए। हादसे में मल्लिका नामक महिला की मौके पर ही मौत हो गई।


In [25]:
print(test_text == decode(t1))

True
