<a href="https://colab.research.google.com/github/iandsilva4/LearningML/blob/main/BytePairEncoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [12]:
# https://mybinder.org/v2/gh/iandsilva4/LearningML/main?urlpath=voila%2Frender%2FBytePairEncoding.ipynb

import collections
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ipywidgets as widgets

def byte_repr(b):
    if 32 <= b <= 126:
        return repr(chr(b))
    else:
        return f"0x{b:02x}"

def resolve_token(token, id_to_token):
    if isinstance(token, int):
        if token < 256:
            return [token]
        elif token in id_to_token:
            return resolve_token(id_to_token[token], id_to_token)
        else:
            return [token]
    elif isinstance(token, tuple):
        return resolve_token(token[0], id_to_token) + resolve_token(token[1], id_to_token)
    else:
        return [token]

def get_pair_frequencies(seq):
    pairs = collections.Counter()
    for i in range(len(seq) - 1):
        pairs[(seq[i], seq[i+1])] += 1
    return pairs

def merge_pair(seq, pair, new_token):
    i = 0
    output = []
    while i < len(seq):
        if i < len(seq) - 1 and (seq[i], seq[i+1]) == pair:
            output.append(new_token)
            i += 2
        else:
            output.append(seq[i])
            i += 1
    return output

def build_bpe_steps(corpus, num_merges):
    seq = list(corpus.encode('utf-8'))
    vocab = {tuple([b]): b for b in range(256)}
    id_to_token = {b: tuple([b]) for b in range(256)}
    steps = []
    next_token = 256
    for merge_step in range(num_merges):
        pairs = get_pair_frequencies(seq)
        if not pairs:
            break
        most_freq = max(pairs.items(), key=lambda x: x[1])[0]
        seq = merge_pair(seq, most_freq, next_token)
        vocab[most_freq] = next_token
        id_to_token[next_token] = most_freq
        steps.append({
            'step': merge_step + 1,
            'seq': seq.copy(),
            'vocab': vocab.copy(),
            'id_to_token': id_to_token.copy(),
            'pairs': pairs.copy(),
            'merged': most_freq,
            'new_token': next_token
        })
        next_token += 1
    return steps

def color_tokens(seq, id_to_token):
    unique_tokens = list(dict.fromkeys(seq))
    color_list = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.CSS4_COLORS.values())
    color_map = {tok: color_list[i % len(color_list)] for i, tok in enumerate(unique_tokens)}
    html = ""
    for tok in seq:
        resolved = resolve_token(tok, id_to_token)
        chars = ''.join(chr(b) if 32 <= b <= 126 else '.' for b in resolved)
        html += f"<span style='background-color:{color_map[tok]};padding:2px 4px;margin:1px;border-radius:3px;color:black;font-weight:bold'>{chars}</span>"
    return widgets.HTML(html)

def vocab_table(vocab, id_to_token, highlight_token=None, merged_only=False):
    rows = []
    for k, v in sorted(vocab.items(), key=lambda x: -x[1]):
        if merged_only and v < 256:
            continue
        resolved = resolve_token(v, id_to_token)
        chars = ''.join(chr(b) if 32 <= b <= 126 else '.' for b in resolved)
        if isinstance(k, tuple) and len(k) == 2:
            base = f"{resolved} ('{chars}')"
        else:
            b = k[0]
            base = f"({b}) ({byte_repr(b)})"
        style = "background-color: #ffe082;" if v == highlight_token else ""
        rows.append(f"<tr style='{style}'><td>{v}</td><td>{base}</td></tr>")
    if not rows:
        return widgets.HTML("<i>No merged tokens yet.</i>")
    table = "<table><tr><th>Token ID</th><th>Token (as text)</th></tr>" + "".join(rows) + "</table>"
    return widgets.HTML(table)

def pair_freq_bar(pairs, id_to_token, highlight_pair=None, max_pairs=10):
    if not pairs:
        return widgets.Output()
    sorted_pairs = sorted(pairs.items(), key=lambda x: -x[1])[:max_pairs]
    labels = []
    values = []
    colors = []
    for (a, b), freq in sorted_pairs:
        a_str = ''.join(chr(x) if 32 <= x <= 126 else '.' for x in resolve_token(a, id_to_token))
        b_str = ''.join(chr(x) if 32 <= x <= 126 else '.' for x in resolve_token(b, id_to_token))
        labels.append(f"'{a_str}'+'{b_str}'")
        values.append(freq)
        if highlight_pair and (a, b) == highlight_pair:
            colors.append('#ff7043')
        else:
            colors.append('#90caf9')
    out = widgets.Output()
    with out:
        import matplotlib.pyplot as plt
        fig, ax = plt.subplots(figsize=(min(12, 0.7*len(labels)), 2.5))
        ax.bar(labels, values, color=colors)
        ax.set_ylabel("Frequency")
        ax.set_title("Pair Frequencies (Top 10)")
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
    return out

def bpe_visualization_ui():
    default_corpus = (
        "low lower lowest slowly slower slowest slowdown\n"
        "new newer newest newly newness\n"
        "fast faster fastest fasten fastened fastening\n"
        "run runner running rerun overrun outrun\n"
        "play player playing replay gameplay\n"
        "write writer writing rewrite overwritten"
    )
    corpus_box = widgets.Textarea(
        value=default_corpus,
        description='Corpus:',
        layout=widgets.Layout(width='100%', height='100px')
    )
    merge_input = widgets.IntText(
        value=5, min=1, max=100, step=1, description='Merges:'
    )
    vocab_size_label = widgets.Label(value=f"Target Vocab Size: {256 + merge_input.value}")
    auto_run_toggle = widgets.ToggleButton(value=False, description='Auto-run', icon='play')
    prev_button = widgets.Button(description='Previous', disabled=True)
    next_button = widgets.Button(description='Next', disabled=False)
    step_label = widgets.Label(value='Step: 0')
    main_vbox = widgets.VBox()
    ui = widgets.VBox([
        widgets.HTML("<h2>Byte Pair Encoding (BPE) Visualizer</h2>"),
        corpus_box,
        widgets.HBox([merge_input, vocab_size_label]),
        widgets.HBox([prev_button, next_button, step_label, auto_run_toggle]),
        main_vbox
    ])
    state = {'steps': [], 'step': 0, 'auto_stop': False}
    def update_vocab_size_label(*args):
        vocab_size_label.value = f"Target Vocab Size: {256 + merge_input.value}"
    def update_steps(*args):
        state['steps'] = build_bpe_steps(corpus_box.value, merge_input.value)
        state['step'] = 0
        update_vocab_size_label()
        update_buttons()
        show_step()
    def update_buttons():
        prev_button.disabled = (state['step'] == 0)
        next_button.disabled = (state['step'] >= len(state['steps']))
        step_label.value = f"Step: {state['step']}"
    def show_step(change=None):
        step = state['step']
        steps = state['steps']
        widgets_to_show = []
        if step == 0:
            seq = list(corpus_box.value.encode('utf-8'))
            id_to_token = {b: tuple([b]) for b in range(256)}
            color_tokens_widget = color_tokens(seq, id_to_token)
            explanation_html = widgets.HTML(
                "<div style='margin:10px 0 18px 0;padding:8px 12px;background:#f5f5f5;border-left:4px solid #2196f3;'>"
                "<b>Explanation:</b> This is the initial byte-level tokenization. Each character (including spaces and newlines) is a separate token. "
                "The table below shows the initial vocabulary of all 256 possible byte values, truncated for clarity."
                "</div>"
            )
            initial_vocab_html = widgets.HTML(
                "<div style='margin-top:8px;'>"
                "<b>Initial Vocabulary (truncated):</b><br>"
                "<span style='font-family:monospace;'>0: 0x00<br>...<br>255: 0xff</span>"
                "</div>"
            )
            widgets_to_show = [
                widgets.HTML("<b>Step 0: Initial state (no merges yet)</b>"),
                color_tokens_widget,
                explanation_html,
                initial_vocab_html
            ]
        else:
            s = steps[step-1]
            color_tokens_widget = color_tokens(s['seq'], s['id_to_token'])
            merged_str = (
                f"Merged the most frequent pair <b>{s['merged']}</b> ("
                f"<span style='font-family:monospace;'>'{''.join(chr(x) if 32 <= x <= 126 else '.' for x in resolve_token(s['merged'][0], s['id_to_token']))}' + "
                f"'{''.join(chr(x) if 32 <= x <= 126 else '.' for x in resolve_token(s['merged'][1], s['id_to_token']))}'</span>) "
                f"into new token <b>{s['new_token']}</b>."
            )
            explanation_html = widgets.HTML(
                f"<div style='margin:10px 0 18px 0;padding:8px 12px;background:#f5f5f5;border-left:4px solid #2196f3;'>"
                f"<b>Explanation:</b> {merged_str}</div>"
            )
            merged_vocab_html = vocab_table(s['vocab'], s['id_to_token'], highlight_token=s['new_token'], merged_only=True)
            initial_vocab_html = widgets.HTML(
                "<div style='margin-top:8px;'>"
                "<b>Initial Vocabulary (truncated):</b><br>"
                "<span style='font-family:monospace;'>0: 0x00<br>...<br>255: 0xff</span>"
                "</div>"
            )
            pair_freq_box = pair_freq_bar(s['pairs'], s['id_to_token'], highlight_pair=s['merged'], max_pairs=10)
            hbox = widgets.HBox([
                widgets.VBox([
                    merged_vocab_html,
                    initial_vocab_html
                ], layout=widgets.Layout(width='50%')),
                pair_freq_box
            ])
            widgets_to_show = [
                widgets.HTML(f"<b>Step {step}: Merge {s['merged']} into token {s['new_token']}</b>"),
                color_tokens_widget,
                explanation_html,
                hbox
            ]
        main_vbox.children = widgets_to_show
        update_buttons()
    def on_prev(b):
        if state['step'] > 0:
            state['step'] -= 1
            show_step()
    def on_next(b):
        if state['step'] < len(state['steps']):
            state['step'] += 1
            show_step()
    def on_auto_run(change):
        if auto_run_toggle.value:
            state['auto_stop'] = False
            while state['step'] < len(state['steps']) and not state['auto_stop']:
                import time
                time.sleep(0.7)
                state['step'] += 1
                show_step()
            auto_run_toggle.value = False
        else:
            state['auto_stop'] = True
    corpus_box.observe(update_steps, names='value')
    merge_input.observe(update_steps, names='value')
    merge_input.observe(update_vocab_size_label, names='value')
    prev_button.on_click(on_prev)
    next_button.on_click(on_next)
    auto_run_toggle.observe(on_auto_run, names='value')
    update_steps()
    return ui

# To use in a notebook or Voila, just run:
bpe_visualization_ui()

VBox(children=(HTML(value='<h2>Byte Pair Encoding (BPE) Visualizer</h2>'), Textarea(value='low lower lowest sl…