In [1]:
from collections import defaultdict, Counter
import re
import json
import unicodedata
from typing import List, Tuple, Dict

In [2]:
def normalize_text(text: str) -> str:
    return unicodedata.normalize('NFC', text).strip()


# very simple pre-tokenizer: separate common punctuation from words
_PUNCT_RE = re.compile(r"([\.,!\?:;()\[\]\"'“”„—–])")


def pre_tokenize(text: str) -> List[str]:
    text = normalize_text(text)
    # space around punctuation
    text = _PUNCT_RE.sub(r" \1 ", text)
    # collapse multiple spaces
    text = re.sub(r"\s+", " ", text)
    return text.split(' ')

def get_vocabulary_from_corpus(corpus_lines: List[str]) -> Dict[Tuple[str, ...], int]:
    vocab = defaultdict(int)
    for line in corpus_lines:
        tokens = pre_tokenize(line)
        for tok in tokens:
            if tok == '':
                continue
            word = list(tok) + ['</w>']
            vocab[tuple(word)] += 1
    return vocab

def get_stats(vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, str], int]:
    pairs = defaultdict(int)
    for word, freq in vocab.items():
        symbols = list(word)
        for i in range(len(symbols) - 1):
            pairs[(symbols[i], symbols[i+1])] += freq
    return pairs

def merge_vocab(pair: Tuple[str, str], vocab: Dict[Tuple[str, ...], int]) -> Dict[Tuple[str, ...], int]:
    """Merge all occurrences of the given pair in the vocab and return new vocab."""
    merged_vocab = {}
    bigram = re.escape(' '.join(pair))
    pattern = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    # We'll represent words as space-joined symbols for using regex replacement safely
    for word, freq in vocab.items():
        word_str = ' '.join(word)
        # replace occurrences of the pair with merged symbol
        new_word_str = pattern.sub(''.join(pair), word_str)
        new_word = tuple(new_word_str.split(' '))
        merged_vocab[new_word] = freq if new_word not in merged_vocab else merged_vocab[new_word] + freq
    return merged_vocab




def train_bpe(corpus_lines: List[str], num_merges: int = 10000, min_freq: int = 2) -> Tuple[List[Tuple[str, str]], Dict[Tuple[str, ...], int]]:
    """Train BPE merges on the given corpus lines.


    Args:
    corpus_lines: list of raw text lines
    num_merges: number of merge operations
    min_freq: minimum frequency of a pair to consider merging


    Returns:
    merges: ordered list of merged pairs (tuple of two symbols)
    vocab: final vocabulary mapping (tuple(symbols) -> frequency)
    """
    vocab = get_vocabulary_from_corpus(corpus_lines)
    merges: List[Tuple[str, str]] = []


    for i in range(num_merges):
        pairs = get_stats(vocab)
        if not pairs:
            break
        # pick the most frequent pair
        best_pair, best_freq = max(pairs.items(), key=lambda kv: kv[1])
        if best_freq < min_freq:
            break
        merges.append(best_pair)
        vocab = merge_vocab(best_pair, vocab)
    return merges, vocab


# ----------------------------- Encoding / Tokenizing -----------------------------


def build_merge_lookup(merges: List[Tuple[str, str]]) -> Dict[Tuple[str, str], int]:
    return {pair: idx for idx, pair in enumerate(merges)}




def encode_word(word: str, merges: List[Tuple[str, str]]) -> List[str]:
    """Encode a single word into BPE tokens using the learned merges.


    We use an ordered-merge greedy algorithm: apply merges in the order they were learned.
    """
    # initialize
    symbols = list(word) + ['</w>']
    merges_set = set(merges)


    # We'll attempt to apply merges repeatedly in the learned order. This is simple but not the fastest.
    merged = True
    while merged:
        merged = False
        i = 0
        new_symbols = []
        while i < len(symbols):
            if i < len(symbols) - 1 and (symbols[i], symbols[i+1]) in merges_set:
                # merge this pair
                new_symbols.append(symbols[i] + symbols[i+1])
                i += 2
                merged = True
            else:
                new_symbols.append(symbols[i])
                i += 1
        symbols = new_symbols
    # remove end-of-word marker by concatenating with previous token or keeping it as separate indicator
    if symbols and symbols[-1] == '</w>':
        symbols = symbols[:-1]
    return symbols




def encode(text: str, merges: List[Tuple[str, str]]) -> List[str]:
    """Encode a full text into a sequence of BPE tokens (flattened)."""
    tokens = []
    for tok in pre_tokenize(text):
        if tok == '':
            continue
        enc = encode_word(tok, merges)
        tokens.extend(enc + [' ']) # append a space token as separator (optional)
    # remove trailing space token
    if tokens and tokens[-1] == ' ':
        tokens = tokens[:-1]
    return tokens


# ----------------------------- Persistence -----------------------------


def save_merges(merges: List[Tuple[str, str]], path: str):
    with open(path, 'w', encoding='utf-8') as f:
        json.dump([list(pair) for pair in merges], f, ensure_ascii=False, indent=2)




def load_merges(path: str) -> List[Tuple[str, str]]:
    with open(path, 'r', encoding='utf-8') as f:
        raw = json.load(f)
    return [tuple(pair) for pair in raw]


In [3]:
sample = [
    'Xin chào! Tôi tên là An.',
    'Hôm nay trời đẹp và tôi thích cà phê.',
    'Bạn có muốn đi ăn phở không?',
    'Tôi thích lập trình bằng Python. Python rất mạnh.'
]


print('Training BPE on a tiny sample corpus (for demonstration)...')
merges, vocab = train_bpe(sample, num_merges=20000, min_freq=2)
print(f'Learned {len(merges)} merges (showing first 30):')
for i, m in enumerate(merges[:30]):
    print(f'{i+1:3d}. {m[0]} + {m[1]}')


test_sent = 'Tôi thích cà phê và Python.'
toks = encode(test_sent, merges)
print('\nOriginal:', test_sent)
print('BPE tokens:', toks)


# Save merges for later reuse
save_merges(merges, 'bpe_tokenizer/bpe_merges.json')
print('\nSaved merges to bpe_tokenizer/bpe_merges.json')

Training BPE on a tiny sample corpus (for demonstration)...
Learned 21 merges (showing first 30):
  1. n + </w>
  2. i + </w>
  3. . + </w>
  4. t + h
  5. h + </w>
  6. ô + i</w>
  7. à + </w>
  8. T + ôi</w>
  9. t + r
 10. p + </w>
 11. th + í
 12. thí + c
 13. thíc + h</w>
 14. p + h
 15. n + g
 16. ng + </w>
 17. n + h</w>
 18. P + y
 19. Py + th
 20. Pyth + o
 21. Pytho + n</w>

Original: Tôi thích cà phê và Python.
BPE tokens: ['Tôi</w>', ' ', 'thích</w>', ' ', 'c', 'à</w>', ' ', 'ph', 'ê', ' ', 'v', 'à</w>', ' ', 'Python</w>', ' ', '.</w>']

Saved merges to bpe_tokenizer/bpe_merges.json
