<a href="https://colab.research.google.com/github/iandsilva4/LearningML/blob/main/BytePairEncoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
# Byte Pair Encoding (BPE) Visualizer — Voila-safe 2-column version
#
# Works in: local Jupyter, Colab, Binder+Voila.
#
# Key changes:
#   • Escapes HTML to avoid DOM breakage on <, >, & bytes.
#   • Uses explicit 2-column flex layout (table fixed width; plot flexible).
#   • Explicit display(app) call for Voila reliability.
#   • Docstrings cleaned: no bare \x escapes (use \\xNN in text).
#
# Ian D'Silva / ChatGPT — 2025-07-17

import collections
import html
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import ipywidgets as widgets
from IPython.display import display

# ------------------------------------------------------------------
# Utilities
# ------------------------------------------------------------------

PRINTABLE_LOW  = 32
PRINTABLE_HIGH = 126

def is_printable_byte(b: int) -> bool:
    """True if byte is in a visible ASCII range."""
    return PRINTABLE_LOW <= b <= PRINTABLE_HIGH

def byte_repr(b: int) -> str:
    """Canonical display for a single raw byte."""
    if is_printable_byte(b):
        ch = chr(b)
        # Escape HTML if necessary
        if ch in "<>&'\"":
            return f"'{html.escape(ch)}'"
        return f"'{ch}'"
    return f"0x{b:02x}"

def to_printable(bseq, show_hex_for_control=True, dot_for_control=False):
    r"""
    Turn an iterable of bytes into a human-readable string.

    Parameters
    ----------
    bseq : iterable[int]
        Byte values.
    show_hex_for_control : bool
        If True, render non-printables as literal hex strings like '\\xNN'.
    dot_for_control : bool
        If True, render non-printables as '.' (overrides show_hex_for_control).
    """
    pieces = []
    for b in bseq:
        if is_printable_byte(b) and chr(b) not in "<>&":
            pieces.append(chr(b))
        else:
            if dot_for_control:
                pieces.append('.')
            elif show_hex_for_control:
                pieces.append(f"\\x{b:02x}")
            else:
                pieces.append(' ')
    return "".join(pieces)

def resolve_token(token, id_to_token):
    """
    Given a token ID or tuple pair, return the underlying *flat list of base bytes*.
    `id_to_token` maps int->tuple-of-ids for merges, and base byte IDs map to (b,).
    """
    if isinstance(token, int):
        if token < 256:
            return [token]
        elif token in id_to_token:  # merged token id
            return resolve_token(id_to_token[token], id_to_token)
        else:
            return [token]  # unknown
    elif isinstance(token, tuple):
        return resolve_token(token[0], id_to_token) + resolve_token(token[1], id_to_token)
    else:
        return [token]

# ------------------------------------------------------------------
# Core BPE helpers
# ------------------------------------------------------------------

def get_pair_frequencies(seq):
    """
    Count adjacent token pairs in a sequence of token IDs (ints).
    Returns Counter keyed by (tok_i, tok_{i+1}).
    """
    pairs = collections.Counter()
    for i in range(len(seq) - 1):
        pairs[(seq[i], seq[i+1])] += 1
    return pairs

def merge_pair(seq, pair, new_token):
    """
    Merge all *non-overlapping* occurrences of `pair` in `seq` into `new_token`.
    """
    i = 0
    out = []
    L = len(seq)
    while i < L:
        if i < L - 1 and (seq[i], seq[i+1]) == pair:
            out.append(new_token)
            i += 2
        else:
            out.append(seq[i])
            i += 1
    return out

def build_bpe_steps(corpus: str, num_merges: int):
    """
    Run greedy BPE for `num_merges` steps over the *single* concatenated byte sequence
    built from the provided `corpus` string (UTF-8 encoded).
    Returns a list of step dicts documenting each merge.
    """
    seq = list(corpus.encode("utf-8"))

    # Initial vocab: each single byte -> its own id (0..255)
    vocab = {(b,): b for b in range(256)}     # tuple-of-one-byte -> token id
    id_to_token = {b: (b,) for b in range(256)}  # token id -> tuple-of-children

    steps = []
    next_token = 256

    for merge_idx in range(num_merges):
        pairs = get_pair_frequencies(seq)
        if not pairs:
            break

        most_freq_pair, _ = max(pairs.items(), key=lambda x: x[1])

        # Perform merge
        seq = merge_pair(seq, most_freq_pair, next_token)

        # Update vocab
        vocab[most_freq_pair] = next_token
        id_to_token[next_token] = most_freq_pair

        steps.append({
            "step": merge_idx + 1,
            "seq": seq.copy(),
            "vocab": vocab.copy(),
            "id_to_token": id_to_token.copy(),
            "pairs": pairs.copy(),
            "merged": most_freq_pair,
            "new_token": next_token,
        })

        next_token += 1

    return steps

# ------------------------------------------------------------------
# Rendering widgets
# ------------------------------------------------------------------

def color_tokens(seq, id_to_token):
    """
    Render the sequence as colored spans (inline tokens).
    """
    # preserve order of first appearance for stable colors
    unique_tokens = list(dict.fromkeys(seq))

    color_list = (
        list(mcolors.TABLEAU_COLORS.values()) +
        list(mcolors.CSS4_COLORS.values())
    )
    color_map = {tok: color_list[i % len(color_list)] for i, tok in enumerate(unique_tokens)}

    spans = []
    for tok in seq:
        resolved = resolve_token(tok, id_to_token)
        chars = to_printable(resolved, show_hex_for_control=False, dot_for_control=True)
        spans.append(
            f"<span style='display:inline-block;background-color:{color_map[tok]};"
            f"padding:2px 4px;margin:1px;border-radius:3px;color:black;font-family:monospace;'>"
            f"{html.escape(chars)}</span>"
        )
    return widgets.HTML("".join(spans))


def vocab_table(vocab, id_to_token, highlight_token=None, merged_only=False):
    """
    Build an HTML table of vocab items.
    For merged_only=True show only tokens >=256 (i.e., newly created merges).
    """
    rows = []
    # sort ascending by token id
    for key_tuple, token_id in sorted(vocab.items(), key=lambda x: x[1]):
        if merged_only and token_id < 256:
            continue

        resolved = resolve_token(token_id, id_to_token)
        printable = to_printable(resolved, show_hex_for_control=True, dot_for_control=False)
        printable_html = html.escape(printable)

        if token_id < 256:
            base = byte_repr(resolved[0])
        else:
            base = f"'{printable_html}'"

        style = "background-color:#ffe082;" if token_id == highlight_token else ""
        rows.append(
            f"<tr style='{style}'>"
            f"<td style='padding:2px 6px;border-bottom:1px solid #ddd;'>{token_id}</td>"
            f"<td style='padding:2px 6px;border-bottom:1px solid #ddd;font-family:monospace;'>{base}</td>"
            f"</tr>"
        )

    if not rows:
        return widgets.HTML("<i>No merged tokens yet.</i>")

    table_html = (
        "<div style='max-height:220px;overflow-y:auto;'>"
        "<table style='border-collapse:collapse;width:100%;font-size:13px;'>"
        "<tr>"
        "<th style='text-align:left;padding:2px 6px;border-bottom:2px solid #999;'>Token&nbsp;ID</th>"
        "<th style='text-align:left;padding:2px 6px;border-bottom:2px solid #999;'>Token&nbsp;(as&nbsp;text)</th>"
        "</tr>"
        + "".join(rows) +
        "</table></div>"
    )
    return widgets.HTML(table_html)


def pair_freq_bar(pairs, id_to_token, highlight_pair=None, max_pairs=10):
    """
    Matplotlib bar chart in an Output widget.
    """
    out = widgets.Output()
    if not pairs:
        return out

    sorted_pairs = sorted(pairs.items(), key=lambda x: -x[1])[:max_pairs]

    labels = []
    values = []
    colors = []
    for (a, b), freq in sorted_pairs:
        a_str = to_printable(resolve_token(a, id_to_token), show_hex_for_control=False, dot_for_control=True)
        b_str = to_printable(resolve_token(b, id_to_token), show_hex_for_control=False, dot_for_control=True)
        labels.append(f"'{a_str}'+'{b_str}'")
        values.append(freq)
        colors.append('#ff7043' if highlight_pair and (a, b) == highlight_pair else '#90caf9')

    with out:
        fig, ax = plt.subplots(figsize=(min(12, 0.7 * len(labels)), 2.5))
        ax.bar(labels, values, color=colors)
        ax.set_ylabel("Frequency")
        ax.set_title("Pair Frequencies (Top 10)")
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

    return out

# ------------------------------------------------------------------
# App builder
# ------------------------------------------------------------------

def bpe_visualization_ui():
    default_corpus = (
        "low lower lowest slowly slower slowest slowdown\n"
        "new newer newest newly newness\n"
        "fast faster fastest fasten fastened fastening\n"
        "run runner running rerun overrun outrun\n"
        "play player playing replay gameplay\n"
        "write writer writing rewrite overwritten"
    )

    corpus_box = widgets.Textarea(
        value=default_corpus,
        description='Corpus:',
        layout=widgets.Layout(width='100%', height='100px')
    )

    merge_input = widgets.IntText(
        value=5, min=1, max=100, step=1, description='Merges:'
    )

    vocab_size_label = widgets.Label(value=f"Target Vocab Size: {256 + merge_input.value}")
    auto_run_toggle   = widgets.ToggleButton(value=False, description='Auto-run', icon='play')
    prev_button       = widgets.Button(description='Previous', disabled=True)
    next_button       = widgets.Button(description='Next', disabled=False)
    step_label        = widgets.Label(value='Step: 0')
    main_vbox         = widgets.VBox()

    # top-level container
    ui = widgets.VBox([
        widgets.HTML("<h2>Byte Pair Encoding (BPE) Visualizer</h2>"),
        corpus_box,
        widgets.HBox([merge_input, vocab_size_label]),
        widgets.HBox([prev_button, next_button, step_label, auto_run_toggle]),
        main_vbox
    ])

    # mutable state
    state = {'steps': [], 'step': 0, 'auto_stop': False}

    # --- internal helpers ----------------------------------------------------

    def update_vocab_size_label(*_):
        vocab_size_label.value = f"Target Vocab Size: {256 + merge_input.value}"

    def update_steps(*_):
        state['steps'] = build_bpe_steps(corpus_box.value, merge_input.value)
        state['step'] = 0
        update_vocab_size_label()
        update_buttons()
        show_step()

    def update_buttons():
        prev_button.disabled = (state['step'] == 0)
        next_button.disabled = (state['step'] >= len(state['steps']))
        step_label.value = f"Step: {state['step']}"

    def show_step(change=None):
        step = state['step']
        steps = state['steps']
        widgets_to_show = []

        if step == 0:
            # Initial raw bytes view
            seq = list(corpus_box.value.encode('utf-8'))
            id_to_token = {b: (b,) for b in range(256)}

            color_tokens_widget = color_tokens(seq, id_to_token)

            explanation_html = widgets.HTML(
                "<div style='margin:10px 0 18px 0;padding:8px 12px;"
                "background:#f5f5f5;border-left:4px solid #2196f3;'>"
                "<b>Explanation:</b> This is the initial byte-level tokenization. "
                "Each character (including spaces and newlines) starts as its own byte token. "
                "Below we show that the base vocabulary always begins with all 256 possible bytes."
                "</div>"
            )

            initial_vocab_html = widgets.HTML(
                "<div style='margin-top:8px;'>"
                "<b>Initial Vocabulary (truncated):</b><br>"
                "<span style='font-family:monospace;'>0: 0x00<br>...<br>255: 0xff</span>"
                "</div>"
            )

            widgets_to_show = [
                widgets.HTML("<b>Step 0: Initial state (no merges yet)</b>"),
                color_tokens_widget,
                explanation_html,
                initial_vocab_html
            ]

        else:
            s = steps[step-1]

            color_tokens_widget = color_tokens(s['seq'], s['id_to_token'])

            # Explanation text
            left_bytes  = resolve_token(s['merged'][0], s['id_to_token'])
            right_bytes = resolve_token(s['merged'][1], s['id_to_token'])
            left_txt  = html.escape(to_printable(left_bytes, show_hex_for_control=False, dot_for_control=True))
            right_txt = html.escape(to_printable(right_bytes, show_hex_for_control=False, dot_for_control=True))

            merged_str = (
                f"Merged the most frequent pair <b>{s['merged']}</b> "
                f"(<span style='font-family:monospace;'>'{left_txt}' + '{right_txt}'</span>) "
                f"into new token <b>{s['new_token']}</b>."
            )

            explanation_html = widgets.HTML(
                "<div style='margin:10px 0 18px 0;padding:8px 12px;"
                "background:#f5f5f5;border-left:4px solid #2196f3;'>"
                f"<b>Explanation:</b> {merged_str}</div>"
            )

            # Table of *newly created* vocab items
            merged_vocab_html = vocab_table(
                s['vocab'], s['id_to_token'],
                highlight_token=s['new_token'],
                merged_only=True
            )

            # Static short reminder of base vocab
            initial_vocab_html = widgets.HTML(
                "<div style='margin-top:8px;'>"
                "<b>Initial Vocabulary (truncated):</b><br>"
                "<span style='font-family:monospace;'>0: 0x00<br>...<br>255: 0xff</span>"
                "</div>"
            )

            pair_freq_box = pair_freq_bar(
                s['pairs'], s['id_to_token'],
                highlight_pair=s['merged'], max_pairs=10
            )

            # --- 2-column layout (Voila safe) --------------------------------
            table_col = widgets.VBox(
                [merged_vocab_html, initial_vocab_html],
                layout=widgets.Layout(
                    width='320px',          # fixed width so it doesn't collapse
                    flex='0 0 320px',
                    overflow_y='auto',
                    align_items='stretch'
                )
            )
            plot_col = widgets.VBox(
                [pair_freq_box],
                layout=widgets.Layout(
                    flex='1 1 auto',
                    width='auto',
                    min_width='0px',        # allow shrink so table remains visible
                    align_items='stretch'
                )
            )
            hbox = widgets.HBox(
                [table_col, plot_col],
                layout=widgets.Layout(width='100%', align_items='flex-start')
            )

            widgets_to_show = [
                widgets.HTML(f"<b>Step {step}: Merge {s['merged']} into token {s['new_token']}</b>"),
                color_tokens_widget,
                explanation_html,
                hbox
            ]

        main_vbox.children = widgets_to_show
        update_buttons()

    # --- callbacks -----------------------------------------------------------

    def on_prev(_b):
        if state['step'] > 0:
            state['step'] -= 1
            show_step()

    def on_next(_b):
        if state['step'] < len(state['steps']):
            state['step'] += 1
            show_step()

    def on_auto_run(change):
        if auto_run_toggle.value:
            state['auto_stop'] = False
            while state['step'] < len(state['steps']) and not state['auto_stop']:
                import time
                time.sleep(0.7)
                state['step'] += 1
                show_step()
            auto_run_toggle.value = False
        else:
            state['auto_stop'] = True

    # wire events
    corpus_box.observe(update_steps, names='value')
    merge_input.observe(update_steps, names='value')
    merge_input.observe(update_vocab_size_label, names='value')
    prev_button.on_click(on_prev)
    next_button.on_click(on_next)
    auto_run_toggle.observe(on_auto_run, names='value')

    # initial compute
    update_steps()

    return ui

# ------------------------------------------------------------------
# Instantiate + display (explicit display helps Voila)
# ------------------------------------------------------------------
app = bpe_visualization_ui()
display(app)


VBox(children=(HTML(value='<h2>Byte Pair Encoding (BPE) Visualizer</h2>'), Textarea(value='low lower lowest sl…