In [54]:
# my version of byte-pair encoding

from collections import defaultdict

def pf(l, return_d=False):
    d = defaultdict(int)
    for (a, b) in zip(l, l[1:]):
        d[(a, b)] += 1
    pair = max(d, key=d.get)
    if return_d:
        return d
    return pair if d[pair] > 1 else None

def replace(l, pair, new_token):
    idx = 1
    new_l = []
    while idx < len(l):
        if (l[idx-1], l[idx]) == pair:
            new_l.append(new_token)
            idx += 2
        else:
            new_l.append(l[idx-1])
            idx += 1
    if idx == len(l):
        new_l.append(l[-1])
    return new_l

def replace_back(code, token, pair):
    new_code = []
    for tok in code:
        if tok == token:
            new_code.extend(pair)
        else:
            new_code.append(tok)
    return new_code

# byte-pair encoding
def bpe(text, max_size=300):
    # convert text ot utf-8 encoding
    vocab = defaultdict(int)
    code = list(text.encode('utf-8'))
    curr_size = 255
    pair = pf(code)
    while pair and curr_size < max_size:
        curr_size += 1
        vocab[pair] = curr_size
        code = replace(code, pair, curr_size)
        pair = pf(code)
    return code, curr_size, vocab

def encode(text, vocab):
    l = list(text.encode('utf-8'))
    pairs = pf(l, return_d=True)
    intersection = set(pairs.keys()).intersection(vocab.keys())
    while intersection:
        pair = list(intersection)[0]
        token = vocab[pair]
        l = replace(l, pair, token)
        pairs = pf(l, return_d=True)
        intersection = set(pairs.keys()).intersection(vocab.keys())
    return l

def decode(code, vocab):
    # inverse vocab
    vocab = {v: k for k, v in vocab.items()}
    m = max(code)
    while m > 255:
        code = replace_back(code, m, vocab[m])
        m = max(code)
    return bytes(code).decode('utf-8', errors='replace')


In [55]:
# read text
with open('text.txt', 'r') as f:
    text = f.read()

code, vocab_size, vocab = bpe(text)

In [56]:
line = 'Renat is a cool guy with a lot of ice in his bucket.'
code = encode(line, vocab)
print(f'Original length = {len(list(line.encode('utf-8')))}')
print(f'Encoded length = {len(code)}')
print(f'Compression ratio = {len(list(line.encode("utf-8"))) / len(code)}')


Original length = 52
Encoded length = 38
Compression ratio = 1.368421052631579


In [57]:
decode(code, vocab)

'Renat is a cool guy with a lot of ice in his bucket.'