In [341]:
from typing import List,Dict,Tuple

In [342]:
text = """low low low low low lower lower widest widest widest newest newest newest newest newest newest"""

In [343]:
def get_count(tokens: List[str], tokenized: bool = False):
    histogram = {}
    for token in tokens:
        if tokenized:
            token = token.encode()
            token = tuple(token[i : i + 1] for i in range(len(token)))
        if token not in histogram:
            histogram[token] = 1
        else:
            histogram[token] += 1
    return histogram

In [None]:
class BPETokenizer:

    def __init__(self):
        self.token_to_id: Dict[bytes, int] = {"<|endoftext|>".encode(): 0}
        self.id_to_token: List[bytes] = ["<|endoftext|>".encode()]

        for i in range(256):
            byte = i.to_bytes(1)
            self.token_to_id[byte] = len(self.id_to_token)
            self.id_to_token.append(byte)

    def add_token(self, token: bytes):
        self.token_to_id[token] = len(self.id_to_token)
        self.id_to_token.append(token)

    def fit(self, text: str, verbose: bool = False, merges: int = 1):
        tokens = text.split(" ")
        token_histogram: Dict[Tuple[bytes], int] = get_count(tokens, tokenized=True)
        for i in range(merges):
            pair_hist: Dict[Tuple[bytes], int] = {}
            if verbose:
                print("token hist: ",token_histogram)
            for encoded_token, count in token_histogram.items():
                for i in range(1, len(encoded_token)):
                    pair = (encoded_token[i - 1], encoded_token[i])
                    if pair not in pair_hist:
                        pair_hist[pair] = count
                    else:
                        pair_hist[pair] += count

            pair_list = list(pair_hist.items())
            pair_list.sort(key=lambda el: (el[1], el[0]), reverse=True)
            #if verbose:
            #    print("pair_list: ", pair_list)
            tokenA = pair_list[0][0][0]
            tokenB = pair_list[0][0][1]
            new_token = tokenA + tokenB
            self.add_token(new_token)
            if verbose:
                print("new_token: ", new_token)
            # mergnount otken_histogram
            # musím zrekonstruovat pair_hist
            new_token_histogram: Dict[Tuple[bytes], int] = {}
            for encoded_token, count in token_histogram.items():
                new_encoded_token = []
                i=1
                while i < len(encoded_token):
                    combined = encoded_token[i - 1] + encoded_token[i]
                    if combined == new_token:
                        new_encoded_token.append(combined)
                        i += 2
                    else:
                        new_encoded_token.append(encoded_token[i - 1])
                        i += 1
                #print(i,len(encoded_token),encoded_token,new_encoded_token)
                if i == len(encoded_token) or not new_encoded_token or new_encoded_token[-1] != new_token:
                    new_encoded_token.append(encoded_token[-1])
                new_encoded_token = tuple(new_encoded_token)
                new_token_histogram[new_encoded_token] = count
            token_histogram = new_token_histogram
        if verbose:
            print("token hist: ",token_histogram)

    def tranform(self):
        pass

In [345]:
tokenizer = BPETokenizer()

In [346]:
tokenizer.fit(text,verbose=True,merges=6)

token hist:  {(b'l', b'o', b'w'): 5, (b'l', b'o', b'w', b'e', b'r'): 2, (b'w', b'i', b'd', b'e', b's', b't'): 3, (b'n', b'e', b'w', b'e', b's', b't'): 6}
new_token:  b'st'
3 3 (b'l', b'o', b'w') [b'l', b'o']
5 5 (b'l', b'o', b'w', b'e', b'r') [b'l', b'o', b'w', b'e']
7 6 (b'w', b'i', b'd', b'e', b's', b't') [b'w', b'i', b'd', b'e', b'st']
7 6 (b'n', b'e', b'w', b'e', b's', b't') [b'n', b'e', b'w', b'e', b'st']
token hist:  {(b'l', b'o', b'w'): 5, (b'l', b'o', b'w', b'e', b'r'): 2, (b'w', b'i', b'd', b'e', b'st'): 3, (b'n', b'e', b'w', b'e', b'st'): 6}
new_token:  b'est'
3 3 (b'l', b'o', b'w') [b'l', b'o']
5 5 (b'l', b'o', b'w', b'e', b'r') [b'l', b'o', b'w', b'e']
6 5 (b'w', b'i', b'd', b'e', b'st') [b'w', b'i', b'd', b'est']
6 5 (b'n', b'e', b'w', b'e', b'st') [b'n', b'e', b'w', b'est']
token hist:  {(b'l', b'o', b'w'): 5, (b'l', b'o', b'w', b'e', b'r'): 2, (b'w', b'i', b'd', b'est'): 3, (b'n', b'e', b'w', b'est'): 6}
new_token:  b'ow'
4 3 (b'l', b'o', b'w') [b'l', b'ow']
5 5 (b'l', b

In [347]:
tokens = text.split(" ")

In [348]:
token_histogram = get_count(tokens)
token_histogram

{'low': 5, 'lower': 2, 'widest': 3, 'newest': 6}