In [44]:
import regex as re

In [18]:
text = "Predictive Processing is a process theory for the neocortex. It posits that the brain makes " \
        "use of generative model in order to solve the inverse problem of perception, that is, inferring " \
        "the causes of sensory observations"

tokens = text.encode("utf-8") # convert text to bytes
tokens = list(map(int, tokens)) # convert bytes to integers
print(text)
length = len(text)
print(f"length: {length}")
print(tokens)
print(f"length: {len(tokens)}")
max_value = max(tokens)
print(f"max value: {max_value}")

Predictive Processing is a process theory for the neocortex. It posits that the brain makes use of generative model in order to solve the inverse problem of perception, that is, inferring the causes of sensory observations
length: 222
[80, 114, 101, 100, 105, 99, 116, 105, 118, 101, 32, 80, 114, 111, 99, 101, 115, 115, 105, 110, 103, 32, 105, 115, 32, 97, 32, 112, 114, 111, 99, 101, 115, 115, 32, 116, 104, 101, 111, 114, 121, 32, 102, 111, 114, 32, 116, 104, 101, 32, 110, 101, 111, 99, 111, 114, 116, 101, 120, 46, 32, 73, 116, 32, 112, 111, 115, 105, 116, 115, 32, 116, 104, 97, 116, 32, 116, 104, 101, 32, 98, 114, 97, 105, 110, 32, 109, 97, 107, 101, 115, 32, 117, 115, 101, 32, 111, 102, 32, 103, 101, 110, 101, 114, 97, 116, 105, 118, 101, 32, 109, 111, 100, 101, 108, 32, 105, 110, 32, 111, 114, 100, 101, 114, 32, 116, 111, 32, 115, 111, 108, 118, 101, 32, 116, 104, 101, 32, 105, 110, 118, 101, 114, 115, 101, 32, 112, 114, 111, 98, 108, 101, 109, 32, 111, 102, 32, 112, 101, 114, 99, 10

In [30]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)

top_pair = max(stats, key=stats.get)

def merge(ids, pair, idx):
    new_ids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and (ids[i], ids[i+1]) == pair:
            new_ids.append(idx)
            i += 2
        else:
            new_ids.append(ids[i])
            i += 1

    return new_ids

In [40]:
vocab_size = 276 # the desired final vocab size
num_merges = vocab_size - 256
ids = list(tokens)

merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    top_pair = max(stats, key=stats.get)
    idx = 256 + i
    ids = merge(ids, top_pair, idx)
    merges[top_pair] = idx

vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    text = tokens.decode("utf-8", errors="replace")
    return text

def encode(text):
    tokens = list(text.encode("utf-8"))
    for pair, idx in merges.items():
        tokens = merge(tokens, pair, idx)
    return tokens

print(decode(encode("Hello world, my name is Joe and I am a neuroscientist!")))

Hello world, my name is Joe and I am a neuroscientist!


In [45]:
gpt2pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
print(re.findall(gpt2pat, "Hello world"))

['Hello', ' world']


In [42]:
for pair, idx in merges.items():
    print(decode([idx]))


e 
 t
in
 th
er
or
ti
es
ve 
ro
 p
 the 
of
tive 
roc
roces
rocess
ing
 i
 is


In [22]:
print("tokens length: ", len(tokens))
print("ids length: ", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length:  222
ids length:  146
compression ratio: 1.52X
