# testing how bpe encoding could work

In [16]:
import logging
import yaml
import logging.config

with open("logging.yaml", "rt") as f:
    config = yaml.safe_load(f.read())
    logging.config.dictConfig(config)

logger = logging.getLogger("llms")
#logger.setLevel(logging.INFO)
logger.info("hello")

2025-08-28 11:58:28,336 - llms - INFO - hello


In [None]:
from collections import Counter, deque

class MyBpe:
    def __init__(self):
        # Maps token_id to token_str (e.g., {11246: "some"})
        self.vocab = {}
        # Maps token_str to token_id (e.g., {"some": 11246})
        self.inverse_vocab = {}
        # Dictionary of BPE merges: {(token_id1, token_id2): merged_token_id}
        self.bpe_merges = {}

        # For the official OpenAI GPT-2 merges, use a rank dict:
        #  of form {(string_A, string_B): rank}, where lower rank = higher priority
        self.bpe_ranks = {}

    def train(self, text, vocab_size, allowed_special={"<|endoftext|>"}):
        processed_text = []
        for i, char in enumerate(text):
            if char == " " and i != 0:
                processed_text.append("G")
            if char != " ":
                processed_text.append(char)
        processed_text = "".join(processed_text)
    
        print(f"processed_text: {processed_text}")

        unique_chars = [chr(i) for i in range(256)]

        unique_chars.extend(
            char for char in sorted(set(processed_text))
            if char not in unique_chars
        )
        if "Ġ" not in unique_chars:
            unique_chars.append("Ġ")

        self.vocab = {i: char for i, char in enumerate(unique_chars)}
        self.inverse_vocab = {char: i for i, char in self.vocab.items()}

        print(f"vocab: {list(self.vocab.items())[:3]}")
        print(f"inverse_vocab: {list(self.inverse_vocab.items())[:3]}")

        token_ids = [self.inverse_vocab[char] for char in processed_text]

        print(f"token_ids: {token_ids}")
        print(f"decoded tokens: {[self.vocab[idx] for idx in token_ids]}")

        # BPE steps 1-3: Repeatedly find and replace frequent pairs
        for new_id in range(len(self.vocab), vocab_size):
            pair_id = self.find_freq_pair(token_ids, mode="most")
            print(f"pair_ids: {pair_id}")
            if pair_id is None:
                break
            token_ids = self.replace_pair(token_ids, pair_id, new_id)
            self.bpe_merges[pair_id] = new_id

    @staticmethod
    def find_freq_pair(token_ids, mode="most"):
        pairs = Counter(zip(token_ids, token_ids[1:]))
        print(f"pairs: {pairs}")
        if not pairs:
            return None

        if mode == "most":
            return max(pairs.items(), key=lambda x: x[1])[0]
        elif mode == "least":
            return min(pairs.items(), key=lambda x: x[1])[0]
        else:
            raise ValueError("Invalid mode. Choose 'most' or 'least'.")

    @staticmethod
    def replace_pair(token_ids, pair_id, new_id):
        dq = deque(token_ids)
        replaced = []

        while dq:
            current = dq.popleft()
            if dq and (current, dq[0]) == pair_id:
                replaced.append(new_id)
                # Remove the 2nd token of the pair, 1st was already removed
                dq.popleft()
            else:
                replaced.append(current)

        return replaced

In [50]:
mpbe = MyBpe()
mpbe.train("the fox jumped over the fenced", 258)

processed_text: theGfoxGjumpedGoverGtheGfenced
vocab: [(0, '\x00'), (1, '\x01'), (2, '\x02')]
inverse_vocab: [('\x00', 0), ('\x01', 1), ('\x02', 2)]
token_ids: [116, 104, 101, 71, 102, 111, 120, 71, 106, 117, 109, 112, 101, 100, 71, 111, 118, 101, 114, 71, 116, 104, 101, 71, 102, 101, 110, 99, 101, 100]
decoded tokens: ['t', 'h', 'e', 'G']
pairs: Counter({(116, 104): 2, (104, 101): 2, (101, 71): 2, (71, 102): 2, (101, 100): 2, (102, 111): 1, (111, 120): 1, (120, 71): 1, (71, 106): 1, (106, 117): 1, (117, 109): 1, (109, 112): 1, (112, 101): 1, (100, 71): 1, (71, 111): 1, (111, 118): 1, (118, 101): 1, (101, 114): 1, (114, 71): 1, (71, 116): 1, (102, 101): 1, (101, 110): 1, (110, 99): 1, (99, 101): 1})
pair_ids: (116, 104)
