# Token distribution analysis

In [None]:
import os

import aesthetics as aes
import matplotlib.pyplot as plt
import pandas as pd
import pyrootutils
import seaborn as sns
from aesthetics import *
from transformers import AutoTokenizer

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

In [None]:
response_df_file = PROJECT_ROOT / "data" / "response_df.feather"

response_df = pd.read_feather(response_df_file)

In [None]:
local_df = response_df[
    response_df["model"].isin(["gemma-3-1b", "gemma-3-4b", "DSR1-7B"])
].copy()
local_df.info()

tokenizers = {
    "gemma-3-1b": AutoTokenizer.from_pretrained("google/gemma-3-1b-it"),
    "gemma-3-4b": AutoTokenizer.from_pretrained("google/gemma-3-4b-it"),
    "DSR1-7B": AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"),
}

In [None]:
local_df["tokens"] = local_df.apply(
    lambda row: len(tokenizers[row["model"]].tokenize(row["model_response"])), axis=1
)

In [None]:
fig_height = 2.4
fig = plt.figure(figsize=(aes.PAPER_WIDTH_IN, fig_height))

grid = fig.add_gridspec(
    nrows=2,
    ncols=len(tokenizers.keys()),
    wspace=0.05,
    hspace=0.9,
)

max_tokens = {
    "gemma-3-1b": local_df[local_df["model"] == "gemma-3-1b"]["tokens"].max(),
    "gemma-3-4b": local_df[local_df["model"] == "gemma-3-4b"]["tokens"].max(),
    "DSR1-7B": local_df[local_df["model"] == "DSR1-7B"]["tokens"].max(),
}

local_df["completed"] = local_df.apply(
    lambda r: False if r["tokens"] < (max_tokens[r["model"]] - 100) else True, axis=1
)

with sns.plotting_context("paper", rc=aes.rcs):
    for r in range(2):
        for c, model_name in enumerate(tokenizers.keys()):
            ax = fig.add_subplot(grid[r, c])
            model_df = local_df[local_df["model"] == model_name]
            model_df = model_df.sort_values(by="tokens")

            if r == 0:
                sns.histplot(
                    data=model_df,
                    x="tokens",
                    color=aes.MODEL_COLOR,
                    alpha=0.5,
                    # binwidth=200,
                    bins=15,
                    stat="proportion",
                    ax=ax,
                    legend=False,
                )

                ax.set_title(model_name)
                ax.set_xlabel("Completion length", ha="left", x=0.0)
                ax.set_ylabel("Completions")

                # Format y-axis as percentage
                ax.yaxis.set_major_formatter(
                    mpl.ticker.PercentFormatter(xmax=1, decimals=0, symbol="%")
                )

                if c > 0:
                    ax.set_yticks([])
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
            elif r == 1:
                by_grammars_df = (
                    model_df.groupby(["completed", "grammar_file"], observed=True)[
                        "correct"
                    ]
                    .mean()
                    .reset_index()
                )
                sns.histplot(
                    data=by_grammars_df,
                    x="correct",
                    hue="completed",
                    ax=ax,
                    legend=False,
                    stat="proportion",
                    binwidth=0.05,
                    # col="model"
                )

                # Format y-axis as percentage
                ax.yaxis.set_major_formatter(
                    mpl.ticker.PercentFormatter(xmax=1, decimals=0, symbol="%")
                )

                ax.set_xlabel("Mean accuracy, by grammar", ha="left", x=0.0)
                ax.set_ylabel("Completions")
                ax.set_xlim(-0.05, 1.05)

                if c > 0:
                    ax.set_yticks([])
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
            else:
                by_sl_df = (
                    model_df.groupby(["completed", "sample.length"], observed=True)[
                        "correct"
                    ]
                    .mean()
                    .reset_index()
                )
                sns.histplot(
                    data=by_sl_df,
                    x="correct",
                    hue="completed",
                    ax=ax,
                    legend=False if c < 2 else True,
                    stat="proportion",
                    binwidth=0.02,
                    # col="model"
                )

                # Format y-axis as percentage
                ax.yaxis.set_major_formatter(
                    mpl.ticker.PercentFormatter(xmax=1, decimals=0, symbol="%")
                )

                ax.set_xlabel(
                    "Mean accuracy, by example complexity (length)", ha="left", x=0.0
                )
                ax.set_ylabel("Completions")
                ax.set_xlim(-0.05, 1.05)
                ax.set_xticks([0, 0.5, 1])

                if c > 0:
                    ax.set_yticks([])
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)

                if c == 2:
                    sns.move_legend(
                        ax,
                        "lower center",
                        bbox_to_anchor=(0.5, -1.2),
                        ncol=2,
                        title="Completed",
                    )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

In [None]:
by_grammars_df = (
    local_df.groupby(["model", "completed", "grammar_file"], observed=True)["correct"]
    .mean()
    .reset_index()
)

In [None]:
sns.histplot(
    data=by_grammars_df.query("model == 'DSR1-7B'"),
    x="correct",
    hue="completed",
    # col="model"
)