<a href="https://colab.research.google.com/github/iandsilva4/LearningML/blob/main/BytePairEncoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
# https://mybinder.org/v2/gh/iandsilva4/LearningML/main?urlpath=voila%2Frender%2FBytePairEncoding.ipynb

import ipywidgets as widgets
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import collections

def color_tokens(seq):
    color_list = list(mcolors.TABLEAU_COLORS.values()) + list(mcolors.CSS4_COLORS.values())
    html = ""
    for i, c in enumerate(seq):
        color = color_list[i % len(color_list)]
        if isinstance(c, int):
            try:
                char = chr(c) if 32 <= c <= 126 else '.'
            except:
                char = '.'
        else:
            char = c
        html += f"<span style='background-color:{color};padding:2px 4px;margin:1px;border-radius:3px;color:black;font-weight:bold'>{char}</span>"
    return widgets.HTML(html)

def get_pair_frequencies(seq):
    pairs = collections.Counter()
    for i in range(len(seq) - 1):
        pairs[(seq[i], seq[i+1])] += 1
    return pairs

def merge_pair(seq, pair, new_token):
    i = 0
    output = []
    while i < len(seq):
        if i < len(seq) - 1 and (seq[i], seq[i+1]) == pair:
            output.append(new_token)
            i += 2
        else:
            output.append(seq[i])
            i += 1
    return output

def build_bpe_steps(text, num_merges):
    seq = list(text.encode('utf-8'))
    steps = [{'seq': seq.copy()}]
    next_token = 256
    for _ in range(num_merges):
        pairs = get_pair_frequencies(seq)
        if not pairs:
            break
        most_freq = max(pairs.items(), key=lambda x: x[1])[0]
        seq = merge_pair(seq, most_freq, next_token)
        steps.append({'seq': seq.copy(), 'pairs': pairs.copy(), 'merged': most_freq, 'new_token': next_token})
        next_token += 1
    return steps

def pair_freq_bar(pairs, highlight_pair=None, max_pairs=10):
    if not pairs:
        return widgets.Output()
    sorted_pairs = sorted(pairs.items(), key=lambda x: -x[1])[:max_pairs]
    labels = [f"{chr(a) if 32 <= a <= 126 else '.'}{chr(b) if 32 <= b <= 126 else '.'}" for (a, b), _ in sorted_pairs]
    values = [freq for _, freq in sorted_pairs]
    colors = ['#ff7043' if (a, b) == highlight_pair else '#90caf9' for (a, b), _ in sorted_pairs]
    out = widgets.Output()
    with out:
        fig, ax = plt.subplots(figsize=(min(12, 0.7*len(labels)), 2.5))
        ax.bar(labels, values, color=colors)
        ax.set_ylabel("Frequency")
        ax.set_title("Pair Frequencies (Top 10)")
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()
    return out

def bpe_visualization_ui():
    label = widgets.Label("BPE Visualizer Stepper")
    corpus_box = widgets.Textarea(value="hello world", description="Corpus:")
    merges_input = widgets.IntText(value=5, min=1, max=100, step=1, description='Merges:')
    prev_button = widgets.Button(description="Previous", disabled=True)
    next_button = widgets.Button(description="Next", disabled=False)
    step_label = widgets.Label(value='Step: 0')
    color_tokens_out = widgets.Output()
    pair_freq_out = widgets.Output()
    state = {'steps': [], 'step': 0}

    def update_steps(*args):
        state['steps'] = build_bpe_steps(corpus_box.value, merges_input.value)
        state['step'] = 0
        update_ui()

    def update_ui():
        steps = state['steps']
        step = state['step']
        prev_button.disabled = (step == 0)
        next_button.disabled = (step >= len(steps) - 1)
        step_label.value = f"Step: {step}"
        with color_tokens_out:
            color_tokens_out.clear_output()
            display(color_tokens(steps[step]['seq']))
        with pair_freq_out:
            pair_freq_out.clear_output()
            if step == 0:
                print("No pairs yet.")
            else:
                pairs = steps[step]['pairs']
                merged = steps[step]['merged']
                display(pair_freq_bar(pairs, highlight_pair=merged, max_pairs=10))

    def on_prev(b):
        if state['step'] > 0:
            state['step'] -= 1
            update_ui()

    def on_next(b):
        if state['step'] < len(state['steps']) - 1:
            state['step'] += 1
            update_ui()

    corpus_box.observe(update_steps, names='value')
    merges_input.observe(update_steps, names='value')
    prev_button.on_click(on_prev)
    next_button.on_click(on_next)
    update_steps()
    return widgets.VBox([
        label,
        corpus_box,
        merges_input,
        widgets.HBox([prev_button, next_button, step_label]),
        color_tokens_out,
        pair_freq_out
    ])

bpe_visualization_ui()

VBox(children=(Label(value='BPE Visualizer Stepper'), Textarea(value='hello world', description='Corpus:'), In…