In [11]:
import matplotlib.pyplot as plt
import numpy as np
import regex as re
import math


In [12]:
file = open('text.txt', 'r')
text = file.read()
file.close()
tokens = [int(token) for token in text.encode()]
print(f"Length of text: {len(text)}")
print(f"Length of tokens: {len(tokens)}")

Length of text: 23269
Length of tokens: 24531


In [13]:
def getbpc(bytes):
    bpc = dict()
    for b1, b2 in zip(bytes, bytes[1:]):
        cnt = bpc.get((b1, b2), 0)
        bpc[(b1, b2)] = cnt+1
    return bpc

def replace(tokens, pair, newid):
    toknew = []
    i = 0
    while i < len(tokens):
        if i == len(tokens) - 1: toknew.append(tokens[i]); break
        if tokens[i] == pair[0] and tokens[i+1] == pair[1]:
            toknew.append(newid); i += 2
        else: toknew.append(tokens[i]); i += 1
    return toknew

In [14]:
def shrink(vs, tokens, show_stats = False):
    if show_stats:
        tokcntstat = np.zeros(vs - 255)
        lenstat = np.zeros(vs - 255)
        diffstat = np.zeros(vs - 255)
    tokcnt = 255
    merges = {}
    tokentemp = tokens.copy()
    while tokcnt < vs:
        bpc = getbpc(tokentemp)
        pair = max(bpc, key=bpc.get)
        tokcnt += 1
        merges[pair] = tokcnt
        toknew = replace(tokentemp, pair, tokcnt)
        if show_stats:
            tokcntstat[vs - tokcnt] = tokcnt
            lenstat[vs - tokcnt] = len(toknew)
            diffstat[vs - tokcnt] = len(tokentemp) - len(toknew)
        tokentemp = toknew
    if show_stats:
        plt.figure()
        plt.plot(tokcntstat, diffstat)
        compressionRatios = len(tokens) / lenstat
        plt.figure()
        plt.plot(tokcntstat, lenstat)
        plt.figure()
        plt.plot(tokcntstat, (compressionRatios))
    return (tokentemp, merges)

In [15]:
stokens, merges = shrink(300, tokens, show_stats=False)

In [16]:
vocab = {i: bytes([i]) for i in range(256)}
for (a, b), c in merges.items():
    vocab[c] = vocab[a] + vocab[b]

In [17]:
def decode(ids):
    tokens = b"".join(vocab[idx] for idx in ids)
    return tokens.decode("utf-8", errors='replace')

In [18]:
def encode(text):
    tokens = list(text.encode("utf-8"))
    if len(tokens) < 2: return tokens
    while True:
        bpc = getbpc(tokens)
        pair = min(bpc, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges: break
        idx = merges[pair]
        tokens = replace(tokens, pair, idx)
    return tokens

In [19]:
print(decode(encode("Hello World")))

Hello World


In [41]:
patterns = [
    r"""'s|'t|'re|'ve|'m|'ll|'d""", # Common suffixes in english language
    r""" ?\p{L}+""", # Optional space followed by a series of letters
    r""" ?\p{N}+""", # Optional space followed by a series of numbers
    r""" ?[^\s\p{L}\p{N}]+""",
    r"""\s+(?!\S)""", # Any ammount of Symbols
    r"""\s+""",
]

fullpat = re.compile(r"(?i:"+r"|".join(patterns) + r")", flags=re.IGNORECASE)
print(re.findall(fullpat, "HELLO'S World123 how've      are you!!!?"))

['HELLO', "'S", ' World', '123', ' how', "'ve", '     ', ' are', ' you', '!!!?']
