<a href="https://colab.research.google.com/github/nag07799/LLM_Journey/blob/main/Bytepairencoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import re
from collections import defaultdict, Counter

class BytePairEncoding:
    def __init__(self, text):
        self.text = text
        self.vocab = self._get_vocabulary()

    def _get_vocabulary(self):
        """Create the initial vocabulary from the input text."""
        vocabulary = defaultdict(int)
        for word in self.text.split():
            word = ' '.join(list(word)) + ' </w>'  # Appending </w> to mark end of word
            vocabulary[word] += 1
        return vocabulary

    def _get_stats(self):
        """Get the frequency of pairs in the vocabulary."""
        pairs = defaultdict(int)
        for word, freq in self.vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def _merge_vocab(self, pair):
        """Merge the most frequent pair in the vocabulary."""
        new_vocab = {}
        bigram = re.escape(' '.join(pair))
        pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
        for word in self.vocab:
            new_word = pattern.sub(''.join(pair), word)
            new_vocab[new_word] = self.vocab[word]
        self.vocab = new_vocab

    def apply_bpe(self, num_merges):
        """Apply Byte Pair Encoding for a specified number of merges."""
        for i in range(num_merges):
            pairs = self._get_stats()
            if not pairs:
                break
            best = max(pairs, key=pairs.get)
            self._merge_vocab(best)
            print(f"Step {i + 1}: Merged {best} -> {self.vocab}")
        return self.vocab

    def get_final_vocabulary(self):
        """Return the final vocabulary after BPE."""
        return self.vocab

# Example usage
text = "low lower lowest low low lower"
num_merges = 10
bpe = BytePairEncoding(text)
final_vocab = bpe.apply_bpe(num_merges)

print("\nFinal Vocabulary:")
for word, freq in final_vocab.items():
    print(f"{word}: {freq}")


Step 1: Merged ('l', 'o') -> {'lo w </w>': 3, 'lo w e r </w>': 2, 'lo w e s t </w>': 1}
Step 2: Merged ('lo', 'w') -> {'low </w>': 3, 'low e r </w>': 2, 'low e s t </w>': 1}
Step 3: Merged ('low', '</w>') -> {'low</w>': 3, 'low e r </w>': 2, 'low e s t </w>': 1}
Step 4: Merged ('low', 'e') -> {'low</w>': 3, 'lowe r </w>': 2, 'lowe s t </w>': 1}
Step 5: Merged ('lowe', 'r') -> {'low</w>': 3, 'lower </w>': 2, 'lowe s t </w>': 1}
Step 6: Merged ('lower', '</w>') -> {'low</w>': 3, 'lower</w>': 2, 'lowe s t </w>': 1}
Step 7: Merged ('lowe', 's') -> {'low</w>': 3, 'lower</w>': 2, 'lowes t </w>': 1}
Step 8: Merged ('lowes', 't') -> {'low</w>': 3, 'lower</w>': 2, 'lowest </w>': 1}
Step 9: Merged ('lowest', '</w>') -> {'low</w>': 3, 'lower</w>': 2, 'lowest</w>': 1}

Final Vocabulary:
low</w>: 3
lower</w>: 2
lowest</w>: 1
