In [2]:
import os

In [3]:
def load_training_data():
    possible_paths = [
        '../data/input.txt',
        './data/input.txt',
        'data/input.txt',
        'input.txt'
    ]
    for path in possible_paths:
        if os.path.exists(path):
            print(f"Loading data from: {path}")
            with open(path, 'r', encoding='utf-8') as f:
                return f.read()
    print("No input.txt found, using example text")
    return """
    Hello world! This is a simple example text for training an LSTM model.
    The model will learn to predict the next character in a sequence.
    """

text = load_training_data()

Loading data from: ./data/input.txt


In [4]:
tokens = text.encode('utf-8')
print(text[:20])

tokens = list(map(int, tokens))
print(tokens[:20])

print("length of text: ", len(text))
print("length of tokens: ", len(tokens))

First Citizen:
Befor
[70, 105, 114, 115, 116, 32, 67, 105, 116, 105, 122, 101, 110, 58, 10, 66, 101, 102, 111, 114]
length of text:  1115394
length of tokens:  1115394


In [5]:
def get_stats(tokens):
    counts = {}
    for pair in zip(tokens, tokens[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

counts = get_stats(tokens)
print(sorted(((v, k) for k, v in counts.items()), reverse=True))

[(27643, (101, 32)), (23837, (32, 116)), (22739, (116, 104)), (18203, (104, 101)), (16508, (116, 32)), (15364, (115, 32)), (14165, (100, 32)), (14098, (44, 32)), (13541, (32, 97)), (12730, (111, 117)), (12287, (32, 115)), (11925, (32, 104)), (11771, (101, 114)), (10786, (32, 109)), (10606, (105, 110)), (10546, (32, 119)), (10516, (114, 32)), (10283, (121, 32)), (10197, (97, 110)), (9843, (114, 101)), (9673, (104, 97)), (9306, (110, 100)), (8762, (58, 10)), (8737, (110, 32)), (8463, (32, 98)), (8458, (111, 114)), (8339, (97, 116)), (8134, (111, 32)), (7568, (101, 110)), (7526, (105, 115)), (7410, (32, 111)), (7223, (10, 10)), (7166, (32, 105)), (7081, (97, 114)), (6991, (104, 105)), (6823, (115, 116)), (6676, (46, 10)), (6563, (32, 102)), (6478, (101, 115)), (6435, (111, 110)), (6357, (108, 108)), (6288, (101, 97)), (6135, (109, 101)), (6114, (105, 116)), (5986, (118, 101)), (5860, (116, 111)), (5507, (115, 101)), (5505, (44, 10)), (5478, (32, 100)), (5410, (108, 32)), (5404, (32, 99)),

In [6]:
top_pair = max(counts, key=counts.get)
print(top_pair)

(101, 32)


In [7]:
def merge(ids, pair, idx):
    newid = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1:
            get_pair = (ids[i], ids[i+1])
            if get_pair == pair:
                newid.append(idx)
                i += 2
            else:
                newid.append(ids[i])
                i += 1
        else:
            newid.append(ids[i])
            i += 1
    return newid

print(merge([1, 2, 3, 4, 3, 4], (4, 3), 99))

[1, 2, 3, 99, 4]


In [8]:
vocab_size = 1000
num_merges = vocab_size - 256
ids = list(tokens)

merges = {} # (int, int) -> int
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    merges[pair] = idx
    print(f"Merging {pair} into {idx}")
    ids = merge(ids, pair, idx)


Merging (101, 32) into 256
Merging (116, 104) into 257
Merging (116, 32) into 258
Merging (115, 32) into 259
Merging (100, 32) into 260
Merging (44, 32) into 261
Merging (111, 117) into 262
Merging (101, 114) into 263
Merging (105, 110) into 264
Merging (121, 32) into 265
Merging (97, 110) into 266
Merging (58, 10) into 267
Merging (111, 114) into 268
Merging (111, 32) into 269
Merging (101, 110) into 270
Merging (10, 10) into 271
Merging (97, 114) into 272
Merging (32, 257) into 273
Merging (111, 110) into 274
Merging (108, 108) into 275
Merging (104, 97) into 276
Merging (44, 10) into 277
Merging (46, 271) into 278
Merging (105, 259) into 279
Merging (101, 115) into 280
Merging (121, 262) into 281
Merging (32, 115) into 282
Merging (116, 269) into 283
Merging (266, 260) into 284
Merging (111, 119) into 285
Merging (101, 97) into 286
Merging (32, 109) into 287
Merging (32, 119) into 288
Merging (111, 102) into 289
Merging (32, 104) into 290
Merging (264, 103) into 291
Merging (111, 10

In [9]:
print("token length: ", len(tokens))
print("ids length: ", len(ids))
print("compression ratio: ", len(tokens) / len(ids))

token length:  1115394
ids length:  447069
compression ratio:  2.494903471276246


In [10]:
vocab = {idx: bytes([idx]) for idx in range(256)}

for (p0, p1), idx in merges.items():
    pair = vocab[p0] + vocab[p1]
    vocab[idx] = pair

print(vocab)

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors = 'replace')
    return text

print(decode([260]))

{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91: b'[',

In [11]:
def encode(text):
    tokens = list(text.encode("utf-8"))
    while True:
        stats = get_stats(tokens)
        best = min(stats, key = lambda p: stats.get(p, float('-inf')))
        if best not in merges:
            break
        idx = merges[best]
        tokens = merge(tokens, best, idx)
    return tokens

print(encode("Hello"))

[72, 101, 108, 108, 111]
