In [3]:
import random

class Tokenization:
    def __init__(self, seed=None):
        self.rng = random.Random(seed)
        self.tree = {}
        self.encode_map = {b'': 0}
        self.decode_map = {0: b''}

    def add(self, text: bytes):
        assert text, repr(text)
        current = self.tree
        # Add all substrings to the tree before the full string
        if text[:-1] not in self.encode_map:
            self.add(text[:-1])
        # Add new token to our maps
        token = len(self.encode_map)
        self.encode_map[text] = token
        self.decode_map[token] = text
        # Traverse the tree to add the node
        while text:
            t, *text = text
            if t not in current:
                current[t] = {}
            current = current[t]

    def encode_step(self, text: bytes, compression: float=0.9):
        ''' Encode a single step, returning the next token and remaining text '''
        assert 0.0 <= compression <= 1.0, repr(compression)
        # Traverse the tree
        code = []
        current = self.tree
        while text and text[0] in current and self.rng.random() < compression:
            code.append(text[0])
            current = current[text[0]]
            text = text[1:]
        # Return the token and remaining text
        return self.encode_map(bytes(code)), text

    def encode(self, text: bytes, compression: float=0.9):
        ''' Encode a string, returning a list of tokens '''
        assert 0.0 <= compression <= 1.0, repr(compression)
        # Encode the text step by step
        tokens = []
        while len(text):
            token, text = self.encode_step(text, compression=compression)
            tokens.append(token)
        return tokens

    def decode(self, tokens: list):
        ''' Decode a list of tokens, returning the original text '''
        return b''.join(self.decode_map[t] for t in tokens)


tok = Tokenization()
tok.add(b'a')
tok.add(b'b')
tok.add(b'ab')
tok.tree

{97: {98: {}}, 98: {}}

In [9]:
from infoseq.token import Tokenization
from infoseq.bpe import get_pairs, bpe_from_text


text = b'aaabdaaabac'
max_tokens=257 + 3
compression=1.0
seq_len=len(text)
num_seq=1
seed=0
tok = Tokenization.basic(0)
while len(tok) < max_tokens:
    # Get pairs until we get something new
    for code in get_pairs(text, tok, compression=compression, seq_len=seq_len, num_seq=num_seq, seed=seed):
        print('code', code, code in tok)
        if code not in tok:
            break
    else:
        assert False, "Could not find new code"
    # Add the smallest sub-code which is new
    for i in range(1, len(code) + 1):
        print('testing', code[:i], code[:i] in tok)
        if code[:i] not in tok:
            print('adding', code[:i])
            tok.add(code[:i])
            assert code[:i] in tok
            break
    else:
        assert False, f"{code} ({len(code)}) : {code in tok}"


code b'aa' False
testing b'a' True
testing b'aa' False
adding b'aa'
code b'aaa' False
testing b'a' True
testing b'aa' True
testing b'aaa' False
adding b'aaa'
code b'aaab' False
testing b'a' True
testing b'aa' True
testing b'aaa' True
testing b'aaab' False
adding b'aaab'
