# Analysis for Prompting Evals

In [None]:
import colorsys
import json
import os
from collections import defaultdict
from typing import Any

import matplotlib as mpl
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import pandas as pd
import pyrootutils
import regex as re
import seaborn as sns
import sklearn.metrics as sk_metrics
import statsmodels.api as sm
from IPython.display import display
from statsmodels.stats.outliers_influence import variance_inflation_factor

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

## Plotting Setup

In [None]:
FIGURES_DIR = PROJECT_ROOT / "notebooks" / "figures"


def darken(
    color: str
    | tuple[float, float, float]
    | dict[str, Any]
    | sns.palettes._ColorPalette,
    by: float = 0.2,
):
    """
    Darken a color by provided amount.
    """

    def _darken_color(c: str | tuple[float, float, float], by: float):
        by = min(max(0, by), 1)
        pct_darken = 1 - by

        if isinstance(c, str):
            c = sns.color_palette([c])[0]

        for c_i in c:
            if c_i > 1:
                c_i /= 255
        c_hls = colorsys.rgb_to_hls(c[0], c[1], c[2])
        # Darken the color by reducing the lightness

        c_hls = (
            c_hls[0],  # hue
            c_hls[1] * pct_darken,  # lightness
            c_hls[2],  # saturation
        )
        # Convert back to RGB
        c_rgb = colorsys.hls_to_rgb(c_hls[0], c_hls[1], c_hls[2])
        return c_rgb

    if isinstance(color, dict):
        # If color is a dictionary, assume it's a palette
        # and darken each color in the palette
        return {k: _darken_color(v, by) for k, v in color.items()}
    elif isinstance(color, sns.palettes._ColorPalette):
        colors = [_darken_color(c, by) for c in color]
        return sns.palettes._ColorPalette(colors)
    else:
        return _darken_color(color, by)


# For heatmaps, correlations, -1 to 1 scales, etc
CMAP_HEATMAP = "vlag"

# For any plots where color differentiates sample type
PALETTE_SAMPLE_TYPE = {
    "positive": darken("#ffcc66"),
    "negative": "#5c5cff",
    "unknown": "#000000",
}

# For any plots where color differentiates model
PALETTE_MODEL = darken(
    {
        # "gpt-4.1-nano": sns.color_palette("rocket_r", n_colors=3)[0],
        # "gpt-4.1-mini": sns.color_palette("rocket_r", n_colors=3)[1],
        # "gpt-4.1": sns.color_palette("rocket_r", n_colors=3)[2],
        "gpt-4.1-nano": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[0],
        "gpt-4.1-mini": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[1],
        "gpt-4.1": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[2],
        # "o4-mini": sns.color_palette("crest", n_colors=2)[0],
        # "o3": sns.color_palette("crest", n_colors=2)[1],
        "o4-mini": sns.color_palette("YlOrBr", n_colors=2)[0],
        "o3": sns.color_palette("YlOrBr", n_colors=2)[1],
        "gemma-3-1b": sns.cubehelix_palette(n_colors=5)[0],
        "gemma-3-4b": sns.cubehelix_palette(n_colors=5)[1],
        "gemma-3-12b": sns.cubehelix_palette(n_colors=5)[2],
        "gemma-3-27b": sns.cubehelix_palette(n_colors=5)[3],
    }
)

PALETTE_SCORE = {
    "Weighted F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[0],
    "Macro F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[1],
}

PALETTES = {
    "model": PALETTE_MODEL,
    "sample_type": PALETTE_SAMPLE_TYPE,
    "score": PALETTE_SCORE,
}

# For marking at-chance baselines
COLOR_AT_CHANCE = "#ff0000"  # Red
ALPHA_AT_CHANCE = 0.5

# Bar Chart settings
BAR_EDGE_COLOR = "black"
BAR_EDGE_WIDTH = 0.8


def display_palette(palette: dict[str, Any] | sns.palettes._ColorPalette):
    if isinstance(palette, sns.palettes._ColorPalette):
        display(palette)
    else:
        colors = list(palette.values())
        display(sns.color_palette(colors))


def filter_by_alpha(
    keys: list[str],
    ax,
    palette: dict[str, Any] | None = None,
    alpha=0.3,
    highlight: str | list[str] | None = None,
):
    alphas = defaultdict(lambda: alpha)
    if highlight is not None:
        if isinstance(highlight, str):
            highlight = [highlight]
        for h in highlight:
            alphas[h] = 1

    if palette is None:
        # look through all the palettes in PALETTE; find one whose keys match
        # the keys passed in; if found, use that palette; otherwise, throw an error
        for palette_name, palette_dict in PALETTES.items():
            if all(k in palette_dict for k in keys):
                palette = palette_dict
                break
        else:
            raise ValueError(f"No matching palette found for keys: {keys}")

    face_colors = ax.collections[0].get_facecolors()
    face_colors[:, 3] = alpha
    for key in keys:
        key_color = palette[key]

        # Get indices of face_colors whose first 3 values match the model color
        indices = [
            i
            for i, color in enumerate(face_colors)
            if (color[0], color[1], color[2]) == key_color[:3]
        ]
        for i in indices:
            face_colors[i][3] = alphas[key]
    ax.collections[0].set_facecolor(face_colors)


def legend_format(
    ax: mpl.axes._axes.Axes | sns.FacetGrid,
    keys: list[str] | None = None,
    title: str | None = None,
    **kwargs,
):
    if isinstance(ax, sns.FacetGrid):
        fg = ax
        ax = fg.ax

        _ = fg.legend.remove()

    # Legend Formatting
    handles, labels = ax.get_legend_handles_labels()

    if keys is not None:
        if "gpt-4.1" in keys:
            spacing_locs = [3, 6]
        else:
            ValueError(f"Unsure how to space legend for {keys=}")

        spacer = mpatches.Patch(alpha=0, linewidth=0)
        for sloc in spacing_locs:
            handles.insert(sloc, spacer)
            labels.insert(sloc, "")

    _ = ax.legend(handles, labels)
    _ = ax.get_legend().set_frame_on(False)

    if title is not None:
        _ = ax.get_legend().set_title(title)

    if "loc" not in kwargs:
        kwargs["loc"] = "upper left"
    if "bbox_to_anchor" not in kwargs:
        kwargs["bbox_to_anchor"] = (1, 1)
    _ = sns.move_legend(ax, **kwargs)


# MARK: Printing, experimentation, etc
display(sns.color_palette("terrain", n_colors=2, desat=0.8))
display(darken(sns.color_palette("terrain", n_colors=2, desat=0.8), by=0.2))

display_palette(PALETTE_MODEL)
display_palette(PALETTE_SAMPLE_TYPE)

In [None]:
YES_RE = re.compile(r"[^a-zA-Z]*\b(yes|no)\b[^a-zA-Z]*", re.IGNORECASE)


def extract_content(choices_list: list) -> str:
    try:
        return choices_list[0]["message"]["content"]
    except Exception as e:
        print(choices_list)


def extract_prediction(response: str) -> str:
    # get a list of all matches to YES_RE in `response`; take the last match
    # and check if it is a "yes" or "no" response

    matches = YES_RE.findall(response)
    if len(matches) == 0:
        return "unknown"
    else:
        last_match = matches[-1]
        if last_match.lower() == "yes":
            return "positive"
        else:
            return "negative"


extract_prediction("```yes```")

## Load data from disk

In [None]:
GRAMMARS_DIR = PROJECT_ROOT / "data" / "grammars"

small_subset_file = PROJECT_ROOT / "data" / "small_subset.txt"
large_subset_file = PROJECT_ROOT / "data" / "large_subset.txt"

response_df_file = PROJECT_ROOT / "data" / "response_df.feather"

# Check to see if the results_df_file exists
response_df = None
response_full_df = None
new_response_df = None
if os.path.exists(response_df_file):
    print(f"Loading `response_df` from {response_df_file}")
    response_df = pd.read_feather(
        response_df_file,
    )
else:
    print(f"No store found at {response_df_file}")

print("Loading input and results files")

inputs_file_pattern = "*_inputs.jsonl"
results_file_pattern = "*_results.jsonl"

batch_id_re = re.compile(r"^(batch_\w+)_")

old_batches = []
if response_df is not None:
    old_batches = response_df["batch_id"].unique()

# find all input files in subdirectories of GRAMMARS_DIR
input_files = list(GRAMMARS_DIR.rglob(inputs_file_pattern))
results_files = list(GRAMMARS_DIR.rglob(results_file_pattern))

# filter input_files and results_files to only include those which contain a directory
# name that is present in the small_subset_file or large_subset_file
with open(small_subset_file, "r") as f:
    small_subset = set(f.read().strip().split("\n"))
with open(large_subset_file, "r") as f:
    large_subset = set(f.read().strip().split("\n"))

keep_files = small_subset.union(large_subset)

input_files = [
    f
    for f in input_files
    if f.parent.name in keep_files
    and batch_id_re.search(f.name)
    and batch_id_re.search(f.name).group(1) not in old_batches
]
results_files = [
    f
    for f in results_files
    if f.parent.name in keep_files
    and batch_id_re.search(f.name)
    and batch_id_re.search(f.name).group(1) not in old_batches
]

print(
    f"Found {len(input_files)} new input files and {len(results_files)} new results files"
)

if (len(input_files) > 0) and (len(results_files) > 0):
    input_dfs = []

    inputs_dfs = []
    for f in input_files:
        i_df = pd.read_json(f, lines=True)
        i_json_struct = json.loads(i_df.to_json(orient="records"))
        i_flat_df = pd.json_normalize(i_json_struct)
        batch_id = batch_id_re.search(f.name).group(1)
        i_flat_df["batch_id"] = batch_id
        inputs_dfs.append(i_flat_df)
    inputs_df = pd.concat(inputs_dfs, ignore_index=True)

    del i_df, i_json_struct, i_flat_df, inputs_dfs

    results_dfs = []
    for f in results_files:
        r_df = pd.read_json(f, lines=True)
        r_json_struct = json.loads(r_df.to_json(orient="records"))
        r_flat_df = pd.json_normalize(r_json_struct)
        batch_id = batch_id_re.search(f.name).group(1)
        r_flat_df["batch_id"] = batch_id
        results_dfs.append(r_flat_df)
    results_df = pd.concat(results_dfs, ignore_index=True)

    del r_df, r_json_struct, r_flat_df, results_dfs

    # Merge inputs and results on the the batch_id and custom_id
    response_full_df = results_df.merge(
        inputs_df[
            [
                "custom_id",
                "batch_id",
                "body.metadata.sample_type",  # ground-truth label for sample
                "body.metadata.sample",  # the sample itself
                "body.metadata.grammar_file",  # grammar file used
                "body.metadata.model",  # model used
                "body.metadata.n_shots",  # n_shots used
            ]
        ],
        on=["batch_id", "custom_id"],
    )

    del results_df, inputs_df

    response_full_df = response_full_df.rename(
        columns={
            "body.metadata.sample_type": "sample.type.ground_truth",
            "body.metadata.sample": "sample",
            "body.metadata.grammar_file": "grammar_file",
            "body.metadata.model": "model",
            "body.metadata.n_shots": "n_shots",
        }
    )
    response_full_df = response_full_df.rename(
        columns={
            "body.metadata.sample_type": "sample.type.ground_truth",
            "body.metadata.sample": "sample",
            "body.metadata.grammar_file": "grammar_file",
            "body.metadata.model": "model",
            "body.metadata.n_shots": "n_shots",
        }
    )

if (response_df is not None) and (response_full_df is not None):
    response_full_df = response_full_df[
        ~response_full_df["batch_id"].isin(response_df["batch_id"].unique())
    ]
    print(f"Found {len(response_full_df)} new responses to add to response_df")

    new_batch_ids = response_full_df["batch_id"].unique()

if (response_full_df is not None) and (len(response_full_df) > 0):
    print("Processing new responses")
    response_full_df["model_response"] = response_full_df[
        "response.body.choices"
    ].apply(extract_content)

    new_response_df = response_full_df[
        [
            "sample",
            "sample.type.ground_truth",
            "model_response",
            "grammar_file",
            "model",
            "n_shots",
            "batch_id",
        ]
    ].copy()

    del response_full_df

    # drop columns with NA values
    new_response_df = new_response_df.dropna(axis=1)

    new_response_df["sample.type.predicted"] = new_response_df["model_response"].apply(
        extract_prediction
    )
    new_response_df["sample.length"] = new_response_df["sample"].apply(
        lambda s: len(str(s).split(" "))
    )
    new_response_df["correct"] = (
        new_response_df["sample.type.ground_truth"]
        == new_response_df["sample.type.predicted"]
    )
    new_response_df = new_response_df.dropna()
    new_response_df["n_shots"] = pd.Categorical(
        new_response_df["n_shots"],
        categories=["0", "2", "4", "8", "16", "32"],
        ordered=True,
    )
    new_response_df["sample.type.ground_truth"] = pd.Categorical(
        new_response_df["sample.type.ground_truth"],
        categories=["positive", "negative"],
        ordered=True,
    )
    new_response_df["sample.type.predicted"] = pd.Categorical(
        new_response_df["sample.type.predicted"],
        categories=["positive", "negative", "unknown"],
        ordered=True,
    )

    new_response_df["model"] = new_response_df["model"].str.replace(
        "_", "/", regex=False
    )

    print(new_response_df["model"].unique())

    # Shorten gemma model names
    new_response_df["model"] = new_response_df["model"].map(
        {
            "gpt-4.1-nano": "gpt-4.1-nano",
            "gpt-4.1-mini": "gpt-4.1-mini",
            "gpt-4.1": "gpt-4.1",
            "o4-mini": "o4-mini",
            "o3": "o3",
            "google/gemma-3-1b-it": "gemma-3-1b",
            "google/gemma-3-4b-it": "gemma-3-4b",
            "google/gemma-3-12b-it": "gemma-3-12b",
            "google/gemma-3-27b-it": "gemma-3-27b",
        }
    )

    new_response_df["model_type"] = pd.Categorical(
        new_response_df["model"].map(
            {
                "gpt-4.1-nano": "regular",
                "gpt-4.1-mini": "regular",
                "gpt-4.1": "regular",
                "o4-mini": "thinking",
                "o3": "thinking",
                "gemma-3-1b": "regular",
                "gemma-3-4b": "regular",
                "gemma-3-12b": "regular",
                "gemma-3-27b": "regular",
            }
        ),
        categories=["regular", "thinking"],
        ordered=True,
    )

    # Add all the new data to the response_df
    if response_df is None:
        response_df = new_response_df.copy()
    elif len(new_response_df) > 0:
        response_df = pd.concat([response_df, new_response_df], ignore_index=True)

    response_df["model"] = pd.Categorical(
        response_df["model"],
        categories=[
            "gpt-4.1-nano",
            "gpt-4.1-mini",
            "gpt-4.1",
            "o4-mini",
            "o3",
            "gemma-3-1b",
            "gemma-3-4b",
            "gemma-3-12b",
            "gemma-3-27b",
        ],
        ordered=True,
    )

    print("Saving `response_df` to Feather store")
    response_df.dropna().to_feather(response_df_file)

del new_response_df

response_df.info()

In [None]:
# response_df[response_df["model"] != "gemma-3-4b"].dropna().to_feather(response_df_file)

In [None]:
# check for any batch_ids which either have a results file with no inputs, or vice versa
results_batch_ids = set(batch_id_re.search(f.name).group(1) for f in results_files)
input_batch_ids = set(batch_id_re.search(f.name).group(1) for f in input_files)
missing_results = input_batch_ids - results_batch_ids

print(f"Found {len(missing_results)} batches with inputs but no results")

for batch_id in missing_results:
    # find the input file for this batch_id
    input_file = [
        f for f in input_files if batch_id_re.search(f.name).group(1) == batch_id
    ]
    if len(input_file) > 0:
        print(f"Found input file for batch_id {batch_id}: {input_file[0]}")

    results_file = [
        f for f in results_files if batch_id_re.search(f.name).group(1) == batch_id
    ]
    if len(results_file) > 0:
        print(f"Found results file for batch_id {batch_id}: {results_file[0]}")

In [None]:
response_df.columns

In [None]:
response_df.groupby(["grammar_file", "model"], observed=False)["sample"].count()

In [None]:
response_df.groupby(["model"], observed=True)["grammar_file"].nunique()

Load grammar and sample statistics, and annotate the F1 scores with those values.

In [None]:
grammar_stats_pattern = "grammar_stats.json"
samples_stats_pattern = "filtered_samples_stats.json"

grammar_stats_files = list(GRAMMARS_DIR.rglob(grammar_stats_pattern))
samples_stats_files = list(GRAMMARS_DIR.rglob(samples_stats_pattern))

grammar_stats_files = [f for f in grammar_stats_files if f.parent.name in keep_files]
samples_stats_files = [f for f in samples_stats_files if f.parent.name in keep_files]

grammar_stats_dicts = []
for f in grammar_stats_files:
    try:
        g_dict = json.loads(f.read_text())
        g_dict["grammar_file"] = f.parent.name
        grammar_stats_dicts.append(g_dict)
    except json.JSONDecodeError:
        print(f"Error reading {f}")
grammar_stats_df = pd.DataFrame(grammar_stats_dicts)

samples_stats_dicts = []
for f in samples_stats_files:
    try:
        s_dict = json.loads(f.read_text())
        s_dict["grammar_file"] = f.parent.name
        samples_stats_dicts.append(s_dict)
    except json.JSONDecodeError:
        print(f"Error reading {f}")
samples_stats_df = pd.DataFrame(samples_stats_dicts)

f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="weighted",
        ),
        include_groups=False,
    )
    .reset_index(name="weighted_f1_score")
)

f1_df = f1_df.join(
    grammar_stats_df.set_index("grammar_file"),
    on="grammar_file",
).join(
    samples_stats_df.set_index("grammar_file"),
    on="grammar_file",
)

f1_df = f1_df.dropna(axis=1)

macro_f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="macro",
        ),
        include_groups=False,
    )
    .reset_index(name="macro_f1_score")
)

micro_f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="micro",
        ),
        include_groups=False,
    )
    .reset_index(name="micro_f1_score")
)

f1_df = f1_df.join(
    macro_f1_df[
        ["macro_f1_score", "n_shots", "model", "model_type", "grammar_file"]
    ].set_index(["grammar_file", "n_shots", "model", "model_type"]),
    on=["grammar_file", "n_shots", "model", "model_type"],
).join(
    micro_f1_df[
        ["micro_f1_score", "n_shots", "model", "model_type", "grammar_file"]
    ].set_index(["grammar_file", "n_shots", "model", "model_type"]),
    on=["grammar_file", "n_shots", "model", "model_type"],
)

accuracy_df = response_df.join(
    grammar_stats_df.set_index("grammar_file"),
    on="grammar_file",
).join(
    samples_stats_df.set_index("grammar_file"),
    on="grammar_file",
)

del grammar_stats_df, samples_stats_df, grammar_stats_dicts, samples_stats_dicts

f1_df.info()

In [None]:
accuracy_df.info()

In [None]:
f1_df

## Correlation Analysis

In [None]:
corr_mat = f1_df[
    [
        "macro_f1_score",
        "weighted_f1_score",
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
        "compression_ratio",
        "mean_positive_depth",
        "median_positive_depth",
        "coverage",
    ]
].corr()

corr_mat

In [None]:
ax = sns.heatmap(
    corr_mat,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

# Save the figure
plt.savefig(
    FIGURES_DIR / "grammar_stats_correlation_matrix.pdf",
    bbox_inches="tight",
)

In [None]:
sliced_corr_mat = (
    corr_mat.iloc[:, 0:2].sort_values(by="macro_f1_score", ascending=False).iloc[2:]
)

ax = sns.heatmap(
    sliced_corr_mat,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(sliced_corr_mat.columns)):
    for r in range(len(sliced_corr_mat)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${sliced_corr_mat.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

_ = ax.set_title("Grammar Hyperparameter Correlations with F1 Scores", y=1.05)

plt.savefig(
    FIGURES_DIR / "grammar_stats_correlation_matrix_sliced.pdf",
    bbox_inches="tight",
)

In [None]:
corr_mat_acc = accuracy_df[
    [
        "correct",
        "sample.length",
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
        "compression_ratio",
        "mean_positive_depth",
        "median_positive_depth",
        "coverage",
    ]
].corr()

corr_mat_acc

In [None]:
sliced_corr_mat_acc = (
    corr_mat_acc.iloc[:, 0:1].sort_values(by="correct", ascending=False).iloc[2:]
)

ax = sns.heatmap(
    sliced_corr_mat_acc,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(sliced_corr_mat_acc.columns)):
    for r in range(len(sliced_corr_mat_acc)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${sliced_corr_mat_acc.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

_ = ax.set_title("Grammar Hyperparameter Correlations with F1 Scores", y=1.05)

## Multivariate Regression

### Weighted F1 Score

In [None]:
f1_stats_df = f1_df.copy().drop(
    columns=[
        "total_possible_samples",
        "uncompressed_size",
        "compressed_size",
        "grammar_file",
        "grammar_name",
        "n_shots",
    ]
)

f1_stats_df.info()

In [None]:
endogenous_vars = [
    "model",
    # "n_terminals",
    # "n_nonterminals",
    # "n_lexical_productions",
    "n_nonlexical_productions",
    "compression_ratio",
    "mean_positive_depth",
    "coverage",
]

X = f1_stats_df[endogenous_vars].copy()
X = pd.get_dummies(X, drop_first=True, dtype=int)  # one-hot encode categorical vars
X = sm.add_constant(X)  # add a constant term to the model

Y = f1_stats_df["weighted_f1_score"]

model = sm.OLS(Y, X).fit()
print(model.summary())

### Macro F1 Score

In [None]:
Y = f1_stats_df["macro_f1_score"]

model = sm.OLS(Y, X).fit()
print(model.summary())

### VIF

In [None]:
vif_df = (
    pd.DataFrame(
        {
            "variable": X.columns,
            "VIF": [variance_inflation_factor(X.values, i) for i in range(X.shape[1])],
        }
    )
    .drop(index=0)
    .sort_values(by="VIF", ascending=True)
    .reset_index(drop=True)
)
vif_df["Rating"] = vif_df["VIF"].apply(
    lambda x: "High" if x > 10 else "Moderate" if x > 5 else "Low"
)

vif_df

## Mean Accuracy/Score by Model and Sample Type

In [None]:
g = sns.catplot(
    data=accuracy_df,
    kind="bar",
    x="model",
    y="correct",
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    errorbar="se",
    height=3,
    aspect=2.8,
)

for ax in g.axes.flat:
    for i, bar in enumerate(ax.containers):
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

n_counts = accuracy_df.groupby("model", observed=False)["grammar_file"].nunique()
mean_accs = (
    accuracy_df.groupby(["model", "sample.type.ground_truth"], observed=False)[
        "correct"
    ]
    .mean()
    .groupby("model", observed=False)
    .mean()
)
mean_errors = accuracy_df.groupby("model", observed=False)["correct"].sem()

for ax in g.axes.flat:
    for i, category in enumerate(n_counts.index):
        try:
            pos_height = ax.containers[0][i].get_height()
            neg_height = ax.containers[1][i].get_height()
            max_height = max(pos_height, neg_height)

            count = n_counts[category]
            mean_acc = mean_accs[category]
            ax.text(i, -0.15, f"n={count}", ha="center", va="top")
            ax.text(
                i + (0.1 if pos_height > neg_height else -0.1),
                mean_acc,
                f"{mean_acc:.2f}",
                ha="left" if pos_height > neg_height else "right",
                va="center",
                fontweight="bold",
                fontsize=9,
            )

            # add  black diamond at mean accuracy
            ax.plot(
                i,
                mean_acc,
                marker="D",
                color="black",
                markersize=5,
                linewidth=0,
                label="mean (± sem)" if i == 0 else "_nolegend_",
            )

            # add error bar
            ax.errorbar(
                i,
                mean_acc,
                yerr=mean_errors[category],
                fmt="o",
                color="black",
                markersize=0,
                capsize=5,
                label="_nolegend_",
            )
        except IndexError:
            pass

_ = g.ax.axhline(
    y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
)

legend_format(
    ax=g,
    loc="center left",
    bbox_to_anchor=(0, 0.97),
    columnspacing=0,
)

_ = g.ax.set_ylabel("Mean Accuracy")
_ = g.ax.set_xlabel(None)
_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "accuracy_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
f1_melted_df = f1_df.melt(
    id_vars=["n_shots", "model", "model_type", "grammar_file"],
    value_vars=["weighted_f1_score", "macro_f1_score", "micro_f1_score"],
    var_name="average",
    value_name="score",
)

# map score_type to a more readable name
f1_melted_df["average"] = f1_melted_df["average"].map(
    {
        "weighted_f1_score": "Weighted F1",
        "macro_f1_score": "Macro F1",
        "micro_f1_score": "Micro F1",
    }
)

# Don't show micro f1 score
f1_melted_df = f1_melted_df[f1_melted_df["average"] != "Micro F1"]

g = sns.catplot(
    data=f1_melted_df,
    kind="bar",
    x="model",
    y="score",
    hue="average",
    palette=PALETTE_SCORE,
    height=3,
    aspect=3.5,
)

n_counts = accuracy_df.groupby("model", observed=True)["grammar_file"].nunique()

_ = g.ax.axhline(
    y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
)

for ax in g.axes.flat:
    for i, category in enumerate(n_counts.index):
        count = n_counts[category]
        ax.text(i, -0.15, f"n={count}", ha="center", va="top")

# Add score labels to bars
score_labels = f1_melted_df.groupby(["model", "average"], observed=True)["score"].mean()

for ax in g.axes.flat:
    for bar in ax.containers:
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

legend_format(
    ax=g,
    loc="center left",
    bbox_to_anchor=(0, 0.9),
    ncol=1,
)

_ = g.ax.set_ylabel("F1 Score")
_ = g.ax.set_xlabel(None)
_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "f1_by_model.pdf",
    bbox_inches="tight",
)

## Macro F1 Score by Complexity

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="compression_ratio",
    y="macro_f1_score",
    style="model",
    hue="model",
    hue_order=f1_df["model"].unique(),
    style_order=f1_df["model"].unique(),
    palette=PALETTE_MODEL,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("gzip Compression Ratio")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "compression_ratio_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="n_terminals",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTE_MODEL,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_xscale("log")

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("# of Terminals")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "n_terminals_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="mean_positive_depth",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("Mean Parse Depth")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "mean_parse_depth_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="coverage",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

filter_by_alpha(
    keys=f1_df["model"].unique(),
    highlight=["o3", "o4-mini"],
    alpha=0.2,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("Coverage")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "coverage_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="n_nonlexical_productions",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

filter_by_alpha(
    keys=f1_df["model"].unique(),
    highlight=["o3", "gemma-3-1b"],
    alpha=0.2,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_xscale("log")

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("# of Nonlexical Productions  [log scale]")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "n_nonlexical_productions_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
(
    f1_df[f1_df.n_nonlexical_productions > 100]
    .groupby("model", observed=False)["macro_f1_score"]
    .mean()
)

In [None]:
g = (
    sns.relplot(
        data=accuracy_df,
        kind="line",
        x="sample.length",
        y="correct",
        hue="sample.type.ground_truth",
        palette=PALETTES["sample_type"],
        errorbar="se",
        legend=None,
        col="model",
        height=2,
        aspect=0.7,
    )
    .set_titles("{col_name}")
    .set_axis_labels("", "Mean Accuracy")
    .tight_layout(w_pad=0)
)

g.set(xticks=[0, 25, 50])

g.axes.flat[0].set_xlabel("Sample Length")

for ax in g.axes.flat:
    _ = ax.axhline(
        y=0.5,
        color=COLOR_AT_CHANCE,
        alpha=ALPHA_AT_CHANCE,
        linestyle="--",
    )

# Add "positive" and "negative" labels around the final xy coords in the first plot
for ax in [g.axes.flat[0]]:
    for i, label in enumerate(["positive", "negative"]):
        x_coord = ax.lines[i].get_xdata()[-1]
        y_coord = ax.lines[i].get_ydata()[-1] + (0.02 if label == "positive" else -0.15)
        ax.text(
            x_coord,
            y_coord,
            label,
            ha="right",
            va="bottom",
            fontweight="bold",
            fontsize=9,
            color=darken(PALETTES["sample_type"][label], by=0.2),
        )


plt.savefig(
    FIGURES_DIR / "accuracy_by_sample_length.pdf",
    bbox_inches="tight",
)

In [None]:
preds_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=False)[
        "sample.type.predicted"
    ]
    .value_counts(normalize=True)
    .reset_index()
)


g = (
    sns.relplot(
        data=preds_df,
        kind="line",
        x="sample.length",
        y="proportion",
        hue="sample.type.predicted",
        palette=PALETTES["sample_type"],
        errorbar="se",
        legend=None,
        col="model",
        height=2,
        aspect=0.7,
    )
    .set_titles("{col_name}")
    .set_axis_labels("", "Predicted Sample Type")
    .tight_layout(w_pad=0)
)

g.set(xticks=[0, 25, 50])

g.axes.flat[0].set_xlabel("Sample Length")

for ax in g.axes.flat:
    _ = ax.axhline(
        y=0.5,
        color=COLOR_AT_CHANCE,
        alpha=ALPHA_AT_CHANCE,
        linestyle="--",
    )

# Add "positive" and "negative" labels around the final xy coords in the first plot
for ax in [g.axes.flat[0]]:
    for i, label in enumerate(["positive", "negative", "unknown"]):
        x_coord = ax.lines[i].get_xdata()[-1]
        y_coord = ax.lines[i].get_ydata()[-1] + 0.02
        ax.text(
            x_coord,
            y_coord,
            label,
            ha="right",
            va="bottom",
            fontweight="bold",
            fontsize=9,
            color=darken(PALETTES["sample_type"][label], by=0.2),
        )

plt.savefig(
    FIGURES_DIR / "predicted_sample_type_by_sample_length.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = sns.lineplot(
    data=accuracy_df,
    x="sample.length",
    y="correct",
    hue="model",
    style="model",
    palette=PALETTE_MODEL,
    errorbar="se",
    ax=ax,
)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(None, 1.02)
_ = ax.set_xlabel("Sample Length")
_ = ax.set_ylabel("Mean Accuracy")

plt.savefig(
    FIGURES_DIR / "accuracy_by_sample_length_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
# Plot entropy of model responses by sequence length. Calculate based on the "correct" column of the accuracy_df, which is a binary variable, when grouped by model and sample.length. This will give a distribution over 0's and 1's.

import numpy as np


def binary_entropy(series) -> float:
    counts = series.value_counts().values
    total = counts.sum()
    p_correct = counts[0] / total
    if p_correct == 0 or p_correct == 1:
        return 0
    return -(p_correct * np.log2(p_correct) + (1 - p_correct) * np.log2(1 - p_correct))


entropy_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=False)[
        "sample.type.predicted"
    ]
    .apply(binary_entropy)
    .reset_index(name="entropy")
)


fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)


_ = sns.lineplot(
    data=entropy_df,
    x="sample.length",
    y="entropy",
    hue="model",
    style="model",
    palette=PALETTE_MODEL,
    errorbar="se",
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

# _ = ax.set_xscale("log")

_ = ax.set_xlabel("Sample Length")
_ = ax.set_ylabel("Entropy of Model Predictions")

plt.savefig(
    FIGURES_DIR / "entropy_by_sample_length_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
(
    accuracy_df[accuracy_df.model == "gemma-3-1b"]
    .groupby("sample.length", observed=False)["sample.type.predicted"]
    .value_counts(normalize=True)
    .reset_index()
)

## Histogram of Sample Lengths

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(111)

sns.histplot(
    data=response_df,
    x="sample.length",
    ax=ax,
    binwidth=1,
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    alpha=0.8,
)

_ = ax.get_legend().set_title("Sample type")
_ = ax.set_xlabel("Sample length")

plt.savefig(
    FIGURES_DIR / "sample_length_histogram.pdf",
    bbox_inches="tight",
)

In [None]:
sample_counts_df = (
    response_df.groupby(["sample.type.ground_truth"], observed=False)[
        "sample.type.ground_truth"
    ]
    .count()
    .reset_index(name="count")
)
total_samples = sample_counts_df["count"].sum()
sample_counts_df["proportion"] = sample_counts_df["count"] / total_samples

g = sns.catplot(
    data=sample_counts_df,
    kind="bar",
    x="sample.type.ground_truth",
    y="proportion",
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    height=3,
    aspect=0.8,
).set_axis_labels("", "Proportion of Samples")

for ax in g.axes.flat:
    for i, bar in enumerate(ax.containers):
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

_ = g.ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "sample_type_proportions.pdf",
    bbox_inches="tight",
)

In [None]:
gemma_responses_df = response_df[(response_df.model == "gemma-3-4b")][
    [
        "model_response",
        "sample.type.predicted",
        "sample.length",
        "correct",
        "grammar_file",
    ]
].sort_values(by="grammar_file")[
    ["model_response", "sample.type.predicted", "sample.length", "correct"]
]

In [None]:
print(
    gemma_responses_df[gemma_responses_df["sample.type.predicted"] == "unknown"].iloc[
        50
    ]["model_response"]
)

In [None]:
# with max_new_tokens=2048: np.float64(376.80102040816325)

gemma_responses_df["model_response"].apply(lambda x: len(x)).mean()