In [None]:
from collections import Counter

class BPETokenizer:
    def __init__(self, vocab_size=200):
        self.vocab_size = vocab_size
        self.merges = []              
        self.token_to_id = {}         
        self.id_to_token = {}         

    def get_bigrams(self, word):
        return [(word[i], word[i+1]) for i in range(len(word)-1)]

    def train(self, corpus_path):
        
        with open(corpus_path, "r", encoding="utf-8") as f:
            text = f.read()

      
        words = text.split()
        tokens = [list(word) for word in words]

        
        vocab = set(char for word in tokens for char in word)

        
        while len(vocab) < self.vocab_size:
            pair_counts = Counter()
            for word in tokens:
                for pair in self.get_bigrams(word):
                    pair_counts[pair] += 1

            if not pair_counts:
                break

            
            best_pair = max(pair_counts, key=pair_counts.get)
            new_symbol = "".join(best_pair)

            self.merges.append(best_pair)
            vocab.add(new_symbol)

            
            new_tokens = []
            for word in tokens:
                i = 0
                merged_word = []
                while i < len(word):
                    if i < len(word)-1 and (word[i], word[i+1]) == best_pair:
                        merged_word.append(new_symbol)
                        i += 2
                    else:
                        merged_word.append(word[i])
                        i += 1
                new_tokens.append(merged_word)
            tokens = new_tokens

        
        self.token_to_id = {tok: idx for idx, tok in enumerate(sorted(vocab))}
        self.id_to_token = {idx: tok for tok, idx in self.token_to_id.items()}

    def encode(self, text):
        tokens = list(text)

        changed = True
        while changed:
            changed = False
            for (a, b) in self.merges:
                new_symbol = a + b
                i = 0
                merged = []
                while i < len(tokens):
                    if i < len(tokens)-1 and tokens[i] == a and tokens[i+1] == b:
                        merged.append(new_symbol)
                        i += 2
                        changed = True
                    else:
                        merged.append(tokens[i])
                        i += 1
                tokens = merged

        return [self.token_to_id[tok] for tok in tokens if tok in self.token_to_id]

    def decode(self, token_ids):
        tokens = [self.id_to_token[i] for i in token_ids]
        return "".join(tokens)