# Analysis for Prompting Evals

In [None]:
import colorsys
import json
import os
from collections import defaultdict
from typing import Any

import matplotlib as mpl
import matplotlib.gridspec as gs
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyrootutils
import regex as re
import seaborn as sns
import sklearn.metrics as sk_metrics
import statsmodels.api as sm
import statsmodels.formula.api as smf
from IPython.display import display
from statsmodels.stats.outliers_influence import variance_inflation_factor

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

## Plotting Setup

In [None]:
FIGURES_DIR = PROJECT_ROOT / "notebooks" / "figures"

PAPER_WIDTH_IN = 5.5

rcs = {
    "font.size": 10.0,
    "axes.labelsize": "small",
    "axes.titlesize": "small",
    "xtick.labelsize": "x-small",
    "ytick.labelsize": "x-small",
}


def darken(
    color: str
    | tuple[float, float, float]
    | dict[str, Any]
    | sns.palettes._ColorPalette,
    by: float = 0.2,
):
    """
    Darken a color by provided amount.
    """

    def _darken_color(c: str | tuple[float, float, float], by: float):
        by = min(max(0, by), 1)
        pct_darken = 1 - by

        if isinstance(c, str):
            c = sns.color_palette([c])[0]

        for c_i in c:
            if c_i > 1:
                c_i /= 255
        c_hls = colorsys.rgb_to_hls(c[0], c[1], c[2])
        # Darken the color by reducing the lightness

        c_hls = (
            c_hls[0],  # hue
            c_hls[1] * pct_darken,  # lightness
            c_hls[2],  # saturation
        )
        # Convert back to RGB
        c_rgb = colorsys.hls_to_rgb(c_hls[0], c_hls[1], c_hls[2])
        return c_rgb

    if isinstance(color, dict):
        # If color is a dictionary, assume it's a palette
        # and darken each color in the palette
        return {k: _darken_color(v, by) for k, v in color.items()}
    elif isinstance(color, sns.palettes._ColorPalette):
        colors = [_darken_color(c, by) for c in color]
        return sns.palettes._ColorPalette(colors)
    else:
        return _darken_color(color, by)


# For heatmaps, correlations, -1 to 1 scales, etc
CMAP_HEATMAP = "vlag"

# For any plots where color differentiates sample type
PALETTE_SAMPLE_TYPE = {
    "positive": darken("#ffcc66"),
    "negative": "#5c5cff",
    "unknown": "#C41E3A",
}

# For any plots where color differentiates model
PALETTE_MODEL = darken(
    {
        "gpt-4.1-nano": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[0],
        "gpt-4.1-mini": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[1],
        "gpt-4.1": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[2],
        "o4-mini": sns.color_palette("YlOrBr", n_colors=2)[0],
        "o3": sns.color_palette("YlOrBr", n_colors=2)[1],
        "gemma-3-1b": sns.cubehelix_palette(n_colors=5)[0],
        "gemma-3-4b": sns.cubehelix_palette(n_colors=5)[1],
        "gemma-3-12b": sns.cubehelix_palette(n_colors=5)[2],
        "gemma-3-27b": sns.cubehelix_palette(n_colors=5)[3],
        "DSR1-7B": sns.color_palette("cool", n_colors=2)[0],
    }
)

PALETTE_SCORE = {
    "Weighted F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[0],
    "Macro F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[1],
}

PALETTE_STRAGETY = {
    "rule-based": "#942822",
    "heuristic": "#FFE7CE",
    "code": sns.color_palette("gist_earth", n_colors=4)[2],
    "unknown": sns.color_palette("gist_earth", n_colors=4)[3],
}

PALETTES = {
    "model": PALETTE_MODEL,
    "sample_type": PALETTE_SAMPLE_TYPE,
    "score": PALETTE_SCORE,
    "strategy": PALETTE_STRAGETY,
}

MODEL_COLOR = "#4CA970"

# For marking at-chance baselines
COLOR_AT_CHANCE = "#ff0000"  # Red
ALPHA_AT_CHANCE = 0.5

# Bar Chart settings
BAR_EDGE_COLOR = "black"
BAR_EDGE_WIDTH = 0.8


def display_palette(palette: dict[str, Any] | sns.palettes._ColorPalette):
    if isinstance(palette, sns.palettes._ColorPalette):
        display(palette)
    else:
        colors = list(palette.values())
        display(sns.color_palette(colors))


def filter_by_alpha(
    keys: list[str],
    ax,
    palette: dict[str, Any] | None = None,
    alpha=0.3,
    highlight: str | list[str] | None = None,
):
    alphas = defaultdict(lambda: alpha)
    if highlight is not None:
        if isinstance(highlight, str):
            highlight = [highlight]
        for h in highlight:
            alphas[h] = 1

    if palette is None:
        # look through all the palettes in PALETTE; find one whose keys match
        # the keys passed in; if found, use that palette; otherwise, throw an error
        for palette_name, palette_dict in PALETTES.items():
            if all(k in palette_dict for k in keys):
                palette = palette_dict
                break
        else:
            raise ValueError(f"No matching palette found for keys: {keys}")

    face_colors = ax.collections[0].get_facecolors()
    face_colors[:, 3] = alpha
    for key in keys:
        key_color = palette[key]

        # Get indices of face_colors whose first 3 values match the model color
        indices = [
            i
            for i, color in enumerate(face_colors)
            if (color[0], color[1], color[2]) == key_color[:3]
        ]
        for i in indices:
            face_colors[i][3] = alphas[key]
    ax.collections[0].set_facecolor(face_colors)


def legend_format(
    ax: mpl.axes._axes.Axes | sns.FacetGrid,
    keys: list[str] | None = None,
    title: str | None = None,
    **kwargs,
):
    if isinstance(ax, sns.FacetGrid):
        fg = ax
        ax = fg.ax

        _ = fg.legend.remove()

    # Legend Formatting
    handles, labels = ax.get_legend_handles_labels()

    if keys is not None:
        if "gpt-4.1" in keys:
            spacing_locs = [3, 6]
        else:
            ValueError(f"Unsure how to space legend for {keys=}")

        spacer = mpatches.Patch(alpha=0, linewidth=0)
        for sloc in spacing_locs:
            handles.insert(sloc, spacer)
            labels.insert(sloc, "")

    _ = ax.legend(handles, labels)
    _ = ax.get_legend().set_frame_on(False)

    if title is not None:
        _ = ax.get_legend().set_title(title)

    if "loc" not in kwargs:
        kwargs["loc"] = "upper left"
    if "bbox_to_anchor" not in kwargs:
        kwargs["bbox_to_anchor"] = (1, 1)
    _ = sns.move_legend(ax, **kwargs)


# MARK: Printing, experimentation, etc
display(sns.color_palette("terrain", n_colors=2, desat=0.8))
display(darken(sns.color_palette("terrain", n_colors=2, desat=0.8), by=0.2))

display_palette(PALETTE_MODEL)
display_palette(PALETTE_SAMPLE_TYPE)

In [None]:
PALETTE_SAMPLE_TYPE

In [None]:
YES_RE = re.compile(r"[^a-zA-Z]*\b(yes|no)\b[^a-zA-Z]*", re.IGNORECASE)


def extract_content(choices_list: list) -> str:
    try:
        return choices_list[0]["message"]["content"]
    except Exception as e:
        print(choices_list)


def extract_prediction(response: str) -> str:
    # get a list of all matches to YES_RE in `response`; take the last match
    # and check if it is a "yes" or "no" response

    matches = YES_RE.findall(response)
    if len(matches) == 0:
        return "unknown"
    else:
        last_match = matches[-1]
        if last_match.lower() == "yes":
            return "positive"
        else:
            return "negative"


extract_prediction("```yes```")

## Load data from disk

In [None]:
GRAMMARS_DIR = PROJECT_ROOT / "data" / "grammars"

small_subset_file = PROJECT_ROOT / "data" / "small_subset.txt"
large_subset_file = PROJECT_ROOT / "data" / "large_subset.txt"

response_df_file = PROJECT_ROOT / "data" / "response_df.feather"

# Check to see if the results_df_file exists
response_df = None
response_full_df = None
new_response_df = None
if os.path.exists(response_df_file):
    print(f"Loading `response_df` from {response_df_file}")
    response_df = pd.read_feather(
        response_df_file,
    )
else:
    print(f"No store found at {response_df_file}")

print("Loading input and results files")

inputs_file_pattern = "*_inputs.jsonl"
results_file_pattern = "*_results.jsonl"

batch_id_re = re.compile(r"^(batch_\w+)_")

old_batches = []
if response_df is not None:
    old_batches = response_df["batch_id"].unique()

# find all input files in subdirectories of GRAMMARS_DIR
input_files = list(GRAMMARS_DIR.rglob(inputs_file_pattern))
results_files = list(GRAMMARS_DIR.rglob(results_file_pattern))

# filter input_files and results_files to only include those which contain a directory
# name that is present in the small_subset_file or large_subset_file
with open(small_subset_file, "r") as f:
    small_subset = set(f.read().strip().split("\n"))
with open(large_subset_file, "r") as f:
    large_subset = set(f.read().strip().split("\n"))

keep_files = small_subset.union(large_subset)

input_files = [
    f
    for f in input_files
    if f.parent.name in keep_files
    and batch_id_re.search(f.name)
    and batch_id_re.search(f.name).group(1) not in old_batches
]
results_files = [
    f
    for f in results_files
    if f.parent.name in keep_files
    and batch_id_re.search(f.name)
    and batch_id_re.search(f.name).group(1) not in old_batches
]

print(
    f"Found {len(input_files)} new input files and {len(results_files)} new results files"
)

if (len(input_files) > 0) and (len(results_files) > 0):
    input_dfs = []

    inputs_dfs = []
    for f in input_files:
        i_df = pd.read_json(f, lines=True)
        i_json_struct = json.loads(i_df.to_json(orient="records"))
        i_flat_df = pd.json_normalize(i_json_struct)
        batch_id = batch_id_re.search(f.name).group(1)
        i_flat_df["batch_id"] = batch_id
        inputs_dfs.append(i_flat_df)
    inputs_df = pd.concat(inputs_dfs, ignore_index=True)

    del i_df, i_json_struct, i_flat_df, inputs_dfs

    results_dfs = []
    for f in results_files:
        r_df = pd.read_json(f, lines=True)
        r_json_struct = json.loads(r_df.to_json(orient="records"))
        r_flat_df = pd.json_normalize(r_json_struct)
        batch_id = batch_id_re.search(f.name).group(1)
        r_flat_df["batch_id"] = batch_id
        results_dfs.append(r_flat_df)
    results_df = pd.concat(results_dfs, ignore_index=True)

    del r_df, r_json_struct, r_flat_df, results_dfs

    # Merge inputs and results on the the batch_id and custom_id
    response_full_df = results_df.merge(
        inputs_df[
            [
                "custom_id",
                "batch_id",
                "body.metadata.sample_type",  # ground-truth label for sample
                "body.metadata.sample",  # the sample itself
                "body.metadata.grammar_file",  # grammar file used
                "body.metadata.model",  # model used
                "body.metadata.n_shots",  # n_shots used
            ]
        ],
        on=["batch_id", "custom_id"],
    )

    # del results_df, inputs_df

    response_full_df = response_full_df.rename(
        columns={
            "body.metadata.sample_type": "sample.type.ground_truth",
            "body.metadata.sample": "sample",
            "body.metadata.grammar_file": "grammar_file",
            "body.metadata.model": "model",
            "body.metadata.n_shots": "n_shots",
            "response.body.usage.prompt_tokens": "prompt_tokens",
            "response.body.usage.completion_tokens": "completion_tokens",
            "response.body.usage.total_tokens": "total_tokens",
            "response.body.usage.completion_tokens_details.reasoning_tokens": "reasoning_tokens",
        }
    )

    for toc_col in [
        "prompt_tokens",
        "completion_tokens",
        "total_tokens",
        "reasoning_tokens",
    ]:
        response_full_df[toc_col] = response_full_df[toc_col].fillna(0)

if (response_df is not None) and (response_full_df is not None):
    response_full_df = response_full_df[
        ~response_full_df["batch_id"].isin(response_df["batch_id"].unique())
    ]
    print(f"Found {len(response_full_df)} new responses to add to response_df")

    new_batch_ids = response_full_df["batch_id"].unique()

if (response_full_df is not None) and (len(response_full_df) > 0):
    print("Processing new responses")
    response_full_df["model_response"] = response_full_df[
        "response.body.choices"
    ].apply(extract_content)

    new_response_df = response_full_df[
        [
            "sample",
            "sample.type.ground_truth",
            "model_response",
            "grammar_file",
            "model",
            "n_shots",
            "batch_id",
            "prompt_tokens",
            "completion_tokens",
            "total_tokens",
            "reasoning_tokens",
        ]
    ].copy()

    del response_full_df

    # drop columns with NA values
    new_response_df = new_response_df.dropna(axis=1)

    new_response_df["sample.type.predicted"] = new_response_df["model_response"].apply(
        extract_prediction
    )
    new_response_df["sample.length"] = new_response_df["sample"].apply(
        lambda s: len(str(s).split(" "))
    )
    new_response_df["correct"] = (
        new_response_df["sample.type.ground_truth"]
        == new_response_df["sample.type.predicted"]
    )
    new_response_df = new_response_df.dropna()
    new_response_df["n_shots"] = pd.Categorical(
        new_response_df["n_shots"],
        categories=["0", "2", "4", "8", "16", "32"],
        ordered=True,
    )
    new_response_df["sample.type.ground_truth"] = pd.Categorical(
        new_response_df["sample.type.ground_truth"],
        categories=["positive", "negative"],
        ordered=True,
    )
    new_response_df["sample.type.predicted"] = pd.Categorical(
        new_response_df["sample.type.predicted"],
        categories=["positive", "negative", "unknown"],
        ordered=True,
    )

    new_response_df["model"] = new_response_df["model"].str.replace(
        "_", "/", regex=False
    )

    print(new_response_df["model"].unique())

    # Shorten gemma model names
    new_response_df["model"] = new_response_df["model"].map(
        {
            "gpt-4.1-nano": "gpt-4.1-nano",
            "gpt-4.1-mini": "gpt-4.1-mini",
            "gpt-4.1": "gpt-4.1",
            "o4-mini": "o4-mini",
            "o3": "o3",
            "google/gemma-3-1b-it": "gemma-3-1b",
            "google/gemma-3-4b-it": "gemma-3-4b",
            "google/gemma-3-12b-it": "gemma-3-12b",
            "google/gemma-3-27b-it": "gemma-3-27b",
            "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B": "DSR1-7B",
        }
    )

    new_response_df["model_type"] = pd.Categorical(
        new_response_df["model"].map(
            {
                "gpt-4.1-nano": "regular",
                "gpt-4.1-mini": "regular",
                "gpt-4.1": "regular",
                "o4-mini": "thinking",
                "o3": "thinking",
                "gemma-3-1b": "regular",
                "gemma-3-4b": "regular",
                "gemma-3-12b": "regular",
                "gemma-3-27b": "regular",
                "DSR1-7B": "thinking",
            }
        ),
        categories=["regular", "thinking"],
        ordered=True,
    )

    # Add all the new data to the response_df
    if response_df is None:
        response_df = new_response_df.copy()
    elif len(new_response_df) > 0:
        response_df = pd.concat([response_df, new_response_df], ignore_index=True)

    response_df["model"] = pd.Categorical(
        response_df["model"],
        categories=[
            "gpt-4.1-nano",
            "gpt-4.1-mini",
            "gpt-4.1",
            "o4-mini",
            "o3",
            "gemma-3-1b",
            "gemma-3-4b",
            "gemma-3-12b",
            "gemma-3-27b",
            "DSR1-7B",
        ],
        ordered=True,
    )

    print("Saving `response_df` to Feather store")
    response_df.dropna().to_feather(response_df_file)

del new_response_df

response_df.info()

In [None]:
# results_df.info()
# response_df["model"].unique()

In [None]:
# response_df[response_df["model"] != "gemma-3-1b"].dropna().to_feather(response_df_file)

In [None]:
# check for any batch_ids which either have a results file with no inputs, or vice versa
results_batch_ids = set(batch_id_re.search(f.name).group(1) for f in results_files)
input_batch_ids = set(batch_id_re.search(f.name).group(1) for f in input_files)
missing_results = input_batch_ids - results_batch_ids

print(f"Found {len(missing_results)} batches with inputs but no results")

for batch_id in missing_results:
    # find the input file for this batch_id
    input_file = [
        f for f in input_files if batch_id_re.search(f.name).group(1) == batch_id
    ]
    if len(input_file) > 0:
        print(f"Found input file for batch_id {batch_id}: {input_file[0]}")

    results_file = [
        f for f in results_files if batch_id_re.search(f.name).group(1) == batch_id
    ]
    if len(results_file) > 0:
        print(f"Found results file for batch_id {batch_id}: {results_file[0]}")

In [None]:
response_df.columns

In [None]:
response_df.groupby(["grammar_file", "model"], observed=False)["sample"].count()

In [None]:
response_df.groupby(["model"], observed=True)["grammar_file"].nunique()

Load grammar and sample statistics, and annotate the F1 scores with those values.

In [None]:
grammar_stats_pattern = "grammar_stats.json"
samples_stats_pattern = "filtered_samples_stats.json"

grammar_stats_files = list(GRAMMARS_DIR.rglob(grammar_stats_pattern))
samples_stats_files = list(GRAMMARS_DIR.rglob(samples_stats_pattern))

grammar_stats_files = [f for f in grammar_stats_files if f.parent.name in keep_files]
samples_stats_files = [f for f in samples_stats_files if f.parent.name in keep_files]

grammar_stats_dicts = []
for f in grammar_stats_files:
    try:
        g_dict = json.loads(f.read_text())
        g_dict["grammar_file"] = f.parent.name
        grammar_stats_dicts.append(g_dict)
    except json.JSONDecodeError:
        print(f"Error reading {f}")
grammar_stats_df = pd.DataFrame(grammar_stats_dicts)

samples_stats_dicts = []
for f in samples_stats_files:
    try:
        s_dict = json.loads(f.read_text())
        s_dict["grammar_file"] = f.parent.name
        samples_stats_dicts.append(s_dict)
    except json.JSONDecodeError:
        print(f"Error reading {f}")
samples_stats_df = pd.DataFrame(samples_stats_dicts)

f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="weighted",
        ),
        include_groups=False,
    )
    .reset_index(name="weighted_f1_score")
)

f1_df = f1_df.join(
    grammar_stats_df.set_index("grammar_file"),
    on="grammar_file",
).join(
    samples_stats_df.set_index("grammar_file"),
    on="grammar_file",
)

f1_df = f1_df.dropna(axis=1)

macro_f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="macro",
        ),
        include_groups=False,
    )
    .reset_index(name="macro_f1_score")
)

micro_f1_df = (
    response_df.groupby(
        [
            "n_shots",
            "model",
            "model_type",
            "grammar_file",
        ],
        observed=False,
    )
    .apply(
        lambda group: sk_metrics.f1_score(
            group["sample.type.ground_truth"],
            group["sample.type.predicted"],
            average="micro",
        ),
        include_groups=False,
    )
    .reset_index(name="micro_f1_score")
)

f1_df = f1_df.join(
    macro_f1_df[
        ["macro_f1_score", "n_shots", "model", "model_type", "grammar_file"]
    ].set_index(["grammar_file", "n_shots", "model", "model_type"]),
    on=["grammar_file", "n_shots", "model", "model_type"],
).join(
    micro_f1_df[
        ["micro_f1_score", "n_shots", "model", "model_type", "grammar_file"]
    ].set_index(["grammar_file", "n_shots", "model", "model_type"]),
    on=["grammar_file", "n_shots", "model", "model_type"],
)

accuracy_df = response_df.join(
    grammar_stats_df.set_index("grammar_file"),
    on="grammar_file",
).join(
    samples_stats_df.set_index("grammar_file"),
    on="grammar_file",
)

del grammar_stats_df, samples_stats_df, grammar_stats_dicts, samples_stats_dicts

f1_df.info()

f1_df_file = PROJECT_ROOT / "data" / "f1_df.feather"
f1_df.to_feather(f1_df_file)

In [None]:
accuracy_df.info()

acc_df_file = PROJECT_ROOT / "data" / "acc_df.feather"
accuracy_df.to_feather(acc_df_file)

In [None]:
f1_df

## Correlation Analysis

In [None]:
corr_mat = f1_df[
    [
        "macro_f1_score",
        "weighted_f1_score",
        "n_terminals",
        "n_nonterminals",
        "n_lexical_productions",
        "n_nonlexical_productions",
        "compression_ratio",
        "mean_positive_depth",
        "median_positive_depth",
        "coverage",
    ]
].corr()

corr_mat

In [None]:
ax = sns.heatmap(
    corr_mat,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

# Save the figure
plt.savefig(
    FIGURES_DIR / "grammar_stats_correlation_matrix.pdf",
    bbox_inches="tight",
)

In [None]:
sliced_corr_mat = (
    corr_mat.iloc[:, 0:2].sort_values(by="macro_f1_score", ascending=False).iloc[2:]
)

ax = sns.heatmap(
    sliced_corr_mat,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(sliced_corr_mat.columns)):
    for r in range(len(sliced_corr_mat)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${sliced_corr_mat.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

_ = ax.set_title("Grammar Hyperparameter Correlations with F1 Scores", y=1.05)

plt.savefig(
    FIGURES_DIR / "grammar_stats_correlation_matrix_sliced.pdf",
    bbox_inches="tight",
)

In [None]:
corr_mat_acc = (
    accuracy_df[
        [
            "correct",
            "sample.length",
            "grammar_file",
            "n_terminals",
            "n_nonterminals",
            "n_lexical_productions",
            "n_nonlexical_productions",
            "compression_ratio",
            "mean_positive_depth",
            "median_positive_depth",
            "coverage",
            "compression_ratio",
        ]
    ]
    .groupby("grammar_file", observed=False)
    .mean()
    .corr()
)

corr_mat_acc

In [None]:
sliced_corr_mat_acc = (
    corr_mat_acc.iloc[:, 0:1].sort_values(by="correct", ascending=False).iloc[2:]
)

ax = sns.heatmap(
    sliced_corr_mat_acc,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(sliced_corr_mat_acc.columns)):
    for r in range(len(sliced_corr_mat_acc)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${sliced_corr_mat_acc.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

_ = ax.set_title("Grammar Hyperparameter Correlations with Accuracy", y=1.05)

## Multivariate Regression

### Weighted F1 Score

In [None]:
f1_stats_df = f1_df.copy().drop(
    columns=[
        "total_possible_samples",
        "uncompressed_size",
        "compressed_size",
        "grammar_file",
        "grammar_name",
        "n_shots",
    ]
)

f1_stats_df.info()

In [None]:
endogenous_vars = [
    "model",
    "n_nonlexical_productions",
]

X = f1_stats_df[endogenous_vars].copy()
X = pd.get_dummies(X, drop_first=True, dtype=int)  # one-hot encode categorical vars
X = sm.add_constant(X)  # add a constant term to the model

Y = f1_stats_df["weighted_f1_score"]

model = smf.ols(
    "macro_f1_score ~ n_nonlexical_productions * C(model)", data=f1_stats_df
).fit()
print(model.summary())

### Macro F1 Score

In [None]:
Y = f1_stats_df["macro_f1_score"]

model = sm.OLS(Y, X).fit()
print(model.summary())

In [None]:
acc_by_grammars_df = (
    accuracy_df.groupby(
        [
            "model",
            "grammar_file",
            "n_nonlexical_productions",
            "sample.length",
            "compression_ratio",
        ],
        observed=True,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

acc_by_grammars_df = acc_by_grammars_df[
    ~acc_by_grammars_df.model.isin(["gemma-3-12b", "gemma-3-27b"])
]

acc_by_grammars_df["model"] = acc_by_grammars_df["model"].cat.remove_categories(
    ["gemma-3-12b", "gemma-3-27b"]
)
acc_by_grammars_df["accuracy_p"] = acc_by_grammars_df["accuracy"] * 100
acc_by_grammars_df = acc_by_grammars_df.rename(
    {"sample.length": "sample_length"}, axis=1
)

centered_df = acc_by_grammars_df.copy()

centered_df["log_nlp"] = np.log10(centered_df["n_nonlexical_productions"])
centered_df["log_sl"] = np.log10(centered_df["sample_length"])

centered_df["log_nlp_c"] = centered_df["log_nlp"] - centered_df["log_nlp"].mean()
centered_df["sl_c"] = centered_df["sample_length"] - centered_df["sample_length"].mean()

model = smf.ols(
    "accuracy_p ~ log_nlp_c * C(model) + sl_c * C(model) + log_nlp_c * sl_c",
    data=centered_df,
).fit()
print(model.summary())

In [None]:
centered_df["log_sl_c"] = centered_df["log_sl"] - centered_df["log_sl"].mean()

mixed_model = smf.ols(
    "accuracy_p ~ (log_nlp_c + log_sl_c) * C(model) + (log_nlp_c * log_sl_c)",
    data=centered_df,
).fit()
print(mixed_model.summary())

In [None]:
acc_by_grammars_df = (
    accuracy_df.groupby(
        [
            "model",
            "grammar_file",
            "n_nonlexical_productions",
            "sample.length",
            "n_terminals",
            "n_nonterminals",
            "n_lexical_productions",
            "compression_ratio",
        ],
        observed=True,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

acc_by_grammars_df["accuracy_p"] = acc_by_grammars_df["accuracy"] * 100
acc_by_grammars_df["log_sl"] = np.log10(acc_by_grammars_df["sample.length"])
acc_by_grammars_df["log_nlp"] = np.log10(acc_by_grammars_df["n_nonlexical_productions"])
acc_by_grammars_df["log_lp"] = np.log10(acc_by_grammars_df["n_lexical_productions"])
acc_by_grammars_df["log_nt"] = np.log10(acc_by_grammars_df["n_nonterminals"])
acc_by_grammars_df["log_t"] = np.log10(acc_by_grammars_df["n_terminals"])


# compute the correlation matrix for accuracy_p, log_sl, log_nlp, log_lp, log_nt, log_t, and compression_ratio for each model separately

corr_mat_acc = (
    acc_by_grammars_df[
        [
            "accuracy_p",
            "log_sl",
            "log_nlp",
            "log_lp",
            "log_nt",
            "log_t",
            "model",
        ]
    ]
    .groupby("model", observed=True)
    .corr()
    .iloc[:, 0:1]
    .iloc[1:]
    .reset_index()
    .pivot(index="model", columns="level_1", values="accuracy_p")
    .iloc[:, 1:]
)[
    [
        "log_t",
        "log_lp",
        "log_nt",
        "log_nlp",
    ]
]

corr_mat_acc

ax = sns.heatmap(
    corr_mat_acc,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(corr_mat_acc.columns)):
    for r in range(len(corr_mat_acc)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${corr_mat_acc.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

In [None]:
f1_df["log_nlp"] = np.log10(f1_df["n_nonlexical_productions"])
f1_df["log_lp"] = np.log10(f1_df["n_lexical_productions"])
f1_df["log_nt"] = np.log10(f1_df["n_nonterminals"])
f1_df["log_t"] = np.log10(f1_df["n_terminals"])

corr_mat_f1 = (
    f1_df[
        [
            "macro_f1_score",
            "log_nlp",
            "log_lp",
            "log_nt",
            "log_t",
            "model",
        ]
    ]
    .groupby("model", observed=True)
    .corr()
    .iloc[:, 0:1]
    .iloc[1:]
    .reset_index()
    .pivot(index="model", columns="level_1", values="macro_f1_score")
    .iloc[:, :-1]
)[
    [
        "log_t",
        "log_lp",
        "log_nt",
        "log_nlp",
    ]
]

# corr_mat = f1_df[
#     [
#         "macro_f1_score",
#         "log_nlp",
#         "log_lp",
#         "log_nt",
#         "log_t",
#         # "weighted_f1_score",
#         # "n_terminals",
#         # "n_nonterminals",
#         # "n_lexical_productions",
#         # "n_nonlexical_productions",
#         "compression_ratio",
#         # "mean_positive_depth",
#         # "median_positive_depth",
#         # "coverage",
#     ]
# ].corr()

# corr_mat

# sliced_corr_mat = corr_mat.iloc[:, 0:1].sort_values(by="macro_f1_score", ascending=False).iloc[1:]

ax = sns.heatmap(
    corr_mat_f1,
    cmap=CMAP_HEATMAP,
    vmin=-1,
    vmax=1,
)

for c in range(len(corr_mat_f1.columns)):
    for r in range(len(corr_mat_f1)):
        ax.text(
            c + 0.5,
            r + 0.5,
            f"${corr_mat_f1.iloc[r, c]:.2f}$",
            ha="center",
            va="center",
            color="black",
        )

_ = ax.set_title("Grammar Hyperparameter Correlations with Accuracy", y=1.05)

In [None]:
(
    f1_df[
        [
            "macro_f1_score",
            "log_nlp",
            "log_lp",
            "log_nt",
            "log_t",
            "model",
        ]
    ]
    .groupby("model", observed=True)
    .corr()
    .iloc[:, 0:1]
    .iloc[1:]
    .reset_index()
    .pivot(index="model", columns="level_1", values="macro_f1_score")
    .iloc[:, :-1]
)

In [None]:
centered_df["sl_c_sig"] = 1 / (1 + np.exp(-centered_df["sl_c"]))

mixed_model = smf.ols(
    "accuracy_p ~ log_nlp_c * C(model) + sl_c_sig * C(model)", data=centered_df
).fit()
print(mixed_model.summary())

In [None]:
rebalanced_df = accuracy_df.copy()
rebalanced_df = (
    rebalanced_df.groupby(["model", "sample.length", "sample.type.ground_truth"])[
        "correct"
    ]
    .mean()
    .reset_index()
    .groupby(
        [
            "model",
            "sample.length",
        ]
    )["correct"]
    .mean()
    .reset_index()
)

rebalanced_df["sl"] = rebalanced_df["sample.length"]
rebalanced_df["sl_c"] = rebalanced_df["sl"] - rebalanced_df["sl"].mean()
rebalanced_df["sl_c_sig"] = 1 / (1 + np.exp(-rebalanced_df["sl"]))

sns.relplot(
    data=rebalanced_df, kind="line", x="sl_c_sig", y="correct", col="model", col_wrap=3
)

In [None]:
glm = smf.glm(
    formula="accuracy_p ~ log_nlp_c * C(model) + sl_c_sig * C(model)",
    data=centered_df,
    family=sm.families.Binomial(),  # logistic link by default
    # freq_weights=centered_df.get("n_trials")   # optional: if accuracy is a proportion
).fit()

print(glm.summary())

In [None]:
model = smf.ols(
    "accuracy_p ~ n_nonlexical_productions * C(model)", data=acc_by_grammars_df
).fit()
print(model.summary())

In [None]:
centered_df.head()

### VIF

In [None]:
vif_df = (
    pd.DataFrame(
        {
            "variable": X.columns,
            "VIF": [variance_inflation_factor(X.values, i) for i in range(X.shape[1])],
        }
    )
    .drop(index=0)
    .sort_values(by="VIF", ascending=True)
    .reset_index(drop=True)
)
vif_df["Rating"] = vif_df["VIF"].apply(
    lambda x: "High" if x > 10 else "Moderate" if x > 5 else "Low"
)

vif_df

## Mean Accuracy/Score by Model and Sample Type

In [None]:
g = sns.catplot(
    data=accuracy_df,
    kind="bar",
    x="model",
    y="correct",
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    errorbar="se",
    height=3,
    aspect=2.8,
)

for ax in g.axes.flat:
    for i, bar in enumerate(ax.containers):
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

n_counts = accuracy_df.groupby("model", observed=False)["grammar_file"].nunique()
mean_accs = (
    accuracy_df.groupby(["model", "sample.type.ground_truth"], observed=False)[
        "correct"
    ]
    .mean()
    .groupby("model", observed=False)
    .mean()
)
mean_errors = accuracy_df.groupby("model", observed=False)["correct"].sem()

for ax in g.axes.flat:
    for i, category in enumerate(n_counts.index):
        try:
            pos_height = ax.containers[0][i].get_height()
            neg_height = ax.containers[1][i].get_height()
            max_height = max(pos_height, neg_height)

            count = n_counts[category]
            mean_acc = mean_accs[category]
            ax.text(i, -0.15, f"n={count}", ha="center", va="top")
            ax.text(
                i + (0.1 if pos_height > neg_height else -0.1),
                mean_acc,
                f"{mean_acc:.2f}",
                ha="left" if pos_height > neg_height else "right",
                va="center",
                fontweight="bold",
                fontsize=9,
            )

            # add  black diamond at mean accuracy
            ax.plot(
                i,
                mean_acc,
                marker="D",
                color="black",
                markersize=5,
                linewidth=0,
                label="mean (± sem)" if i == 0 else "_nolegend_",
            )

            # add error bar
            ax.errorbar(
                i,
                mean_acc,
                yerr=mean_errors[category],
                fmt="o",
                color="black",
                markersize=0,
                capsize=5,
                label="_nolegend_",
            )
        except IndexError:
            pass

_ = g.ax.axhline(
    y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
)

legend_format(
    ax=g,
    loc="center left",
    bbox_to_anchor=(0, 0.97),
    columnspacing=0,
)

_ = g.ax.set_ylabel("Mean Accuracy")
_ = g.ax.set_xlabel(None)
_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "accuracy_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
acc_by_grammars_df = (
    accuracy_df.groupby(
        [
            "grammar_file",
            "model",
            "sample.type.ground_truth",
        ],
        observed=False,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

fig_height = 1.5

with sns.plotting_context("paper", font_scale=0.8):
    g = sns.catplot(
        data=acc_by_grammars_df,
        kind="bar",
        x="model",
        y="accuracy",
        hue="sample.type.ground_truth",
        palette=PALETTE_SAMPLE_TYPE,
        errorbar="se",
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height,
    )

    for ax in g.axes.flat:
        for i, bar in enumerate(ax.containers):
            for rect in bar:
                height = rect.get_height()
                x_coord = rect.get_x() + rect.get_width() / 2.0

                ax.text(
                    rect.get_x() + rect.get_width() / 2,
                    height - 0.02,
                    f"{height:.2f}",
                    ha="center",
                    va="top",
                    fontsize=6,
                    color="white",
                    fontweight="bold",
                )

    for ax in g.axes.flat:
        for bar in ax.patches:
            bar.set_edgecolor(BAR_EDGE_COLOR)
            bar.set_linewidth(BAR_EDGE_WIDTH)

    n_counts = accuracy_df.groupby("model", observed=False)["grammar_file"].nunique()
    mean_accs = (
        accuracy_df.groupby(["model", "sample.type.ground_truth"], observed=False)[
            "correct"
        ]
        .mean()
        .groupby("model", observed=False)
        .mean()
    )
    mean_errors = accuracy_df.groupby("model", observed=False)["correct"].sem()

    for ax in g.axes.flat:
        for i, category in enumerate(n_counts.index):
            try:
                pos_height = ax.containers[0][i].get_height()
                neg_height = ax.containers[1][i].get_height()
                max_height = max(pos_height, neg_height)

                count = n_counts[category]
                mean_acc = mean_accs[category]

                # Add n=<count>
                # ax.text(i, -0.15, f"n={count}", ha="center", va="top")

                # Add mean accuracy label
                ax.text(
                    i + (0.1 if pos_height > neg_height else -0.1),
                    mean_acc,
                    f"{mean_acc:.2f}",
                    ha="left" if pos_height > neg_height else "right",
                    va="center",
                    fontweight="bold",
                    fontsize=7,
                )

                # add  black diamond at mean accuracy
                ax.plot(
                    i,
                    mean_acc,
                    marker="D",
                    color="black",
                    markersize=5,
                    linewidth=0,
                    label="mean (± sem)" if i == 0 else "_nolegend_",
                )

                # add error bar
                ax.errorbar(
                    i,
                    mean_acc,
                    yerr=mean_errors[category],
                    fmt="o",
                    color="black",
                    markersize=0,
                    capsize=5,
                    label="_nolegend_",
                )
            except IndexError:
                pass

    _ = g.ax.axhline(
        y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
    )

    legend_format(
        ax=g,
        loc="upper right",
        bbox_to_anchor=(1.15, 1),
        # columnspacing=0,
        ncol=1,
    )

    _ = g.ax.set_ylabel("Mean Accuracy")
    _ = g.ax.set_xlabel(None)
    _ = g.ax.set_ylim(0, 1)
    _ = g.ax.tick_params(axis="x", rotation=10)

plt.savefig(
    FIGURES_DIR / "accuracy_by_model_per_grammar.pdf",
    bbox_inches="tight",
)

In [None]:
acc_by_grammars_df = (
    accuracy_df.groupby(
        [
            "grammar_file",
            "model",
            "sample.type.ground_truth",
        ],
        observed=False,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

# empty models
empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

acc_by_grammars_df["model"] = acc_by_grammars_df["model"].cat.remove_categories(
    empty_models
)

fig_height = 1

with sns.plotting_context("paper", font_scale=1):
    g = sns.catplot(
        data=acc_by_grammars_df,
        kind="box",
        x="model",
        y="accuracy",
        hue="sample.type.ground_truth",
        palette=PALETTE_SAMPLE_TYPE,
        errorbar="se",
        showfliers=False,
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height,
    )

    n_counts = accuracy_df.groupby("model", observed=False)["grammar_file"].nunique()
    mean_accs = (
        accuracy_df.groupby(["model", "sample.type.ground_truth"], observed=False)[
            "correct"
        ]
        .mean()
        .groupby("model", observed=False)
        .mean()
    )

    for ax in g.axes.flat:
        for i, category in enumerate(n_counts.index):
            mean_acc = mean_accs[category]
            ax.plot(
                i,
                mean_acc,
                marker="D",
                color="black",
                markersize=6,
                linewidth=0,
                label="mean" if i == 0 else "_nolegend_",
            )

    for ax in g.axes.flat:
        for bar in ax.patches:
            bar.set_edgecolor(BAR_EDGE_COLOR)
            bar.set_linewidth(BAR_EDGE_WIDTH)

    _ = g.ax.axhline(
        y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
    )

    legend_format(
        ax=g,
        loc="upper right",
        bbox_to_anchor=(1, 1),
        ncol=1,
    )

    _ = g.ax.set_ylabel("Mean Accuracy")
    _ = g.ax.set_xlabel(None)
    _ = g.ax.set_ylim(0, 1)

    for line in g.ax.lines:
        line.set_clip_on(False)

plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

plt.savefig(
    FIGURES_DIR / "accuracy_by_model_per_grammar_box.pdf",
    bbox_inches="tight",
    pad_inches=0,
)

In [None]:
(
    accuracy_df.groupby(
        ["model", "sample.type.ground_truth", "sample.length"], observed=False
    )["correct"]
    .mean()
    .groupby("model", observed=False)
    .agg(["mean", "sem"])
    .dropna()
    .round(3)
    * 100
)

In [None]:
(
    f1_df.groupby("model", observed=False)["macro_f1_score"]
    .agg(["mean", "sem"])
    .dropna()
    .round(3)
    * 100
)

In [None]:
acc_by_grammars_df = (
    accuracy_df.groupby(
        [
            "grammar_file",
            "model",
            "sample.type.ground_truth",
        ],
        observed=False,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

# empty models
empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

acc_by_grammars_df["model"] = acc_by_grammars_df["model"].cat.remove_categories(
    empty_models
)

n_counts = accuracy_df.groupby("model", observed=False)["grammar_file"].nunique()
mean_accs = (
    accuracy_df.groupby(["model", "sample.type.ground_truth"], observed=False)[
        "correct"
    ]
    .mean()
    .groupby("model", observed=False)
    .mean()
)

fig_height = 1

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, len(acc_by_grammars_df["model"].cat.categories), wspace=0.0)

with sns.plotting_context("paper", font_scale=1):
    for i, model in enumerate(acc_by_grammars_df["model"].cat.categories):
        ax = fig.add_subplot(grid[0, i])
        sns.boxplot(
            data=acc_by_grammars_df[acc_by_grammars_df["model"] == model],
            x="sample.type.ground_truth",
            y="accuracy",
            hue="sample.type.ground_truth",
            palette=PALETTE_SAMPLE_TYPE,
            # errorbar="se",
            showfliers=False,
            ax=ax,
            width=0.9,
        )

        # Set the title for each subplot
        ax.set_title(model)
        ax.set_ylim(0, 1)
        ax.set_xticks([])
        ax.set_xlabel(None)

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.spines["bottom"].set_visible(False)
        if i != 0:
            ax.spines["left"].set_visible(False)
            ax.set_ylabel(None)
            ax.set_yticks([])
        else:
            ax.set_ylabel("Accuracy")
            ax.set_yticks([0, 1])

        mean_acc = mean_accs[model]
        ax.plot(
            0.5,
            mean_acc,
            marker="D",
            color="black",
            markersize=6,
            linewidth=0,
            label="mean" if i == 0 else "_nolegend_",
        )
        ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
            zorder=0,
        )

    for ax in fig.axes:
        for bar in ax.patches:
            bar.set_edgecolor(BAR_EDGE_COLOR)
            bar.set_linewidth(BAR_EDGE_WIDTH)

    # legend_format(
    #     ax=g,
    #     loc="upper right",
    #     bbox_to_anchor=(1, 1),
    #     ncol=1,
    # )

    # _ = g.ax.set_ylabel("Mean Accuracy")
    # _ = g.ax.set_xlabel(None)
    # _ = g.ax.set_ylim(0, 1)

    # for line in g.ax.lines:
    #     line.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    plt.savefig(
        FIGURES_DIR / "accuracy_by_model_per_grammar_box.pdf",
        bbox_inches="tight",
        pad_inches=0,
    )

In [None]:
f1_melted_df = f1_df.melt(
    id_vars=["n_shots", "model", "model_type", "grammar_file"],
    value_vars=["weighted_f1_score", "macro_f1_score", "micro_f1_score"],
    var_name="average",
    value_name="score",
)

# map score_type to a more readable name
f1_melted_df["average"] = f1_melted_df["average"].map(
    {
        "weighted_f1_score": "Weighted F1",
        "macro_f1_score": "Macro F1",
        "micro_f1_score": "Micro F1",
    }
)

# Don't show micro f1 score
f1_melted_df = f1_melted_df[f1_melted_df["average"] != "Micro F1"]

g = sns.catplot(
    data=f1_melted_df,
    kind="bar",
    x="model",
    y="score",
    hue="average",
    palette=PALETTE_SCORE,
    height=3,
    aspect=3.5,
)

n_counts = accuracy_df.groupby("model", observed=True)["grammar_file"].nunique()

_ = g.ax.axhline(
    y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
)

for ax in g.axes.flat:
    for i, category in enumerate(n_counts.index):
        count = n_counts[category]
        ax.text(i, -0.15, f"n={count}", ha="center", va="top")

# Add score labels to bars
score_labels = f1_melted_df.groupby(["model", "average"], observed=True)["score"].mean()

for ax in g.axes.flat:
    for bar in ax.containers:
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

legend_format(
    ax=g,
    loc="center left",
    bbox_to_anchor=(0, 0.9),
    ncol=1,
)

_ = g.ax.set_ylabel("F1 Score")
_ = g.ax.set_xlabel(None)
_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "f1_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
f1_melted_df = f1_df.melt(
    id_vars=["n_shots", "model", "model_type", "grammar_file"],
    value_vars=["weighted_f1_score", "macro_f1_score", "micro_f1_score"],
    var_name="average",
    value_name="score",
)

# map score_type to a more readable name
f1_melted_df["average"] = f1_melted_df["average"].map(
    {
        "weighted_f1_score": "Weighted F1",
        "macro_f1_score": "Macro F1",
        "micro_f1_score": "Micro F1",
    }
)

# empty models
empty_models = (
    f1_melted_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

f1_melted_df["model"] = f1_melted_df["model"].cat.remove_categories(empty_models)

# Don't show micro f1 score
f1_melted_df = f1_melted_df[f1_melted_df["average"] == "Macro F1"]

fig_height = 1

with sns.plotting_context("paper", font_scale=1):
    g = sns.catplot(
        data=f1_melted_df,
        kind="box",
        x="model",
        y="score",
        hue="average",
        palette=PALETTE_SCORE,
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height,
    )

    n_counts = accuracy_df.groupby("model", observed=True)["grammar_file"].nunique()

    _ = g.ax.axhline(
        y=0.5, color=COLOR_AT_CHANCE, alpha=ALPHA_AT_CHANCE, linestyle="--", zorder=0
    )

    legend_format(
        ax=g,
        loc="center left",
        bbox_to_anchor=(0, 0.9),
        ncol=1,
    )

    for ax in g.axes.flat:
        for bar in ax.patches:
            bar.set_edgecolor(BAR_EDGE_COLOR)
            bar.set_linewidth(BAR_EDGE_WIDTH)

    _ = g.ax.set_ylabel("F1 Score")
    _ = g.ax.set_xlabel(None)
    _ = g.ax.set_ylim(0, 1)

plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

plt.savefig(
    FIGURES_DIR / "f1_by_model_box.pdf",
    bbox_inches="tight",
)

## Macro F1 Score by Complexity

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df[f1_df.coverage > 0.9],
    x="compression_ratio",
    y="macro_f1_score",
    style="model",
    hue="model",
    hue_order=f1_df["model"].unique(),
    style_order=f1_df["model"].unique(),
    palette=PALETTE_MODEL,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

filter_by_alpha(
    keys=f1_df["model"].unique(),
    highlight=["gpt-4.1-nano"],
    alpha=0.1,
    ax=ax,
)


_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("gzip Compression Ratio")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "compression_ratio_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="n_terminals",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTE_MODEL,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_xscale("log")

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("# of Terminals")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "n_terminals_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="mean_positive_depth",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("Mean Parse Depth")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "mean_parse_depth_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="coverage",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

filter_by_alpha(
    keys=f1_df["model"].unique(),
    highlight=["o3", "o4-mini"],
    alpha=0.2,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("Coverage")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "coverage_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

sns.scatterplot(
    data=f1_df,
    x="n_nonlexical_productions",
    y="macro_f1_score",
    style="model",
    hue="model",
    palette=PALETTES["model"],
    ax=ax,
)

filter_by_alpha(
    keys=f1_df["model"].unique(),
    highlight=["o3", "o4-mini"],
    alpha=0.2,
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_xscale("log")
_ = ax.set_ylim(-0.02, 1.02)
_ = ax.set_xlabel("# of Nonlexical Productions  [log scale]")
_ = ax.set_ylabel("Macro F1 Score")

plt.savefig(
    FIGURES_DIR / "n_nonlexical_productions_vs_macro_f1.pdf",
    bbox_inches="tight",
)

In [None]:
empty_models = f1_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
empty_models = empty_models[empty_models].index.to_list()

f1_small_df = f1_df.copy()

f1_small_df["model"] = f1_small_df["model"].cat.remove_categories(empty_models)

# Calculate linear regressions for each model of
# macro_f1_score vs log(n_nonlexical_productions)
regs = {}
for model in f1_small_df["model"].cat.categories:
    model_df = f1_small_df[f1_small_df["model"] == model]
    X = model_df["n_nonlexical_productions"].apply(lambda x: np.log(x)).copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["macro_f1_score"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()
    regs[model] |= {"R^2": lin_model.rsquared}

fig_height = 1

with sns.plotting_context("paper", font_scale=1):
    _ = ax.axhline(
        y=0.5,
        color=COLOR_AT_CHANCE,
        alpha=ALPHA_AT_CHANCE,
        linestyle="--",
    )

    g = sns.relplot(
        data=f1_small_df,
        kind="scatter",
        x="n_nonlexical_productions",
        y="macro_f1_score",
        style="model",
        hue="model",
        col="model",
        palette=PALETTES["model"],
        s=8,
        # col_wrap=3,
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height / 6,
        legend=False,
        alpha=0.5,
    )

    g.set_titles("")

    for ax in g.axes.flat:
        ax.set_xscale("log")
        ax.set_ylim(0, 1)
        ax.set_xlabel(None)
        ax.set_ylabel("Macro F1 Score")

    for i, ax in enumerate(g.axes.flat):
        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
            zorder=0,
        )

        model = f1_small_df["model"].cat.categories[i]
        ax.text(
            3.5,
            0.1 if i not in [0, 5] else 0.85,
            f"{f1_small_df['model'].cat.categories[i]}",
            ha="left",
            va="bottom",
            fontsize=7,
            fontweight="bold",
            color=darken(PALETTES["model"][model], 0.3),
        )

        x_min = f1_small_df["n_nonlexical_productions"].min()
        x_max = f1_small_df["n_nonlexical_productions"].max()
        x = np.linspace(x_min, x_max, 10)

        # Get the regression line
        y_pred = regs[model]["const"] + regs[model][
            "n_nonlexical_productions"
        ] * np.log(x)

        # Plot the regression line
        ax.plot(
            x,
            y_pred,
            color=darken(PALETTES["model"][model], 0.3),
            linewidth=2,
            label="Regression Line",
        )

    g.axes.flat[0].set_xlabel(
        "# of Nonlexical Productions  [log scale]",
        ha="left",
    )
    g.axes.flat[0].xaxis.set_label_coords(0, -0.32)

    for o in g.figure.findobj():
        o.set_clip_on(False)

plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

plt.savefig(
    FIGURES_DIR / "n_nonlexical_productions_vs_macro_f1_faceted.pdf",
    bbox_inches="tight",
)

In [None]:
regs

In [None]:
empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

acc_small_df = accuracy_df.copy()

acc_small_df["model"] = acc_small_df["model"].cat.remove_categories(empty_models)

acc_small_df = acc_small_df[
    ["model", "grammar_file", "n_nonlexical_productions", "correct"]
]
acc_small_df["correct"] = acc_small_df["correct"].astype(float)

acc_small_df = (
    acc_small_df.groupby(
        ["grammar_file", "model", "n_nonlexical_productions"], observed=True
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

# Calculate linear regressions for each model of
# macro_f1_score vs log(n_nonlexical_productions)
regs = {}
for model in acc_small_df["model"].cat.categories:
    model_df = acc_small_df[acc_small_df["model"] == model]
    X = model_df["n_nonlexical_productions"].apply(lambda x: np.log(x)).copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["accuracy"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()
    regs[model] |= {"R^2": lin_model.rsquared}

fig_height = 0.8
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, len(acc_small_df["model"].cat.categories), wspace=0.1)

with sns.plotting_context("paper", font_scale=1):
    for i, model in enumerate(acc_small_df["model"].cat.categories):
        ax = fig.add_subplot(grid[0, i])

        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
        )

        sns.scatterplot(
            data=acc_small_df[acc_small_df["model"] == model],
            x="n_nonlexical_productions",
            y="accuracy",
            style="model",
            hue="model",
            palette=PALETTES["model"],
            ax=ax,
            s=8,
            legend=None,
        )

        x_min = f1_small_df["n_nonlexical_productions"].min()
        x_max = f1_small_df["n_nonlexical_productions"].max()
        x = np.linspace(x_min, x_max, 10)

        # Get the regression line
        y_pred = regs[model]["const"] + regs[model][
            "n_nonlexical_productions"
        ] * np.log(x)

        # Plot the regression line
        ax.plot(
            x,
            y_pred,
            color=darken(PALETTES["model"][model], 0.3),
            linewidth=2,
            label="Regression Line",
        )
        ax.set_ylim(0, 1)

        # ax.text(
        #     3.5,
        #     0.1 if i not in [0] else 0.85,
        #     f"{f1_small_df['model'].cat.categories[i]}",
        #     ha="left",
        #     va="bottom",
        #     fontsize=6,
        #     fontweight="bold",
        #     color=darken(PALETTES["model"][model], 0.3),
        # )

        ax.set_xlabel(None)
        ax.set_title(model, fontsize=7)
        ax.set_xscale("log")

        if i == 0:
            ax.set_ylabel("Accuracy")
            ax.set_xlabel("# of Nonlexical Productions  [log scale]", ha="left", x=0.0)
            ax.set_yticks([0, 1])
            ax.set_yticklabels([0, 1])
            # format y-axis ticks as percentages
            # ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1))
        else:
            ax.tick_params(axis="y", which="both", left=False)
            ax.tick_params(axis="y", which="both", labelleft=False)
            ax.set_ylabel(None)
            ax.spines["left"].set_edgecolor("grey")

        # Turn off top and right spines
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        for line in ax.lines:
            line.set_clip_on(False)

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

    plt.savefig(
        FIGURES_DIR / "n_nonlexical_productions_vs_accuracy_faceted.pdf",
        bbox_inches="tight",
    )

In [None]:
regs_flattened = []
for k, v in regs.items():
    regs_flattened.append(
        {
            "model": k,
            **v,
        }
    )
regs_df = pd.DataFrame(regs_flattened)
regs_df

In [None]:
empty_models = f1_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
empty_models = empty_models[empty_models].index.to_list()

f1_small_df = f1_df.copy()

f1_small_df["model"] = f1_small_df["model"].cat.remove_categories(empty_models)

# Calculate linear regressions for each model of
# macro_f1_score vs log(n_nonlexical_productions)
regs = {}
for model in f1_small_df["model"].cat.categories:
    model_df = f1_small_df[f1_small_df["model"] == model]
    X = model_df["compression_ratio"].copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["macro_f1_score"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()

fig_height = 1

with sns.plotting_context("paper", font_scale=1):
    _ = ax.axhline(
        y=0.5,
        color=COLOR_AT_CHANCE,
        alpha=ALPHA_AT_CHANCE,
        linestyle="--",
    )

    g = sns.relplot(
        data=f1_small_df,
        kind="scatter",
        x="compression_ratio",
        y="macro_f1_score",
        style="model",
        hue="model",
        col="model",
        palette=PALETTES["model"],
        s=8,
        # col_wrap=3,
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height / 6,
        legend=False,
        alpha=0.5,
    )

    g.set_titles("")

    for ax in g.axes.flat:
        # ax.set_xscale("log")
        ax.set_ylim(0, 1)
        ax.set_xlabel(None)
        ax.set_ylabel("Macro F1 Score")

    for i, ax in enumerate(g.axes.flat):
        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
            zorder=0,
        )

        model = f1_small_df["model"].cat.categories[i]
        ax.text(
            3.5,
            0.1 if i not in [0, 5] else 0.85,
            f"{f1_small_df['model'].cat.categories[i]}",
            ha="left",
            va="bottom",
            fontsize=7,
            fontweight="bold",
            color=darken(PALETTES["model"][model], 0.3),
        )

        x_min = f1_small_df["compression_ratio"].min()
        x_max = f1_small_df["compression_ratio"].max()
        x = np.linspace(x_min, x_max, 10)

        # Get the regression line
        y_pred = regs[model]["const"] + regs[model]["compression_ratio"] * x

        # Plot the regression line
        ax.plot(
            x,
            y_pred,
            color=darken(PALETTES["model"][model], 0.3),
            linewidth=2,
            label="Regression Line",
        )
        # ax.set_clip_on(False)
        # for line in ax.lines:
        #     line.set_clip_on(False)

        # for marker in ax.collections:
        #     marker.set_clip_on(False)

    g.axes.flat[0].set_xlabel(
        "Compression Ratio",
        ha="left",
    )
    g.axes.flat[0].xaxis.set_label_coords(0, -0.32)

    for o in g.figure.findobj():
        o.set_clip_on(False)

plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

plt.savefig(
    FIGURES_DIR / "compression_ratio_vs_macro_f1_faceted.pdf",
    bbox_inches="tight",
)

In [None]:
empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

accuracy_small_df = accuracy_df.copy()

accuracy_small_df["model"] = accuracy_small_df["model"].cat.remove_categories(
    empty_models
)

fig_height = 0.8

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, len(accuracy_small_df["model"].cat.categories), wspace=0.1)

with sns.plotting_context("paper", font_scale=1):
    for i, model in enumerate(accuracy_small_df["model"].cat.categories):
        ax = fig.add_subplot(grid[0, i])

        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
        )

        sns.lineplot(
            data=accuracy_small_df[accuracy_small_df["model"] == model],
            x="sample.length",
            y="correct",
            hue="sample.type.ground_truth",
            palette=PALETTES["sample_type"],
            ax=ax,
            legend=None,
        )

        sns.lineplot(
            data=accuracy_small_df[accuracy_small_df.model == model],
            x="sample.length",
            y="correct",
            color="grey",
            errorbar="se",
            legend=None,
            ax=ax,
        )

        ax.set_ylim(0, 1)

        ax.set_xlabel(None)
        ax.set_title(model, fontsize=7)

        if i == 0:
            ax.set_ylabel("Accuracy")
            ax.set_xlabel("Sample Length", ha="left", x=0.0)
            ax.set_yticks([0, 1])
            ax.set_yticklabels([0, 1])
        else:
            ax.tick_params(axis="y", which="both", left=False)
            ax.tick_params(axis="y", which="both", labelleft=False)
            ax.set_ylabel(None)
            ax.spines["left"].set_edgecolor("grey")

        # Turn off top and right spines
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        for line in ax.lines:
            line.set_clip_on(False)

    for ax in [fig.get_axes()[0]]:
        for i, label in enumerate(["positive", "negative"]):
            x_coord = ax.lines[i + 1].get_xdata()[-1]
            y_coord = ax.lines[i + 1].get_ydata()[-1] + (
                0.02 if label == "positive" else -0.15
            )
            ax.text(
                x_coord,
                y_coord,
                label,
                ha="right",
                va="bottom",
                fontweight="bold",
                fontsize=7,
                color=darken(PALETTES["sample_type"][label], by=0.2),
            )

        mean_x_coord = ax.lines[3].get_xdata()[-1]
        mean_y_coord = ax.lines[3].get_ydata()[-1] + 0.08
        ax.text(
            mean_x_coord,
            mean_y_coord,
            "all",
            ha="right",
            va="bottom",
            fontweight="bold",
            fontsize=7,
            color=darken("grey", by=0.2),
        )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

    plt.savefig(
        FIGURES_DIR / "accuracy_by_sample_length.pdf",
        bbox_inches="tight",
    )

### Accuracy ~ Complexity

In [None]:
fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))

empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

accuracy_small_df = accuracy_df.copy()

accuracy_small_df["model"] = accuracy_small_df["model"].cat.remove_categories(
    empty_models
)

acc_by_sl_df = (
    accuracy_small_df.groupby(
        ["grammar_file", "model", "sample.length"],
        observed=True,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

acc_by_nlp_df = (
    accuracy_small_df.groupby(
        ["grammar_file", "model", "n_nonlexical_productions"],
        observed=True,
    )["correct"]
    .mean()
    .reset_index(name="accuracy")
)

regs = {}
for model in acc_small_df["model"].cat.categories:
    model_df = acc_small_df[acc_small_df["model"] == model]
    X = model_df["n_nonlexical_productions"].apply(lambda x: np.log(x)).copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["accuracy"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()
    regs[model] |= {"R^2": lin_model.rsquared}

models = accuracy_small_df["model"].cat.categories
rows = [0, 1]

grid = fig.add_gridspec(2, len(models), wspace=0.1, hspace=0.8)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    # Plot accuracy ~ n_nonlexical_productions in first row,
    # and accuracy ~ sample.length in second row

    for r in rows:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])
            _ = ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            if r == 0:
                # Plot accuracy ~ n_nonlexical_productions
                sns.scatterplot(
                    data=acc_by_nlp_df[acc_by_nlp_df["model"] == model],
                    x="n_nonlexical_productions",
                    y="accuracy",
                    style="model",
                    hue="model",
                    palette=PALETTES["model"],
                    ax=ax,
                    s=8,
                    legend=None,
                )

                x_min = f1_small_df["n_nonlexical_productions"].min()
                x_max = f1_small_df["n_nonlexical_productions"].max()
                x = np.linspace(x_min, x_max, 10)

                # Get the regression line
                y_pred = regs[model]["const"] + regs[model][
                    "n_nonlexical_productions"
                ] * np.log(x)

                # Plot the regression line
                ax.plot(
                    x,
                    y_pred,
                    color=darken(PALETTES["model"][model], 0.3),
                    linewidth=2,
                    label="Regression Line",
                )

                ax.set_xscale("log")
                ax.set_title(model, fontsize=7)
                ax.set_ylim(0, 1)
                ax.set_xlim(1, 500)
                ax.set_xticks([1, 100])
                ax.set_xticklabels([1, 100])
                # ax.xaxis.set_major_formatter(
                #     mpl.ticker.FuncFormatter(lambda val, pos: str(int(np.log10(val))))
                # )
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)
                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel(
                        "# of Nonlexical Productions  [log scale]", ha="left", x=0.0
                    )
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("lightgrey")
                ax.get_xticklabels()[0].set_ha("left")
            else:
                # Plot accuracy ~ sample.length
                sns.lineplot(
                    data=acc_by_sl_df[acc_by_sl_df.model == model],
                    x="sample.length",
                    y="accuracy",
                    color=PALETTE_MODEL[model],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_ylim(0, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, 50])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel("Sample Length", ha="left", x=0.0)
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("lightgrey")

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "accuracy_by_complexity.pdf",
        bbox_inches="tight",
    )

In [None]:
fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))

empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

accuracy_small_df = accuracy_df.copy()

accuracy_small_df["model"] = accuracy_small_df["model"].cat.remove_categories(
    empty_models
)

# mean_acc_df = (
#     accuracy_small_df[
#         ["sample.length", "correct", "sample.type.ground_truth", "model"]
#     ]
#     .groupby(["model", "sample.length", "sample.type.ground_truth"], observed=True)
#     .mean()
#     .reset_index()
#     .groupby(["model", "sample.length"], observed=True)["correct"]
#     .mean()
#     .reset_index()
# )

acc_by_sl_df = (
    accuracy_small_df.groupby(
        ["model", "sample.length", "sample.type.ground_truth"],
        observed=True,
    )["correct"]
    .mean()
    .reset_index()
    .groupby(["model", "sample.length"], observed=True)["correct"]
    .mean()
    .reset_index(name="accuracy")
)

acc_by_nlp_df = (
    accuracy_small_df.groupby(
        [
            "grammar_file",
            "model",
            "n_nonlexical_productions",
            "sample.type.ground_truth",
        ],
        observed=True,
    )["correct"]
    .mean()
    .reset_index()
    .groupby(["model", "n_nonlexical_productions", "grammar_file"], observed=True)[
        "correct"
    ]
    .mean()
    .reset_index(name="accuracy")
)

regs = {}
for model in acc_small_df["model"].cat.categories:
    model_df = acc_small_df[acc_small_df["model"] == model]
    X = model_df["n_nonlexical_productions"].apply(lambda x: np.log(x)).copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["accuracy"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()
    regs[model] |= {"R^2": lin_model.rsquared}

models = accuracy_small_df["model"].cat.categories
rows = [0, 1]

grid = fig.add_gridspec(2, len(models), wspace=0.1, hspace=0.8)

# MODEL_COLOR = "#28b65f"


with sns.plotting_context("paper", font_scale=1, rc=rcs):
    # Plot accuracy ~ n_nonlexical_productions in first row,
    # and accuracy ~ sample.length in second row

    for r in rows:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])
            ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1))

            if r == 0:
                # Plot accuracy ~ n_nonlexical_productions
                sns.scatterplot(
                    data=acc_by_nlp_df[acc_by_nlp_df["model"] == model],
                    x="n_nonlexical_productions",
                    y="accuracy",
                    # style="model",
                    color=MODEL_COLOR,
                    ax=ax,
                    s=8,
                    legend=None,
                )

                x_min = f1_small_df["n_nonlexical_productions"].min()
                x_max = f1_small_df["n_nonlexical_productions"].max()
                x = np.linspace(x_min, x_max, 10)

                # Get the regression line
                y_pred = regs[model]["const"] + regs[model][
                    "n_nonlexical_productions"
                ] * np.log(x)

                # Plot the regression line
                ax.plot(
                    x,
                    y_pred,
                    color=darken(MODEL_COLOR, 0.3),
                    linewidth=2,
                    label="Regression Line",
                )

                ax.set_title(model, fontsize=7)
                ax.set_ylim(0, 1)
                ax.set_xlim(1, 500)

                ax.set_xscale("log")
                ax.set_xticks([1, 10, 100])
                ax.set_xticklabels(["1", "", ""])

                first_tick = ax.xaxis.get_majorticklabels()[0]
                first_tick.set_ha("left")

                ax.text(
                    500,
                    first_tick.get_position()[1],
                    "500",
                    transform=first_tick.get_transform(),
                    ha="right",
                    va="top",
                    fontsize=7,
                )

                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)
                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel(
                        "Instruction Set Complexity (# of Nonlexical Productions)  [log scale]",
                        ha="left",
                        x=0.0,
                    )
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("grey")
            else:
                sns.lineplot(
                    data=acc_by_sl_df[acc_by_sl_df.model == model],
                    x="sample.length",
                    y="accuracy",
                    color=MODEL_COLOR,
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_ylim(0, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, 50])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel("Task Complexity (Example Length)", ha="left", x=0.0)
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("grey")

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "accuracy_by_complexity_relabeled.pdf",
        bbox_inches="tight",
    )

In [None]:
(
    accuracy_small_df.groupby(
        ["model", "sample.type.ground_truth", "sample.length"], observed=True
    )["correct"]
    .mean()
    .reset_index()
    # .groupby(["model"], observed=True)["correct"]
    # .mean()
    # .round(3)
    # * 100
)

In [None]:
fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))

empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

accuracy_small_df = accuracy_df.copy()

accuracy_small_df["model"] = accuracy_small_df["model"].cat.remove_categories(
    empty_models
)

regs = {}
for model in acc_small_df["model"].cat.categories:
    model_df = acc_small_df[acc_small_df["model"] == model]
    X = model_df["n_nonlexical_productions"].apply(lambda x: np.log(x)).copy()
    X = sm.add_constant(X)  # add a constant term to the model
    Y = model_df["accuracy"]
    lin_model = sm.OLS(Y, X).fit()
    regs[model] = lin_model.params.to_dict()
    regs[model] |= {"R^2": lin_model.rsquared}

models = accuracy_small_df["model"].cat.categories
rows = [0, 1]

grid = fig.add_gridspec(2, len(models), wspace=0.1, hspace=0.8)

rcs = {
    "font.size": 10.0,
    "axes.labelsize": "small",
    "axes.titlesize": "small",
    "xtick.labelsize": "x-small",
    "ytick.labelsize": "x-small",
}

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    # Plot accuracy ~ n_nonlexical_productions in first row,
    # and accuracy ~ sample.length in second row

    for r in rows:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])
            _ = ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            if r == 0:
                # Plot accuracy ~ n_nonlexical_productions
                sns.scatterplot(
                    data=acc_small_df[acc_small_df["model"] == model],
                    x="n_nonlexical_productions",
                    y="accuracy",
                    color=PALETTE_MODEL["gpt-4.1"],
                    ax=ax,
                    s=8,
                    legend=None,
                    alpha=0.5,
                )

                x_min = f1_small_df["n_nonlexical_productions"].min()
                x_max = f1_small_df["n_nonlexical_productions"].max()
                x = np.linspace(x_min, x_max, 10)

                # Get the regression line
                y_pred = regs[model]["const"] + regs[model][
                    "n_nonlexical_productions"
                ] * np.log(x)

                # Plot the regression line
                ax.plot(
                    x,
                    y_pred,
                    color=darken(PALETTE_MODEL["gpt-4.1"], 0.3),
                    linewidth=2,
                    label="Regression Line",
                )

                ax.set_xscale("log")
                ax.set_title(model, fontsize=7)
                ax.set_ylim(0, 1)
                ax.set_xlim(1, 500)
                ax.set_xticks([1, 100])
                ax.set_xticklabels([1, 100])
                # ax.xaxis.set_major_formatter(
                #     mpl.ticker.FuncFormatter(lambda val, pos: str(int(np.log10(val))))
                # )
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)
                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel(
                        "# of Nonlexical Productions  [log scale]", ha="left", x=0.0
                    )
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("grey")
            else:
                sns.lineplot(
                    data=accuracy_small_df[accuracy_small_df.model == model],
                    x="sample.length",
                    y="correct",
                    color=PALETTE_MODEL["gpt-4.1"],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_ylim(0, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, 50])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_xlabel("Sample Length", ha="left", x=0.0)
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.spines["left"].set_edgecolor("grey")

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "accuracy_by_complexity_monochrome.pdf",
        bbox_inches="tight",
    )

### Predictions ~ Sample Length

In [None]:
preds_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=True)[
        "sample.type.predicted"
    ]
    .value_counts(normalize=True)
    .reset_index()
)

empty_models = (
    preds_df.groupby("model", observed=False)["sample.type.predicted"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

preds_df["model"] = preds_df["model"].cat.remove_categories(empty_models)

fig_height = 0.8

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, len(preds_df["model"].cat.categories), wspace=0.1)

with sns.plotting_context("paper", font_scale=1):
    for i, model in enumerate(preds_df["model"].cat.categories):
        ax = fig.add_subplot(grid[0, i])

        sns.lineplot(
            data=preds_df[preds_df.model == model],
            x="sample.length",
            y="proportion",
            hue="sample.type.predicted",
            palette=PALETTES["sample_type"],
            errorbar="se",
            legend=None,
            ax=ax,
        )

        ax.set_title(model, fontsize=7)
        ax.set_ylim(-0.03, 1)

        if i == 0:
            ax.set_ylabel("Predicted Type")
            ax.set_xlabel("Sample Length", ha="left", x=0.0)
            ax.set_yticks([0, 1])
        else:
            ax.tick_params(axis="y", which="both", left=False)
            ax.tick_params(axis="y", which="both", labelleft=False)
            ax.set_ylabel(None)
            ax.set_xlabel(None)
            ax.spines["left"].set_edgecolor("grey")

        # Turn off top and right spines
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        for line in ax.lines:
            line.set_clip_on(False)

    for i, label in enumerate(["positive", "negative", "unknown"]):
        ax = fig.get_axes()[0]
        x_coord = ax.lines[i].get_xdata()[-1]
        y_coord = ax.lines[i].get_ydata()[-1] + (0.1 if label == "negative" else 0.02)
        ax.text(
            x_coord,
            y_coord,
            label,
            ha="right",
            va="bottom",
            fontweight="bold",
            fontsize=7,
            color=darken(PALETTES["sample_type"][label], by=0.2),
        )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

    plt.savefig(
        FIGURES_DIR / "predicted_sample_type_by_sample_length.pdf",
        bbox_inches="tight",
    )

### Accuracy & Prediction Type ~ Sample Length, Type

In [None]:
preds_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=True)[
        "sample.type.predicted"
    ]
    .value_counts(normalize=True)
    .reset_index()
)

empty_models = (
    preds_df.groupby("model", observed=False)["sample.type.predicted"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

preds_df["model"] = preds_df["model"].cat.remove_categories(empty_models)

models = preds_df["model"].cat.categories

fig_height = 1.6

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(
    2,
    len(models),
    wspace=0.1,
    hspace=0.2,
)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    for r in [0, 1]:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])
            ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1))

            if r == 0:
                ax.axhline(
                    y=0.5,
                    color=COLOR_AT_CHANCE,
                    alpha=ALPHA_AT_CHANCE,
                    linestyle="--",
                )

                sns.lineplot(
                    data=accuracy_small_df[accuracy_small_df.model == model],
                    x="sample.length",
                    y="correct",
                    hue="sample.type.ground_truth",
                    palette=PALETTES["sample_type"],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(model, fontsize=7)
                ax.set_xlabel(None)
                ax.set_xticks([])
                ax.set_xticklabels([])

                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_yticks([0, 1])
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
                    ax.spines["left"].set_edgecolor("grey")

                if c == 0:
                    for i, label in enumerate(["positive", "negative"]):
                        x_coord = ax.lines[i + 1].get_xdata()[-1]
                        y_coord = ax.lines[i + 1].get_ydata()[-1] + (
                            0.02 if label == "positive" else -0.05
                        )
                        ax.text(
                            x_coord,
                            y_coord,
                            label,
                            ha="right",
                            va="bottom" if label == "positive" else "top",
                            # fontweight="bold",
                            fontsize=7,
                            color=darken(PALETTES["sample_type"][label], by=0.2),
                        )
            else:
                sns.lineplot(
                    data=preds_df[preds_df.model == model],
                    x="sample.length",
                    y="proportion",
                    hue="sample.type.predicted",
                    palette=PALETTES["sample_type"],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(None)

                if c == 0:
                    ax.set_xlabel("Task Complexity (Example Length)", ha="left", x=0.0)
                    ax.set_ylabel("Predicted Type")
                    ax.set_yticks([0, 1])

                    for i, label in enumerate(["positive", "negative", "unknown"]):
                        x_coord = ax.lines[i].get_xdata()[-1]
                        y_coord = ax.lines[i].get_ydata()[-1] + (
                            0.15 if label == "negative" else 0.02
                        )
                        ax.text(
                            x_coord,
                            y_coord,
                            label,
                            ha="right",
                            va="bottom",
                            # fontweight="bold",
                            fontsize=7,
                            color=darken(PALETTES["sample_type"][label], by=0.2),
                        )
                else:
                    ax.set_xlabel(None)
                    ax.set_yticks([])
                    ax.set_ylabel(None)
                    ax.spines["left"].set_edgecolor("grey")

            ax.set_ylim(-0.03, 1)
            ax.set_xlim(1, 50)
            ax.set_xticks([1, 50])
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            ax.get_xticklabels()[0].set_ha("left")
            ax.get_xticklabels()[-1].set_ha("right")

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "accuracy-predtype_by_sample_length-type.pdf",
        bbox_inches="tight",
    )

In [None]:
mean_acc_df = (
    accuracy_small_df[["sample.length", "correct", "sample.type.ground_truth", "model"]]
    .groupby(["model", "sample.length", "sample.type.ground_truth"], observed=True)
    .mean()
    .reset_index()
    .groupby(["model", "sample.length"], observed=True)["correct"]
    .mean()
    .reset_index()
)

fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(
    2, len(mean_acc_df["model"].cat.categories), wspace=0.1, hspace=0.1
)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    for r in range(2):
        for i, model in enumerate(mean_acc_df["model"].cat.categories):
            ax = fig.add_subplot(grid[r, i])

            _ = ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            if r == 0:
                sns.lineplot(
                    data=accuracy_small_df[accuracy_small_df.model == model],
                    x="sample.length",
                    y="correct",
                    color=PALETTE_MODEL[model],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(model, fontsize=7)
                ax.set_ylim(0, 1)
                ax.set_xlim(1, 50)
                ax.set_xlabel(None)
                ax.set_xticks([1, 50])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if i == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_yticks([0, 1])
                    ax.set_xlabel("Task Complexity (Sample Length)", ha="left", x=0.0)
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
            else:
                sns.lineplot(
                    data=mean_acc_df[mean_acc_df.model == model],
                    x="sample.length",
                    y="correct",
                    color=PALETTE_MODEL[model],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(None)
                ax.set_ylim(0, 1)
                ax.set_xlim(1, 50)
                ax.set_xlabel(None)
                ax.set_xticks([1, 50])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if i == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_yticks([0, 1])
                    ax.set_xlabel("Sample Length", ha="left", x=0.0)
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

In [None]:
mean_acc_df.groupby("model")["correct"].mean().round(3) * 100

In [None]:
def binary_entropy(series) -> float:
    counts = series.value_counts().values
    total = counts.sum()
    p_correct = counts[0] / total
    if p_correct == 0 or p_correct == 1:
        return 0
    return -(p_correct * np.log2(p_correct) + (1 - p_correct) * np.log2(1 - p_correct))


entropy_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=False)[
        "sample.type.predicted"
    ]
    .apply(binary_entropy)
    .reset_index(name="entropy")
)

preds_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=True)[
        "sample.type.predicted"
    ]
    .value_counts(normalize=True)
    .reset_index()
)

empty_models = (
    preds_df.groupby("model", observed=False)["sample.type.predicted"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

preds_df["model"] = preds_df["model"].cat.remove_categories(empty_models)

models = preds_df["model"].cat.categories
rows = [0, 1]

fig_height = 1.5

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(
    len(rows),
    len(models),
    wspace=0.1,
    hspace=0.2,
)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    for r in rows:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])

            if r == 0:
                # Plot sample.type.prediced ~ sample.length
                sns.lineplot(
                    data=preds_df[preds_df.model == model],
                    x="sample.length",
                    y="proportion",
                    hue="sample.type.predicted",
                    palette=PALETTES["sample_type"],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(model)
                ax.set_ylim(-0.03, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([])
                ax.set_xlabel(None)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Predicted\nType")

                    ax.set_yticks([0, 1])
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.spines["left"].set_edgecolor("lightgrey")

                # Turn off top and right spines
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

            else:
                ax.axhline(
                    y=0.5,
                    color=COLOR_AT_CHANCE,
                    alpha=ALPHA_AT_CHANCE,
                    linestyle="--",
                )
                sns.lineplot(
                    data=entropy_df[entropy_df.model == model],
                    x="sample.length",
                    y="entropy",
                    errorbar="se",
                    color="grey",
                    ax=ax,
                )

                ax.set_title(None)
                ax.set_ylim(-0.03, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, 50])
                ax.set_xlabel(None)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Prediction\nEntropy")
                    ax.set_xlabel("Sample Length", ha="left", x=0.0)
                    ax.set_yticks([0, 1])
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

        # for i, label in enumerate(["positive", "negative", "unknown"]):
        #     ax = fig.get_axes()[0]
        #     x_coord = ax.lines[i].get_xdata()[-1]
        #     y_coord = ax.lines[i].get_ydata()[-1] + (0.1 if label == "negative" else 0.02)
        #     ax.text(
        #         x_coord,
        #         y_coord,
        #         label,
        #         ha="right",
        #         va="bottom",
        #         fontweight="bold",
        #         fontsize=7,
        #         color=darken(PALETTES["sample_type"][label], by=0.2),
        #     )

        for o in fig.findobj():
            o.set_clip_on(False)

        plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

        plt.savefig(
            FIGURES_DIR / "predicted_type-entropy_by_length.pdf",
            bbox_inches="tight",
        )

In [None]:
fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)

_ = sns.lineplot(
    data=accuracy_df,
    x="sample.length",
    y="correct",
    hue="model",
    style="model",
    palette=PALETTE_MODEL,
    errorbar="se",
    ax=ax,
)

_ = ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

_ = ax.set_ylim(None, 1.02)
_ = ax.set_xlabel("Sample Length")
_ = ax.set_ylabel("Mean Accuracy")

plt.savefig(
    FIGURES_DIR / "accuracy_by_sample_length_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
def binary_entropy(series) -> float:
    counts = series.value_counts().values
    total = counts.sum()
    p_correct = counts[0] / total
    if p_correct == 0 or p_correct == 1:
        return 0
    return -(p_correct * np.log2(p_correct) + (1 - p_correct) * np.log2(1 - p_correct))


entropy_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=False)[
        "sample.type.predicted"
    ]
    .apply(binary_entropy)
    .reset_index(name="entropy")
)

entropy_df = entropy_df[entropy_df.model != "gemma-3-4b"]


fig = plt.figure(figsize=(5, 3.5))
ax = fig.add_subplot(111)


_ = sns.lineplot(
    data=entropy_df,
    x="sample.length",
    y="entropy",
    hue="model",
    style="model",
    palette=PALETTE_MODEL,
    errorbar="se",
    ax=ax,
)

legend_format(
    keys=f1_df["model"].unique(),
    ax=ax,
)

# _ = ax.set_xscale("log")

_ = ax.set_xlabel("Sample Length")
_ = ax.set_ylabel("Entropy of Model Predictions")

plt.savefig(
    FIGURES_DIR / "entropy_by_sample_length_by_model.pdf",
    bbox_inches="tight",
)

In [None]:
def binary_entropy(series) -> float:
    counts = series.value_counts().values
    total = counts.sum()
    p_correct = counts[0] / total
    if p_correct == 0 or p_correct == 1:
        return 0
    return -(p_correct * np.log2(p_correct) + (1 - p_correct) * np.log2(1 - p_correct))


entropy_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=False)[
        "sample.type.predicted"
    ]
    .apply(binary_entropy)
    .reset_index(name="entropy")
)

preds_df = (
    accuracy_df.groupby(["model", "sample.length"], observed=True)[
        "sample.type.predicted"
    ]
    .value_counts(normalize=True)
    .reset_index()
)

empty_models = (
    preds_df.groupby("model", observed=False)["sample.type.predicted"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

preds_df["model"] = preds_df["model"].cat.remove_categories(empty_models)

models = preds_df["model"].cat.categories
rows = [0, 1]

fig_height = 1.5

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(
    len(rows),
    len(models),
    wspace=0.1,
    hspace=0.2,
)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    for r in rows:
        for c, model in enumerate(models):
            ax = fig.add_subplot(grid[r, c])

            if r == 0:
                # Plot sample.type.prediced ~ sample.length
                sns.lineplot(
                    data=preds_df[preds_df.model == model],
                    x="sample.length",
                    y="proportion",
                    hue="sample.type.predicted",
                    palette=PALETTES["sample_type"],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(None)
                ax.set_ylim(-0.03, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([])
                ax.set_xlabel(None)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Predicted\nType")

                    ax.set_yticks([0, 1])
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
                    ax.set_xlabel(None)
                    ax.spines["left"].set_edgecolor("lightgrey")

                # Turn off top and right spines
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

            else:
                sns.lineplot(
                    data=entropy_df[entropy_df.model == model],
                    x="sample.length",
                    y="entropy",
                    errorbar="se",
                    color="grey",
                    ax=ax,
                )

                ax.set_title(None)
                ax.set_ylim(-0.03, 1)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, 50])
                ax.set_xlabel(None)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel("Prediction\nEntropy")
                    ax.set_xlabel("Sample Length", ha="left", x=0.0)
                    ax.set_yticks([0, 1])
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

        # for i, label in enumerate(["positive", "negative", "unknown"]):
        #     ax = fig.get_axes()[0]
        #     x_coord = ax.lines[i].get_xdata()[-1]
        #     y_coord = ax.lines[i].get_ydata()[-1] + (0.1 if label == "negative" else 0.02)
        #     ax.text(
        #         x_coord,
        #         y_coord,
        #         label,
        #         ha="right",
        #         va="bottom",
        #         fontweight="bold",
        #         fontsize=7,
        #         color=darken(PALETTES["sample_type"][label], by=0.2),
        #     )

        for o in fig.findobj():
            o.set_clip_on(False)

        plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

        plt.savefig(
            FIGURES_DIR / "predicted_type-entropy_by_length.pdf",
            bbox_inches="tight",
        )

In [None]:
entropy_df = (
    accuracy_df.groupby(["model"], observed=False)["sample.type.predicted"]
    .apply(binary_entropy)
    .reset_index(name="entropy")
)

fig_height = 1.2

with sns.plotting_context("paper"):
    g = sns.catplot(
        data=entropy_df,
        kind="bar",
        x="model",
        y="entropy",
        hue="model",
        palette=PALETTE_MODEL,
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height,
    )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

In [None]:
(
    accuracy_df[accuracy_df.model == "gemma-3-1b"]
    .groupby("sample.length", observed=False)["sample.type.predicted"]
    .value_counts(normalize=True)
    .reset_index()
)

## Histogram of Sample Lengths

In [None]:
fig = plt.figure(figsize=(6, 3))
ax = fig.add_subplot(111)

sns.histplot(
    data=response_df,
    x="sample.length",
    ax=ax,
    binwidth=1,
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    alpha=0.8,
)

_ = ax.get_legend().set_title("Sample type")
_ = ax.set_xlabel("Sample length")

plt.savefig(
    FIGURES_DIR / "sample_length_histogram.pdf",
    bbox_inches="tight",
)

In [None]:
sample_counts_df = (
    response_df.groupby(["sample.type.ground_truth"], observed=False)[
        "sample.type.ground_truth"
    ]
    .count()
    .reset_index(name="count")
)
total_samples = sample_counts_df["count"].sum()
sample_counts_df["proportion"] = sample_counts_df["count"] / total_samples

g = sns.catplot(
    data=sample_counts_df,
    kind="bar",
    x="sample.type.ground_truth",
    y="proportion",
    hue="sample.type.ground_truth",
    palette=PALETTE_SAMPLE_TYPE,
    height=3,
    aspect=0.8,
).set_axis_labels("", "Proportion of Samples")

for ax in g.axes.flat:
    for i, bar in enumerate(ax.containers):
        for rect in bar:
            height = rect.get_height()
            x_coord = rect.get_x() + rect.get_width() / 2.0

            ax.text(
                rect.get_x() + rect.get_width() / 2,
                height - 0.02,
                f"{height:.2f}",
                ha="center",
                va="top",
                fontsize=9,
                color="white",
                fontweight="bold",
            )

for ax in g.axes.flat:
    for bar in ax.patches:
        bar.set_edgecolor(BAR_EDGE_COLOR)
        bar.set_linewidth(BAR_EDGE_WIDTH)

_ = g.ax.axhline(
    y=0.5,
    color=COLOR_AT_CHANCE,
    alpha=ALPHA_AT_CHANCE,
    linestyle="--",
)

_ = g.ax.set_ylim(0, 1)

plt.savefig(
    FIGURES_DIR / "sample_type_proportions.pdf",
    bbox_inches="tight",
)

In [None]:
fig_height = 1.2

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, 2, wspace=0.05, width_ratios=[1, 2.5])

with sns.plotting_context("paper", font_scale=1):
    for c in range(2):
        ax = fig.add_subplot(grid[0, c])
        if c == 0:
            sns.barplot(
                data=sample_counts_df,
                x="sample.type.ground_truth",
                y="proportion",
                hue="sample.type.ground_truth",
                palette=PALETTE_SAMPLE_TYPE,
                ax=ax,
                gap=-0.1,
                width=0.8,
            )
            ax.set_ylabel("Proportion")
            ax.set_ylim(0, 1)
            ax.set_yticks([0, 1])
            ax.set_xlabel("Sample Type")
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            for bar in ax.patches:
                bar.set_edgecolor(BAR_EDGE_COLOR)
                bar.set_linewidth(BAR_EDGE_WIDTH)
        else:
            sns.histplot(
                data=response_df,
                x="sample.length",
                ax=ax,
                # binwidth=5,
                discrete=True,
                stat="count",
                hue="sample.type.ground_truth",
                palette=PALETTE_SAMPLE_TYPE,
                alpha=0.8,
                legend=None,
            )

            ax.yaxis.tick_right()
            ax.yaxis.set_label_position("right")
            ax.yaxis.set_ticks_position("right")
            ax.spines["top"].set_visible(False)
            ax.spines["left"].set_visible(False)
            ax.set_xlabel("Sample Length", ha="left", x=0.0)
            ax.set_xlim(0, 51)
            ax.set_xticks([1, 10, 20, 30, 40, 50])

            ax.yaxis.set_major_formatter(
                mpl.ticker.FuncFormatter(lambda x, pos: f"{int(x/1000)}k")
            )

            for bar in ax.patches:
                bar.set_edgecolor(BAR_EDGE_COLOR)
                bar.set_linewidth(BAR_EDGE_WIDTH)

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "sample_stats.pdf",
        bbox_inches="tight",
    )

In [None]:
response_df["sample.length"].value_counts().sort_index().tail()

In [None]:
gemma_responses_df = response_df[(response_df.model == "gemma-3-1b")][
    [
        "model_response",
        "sample.type.predicted",
        "sample.length",
        "correct",
        "grammar_file",
    ]
].sort_values(by="grammar_file")[
    [
        "model_response",
        "sample.type.predicted",
        "sample.length",
        "correct",
        "grammar_file",
    ]
]

In [None]:
gemma_responses_df

In [None]:
# gemma_responses_df.groupby("sample.type.predicted", observed=True).value_counts(normalize=True)

In [None]:
gemma_responses_df[gemma_responses_df.grammar_file == "grammar_20250218222557"][
    "sample.type.predicted"
].value_counts(normalize=True)

In [None]:
gemma_responses_df[gemma_responses_df.grammar_file != "grammar_20250218222557"][
    "sample.type.predicted"
].value_counts(normalize=True)

## Tokens vs Accuracy

In [None]:
binned_acc_df = accuracy_df[
    ["completion_tokens", "correct", "model", "sample.length"]
].copy()
binned_acc_df = binned_acc_df[binned_acc_df.completion_tokens > 0]

binned_acc_df["completion_tokens_bin"] = (
    binned_acc_df["completion_tokens"]
    .map(lambda x: np.log2(x))
    .round(1)
    .map(lambda x: 2**x)
)
binned_acc_df["correct"] = binned_acc_df["correct"].astype(float)
binned_acc_df = (
    binned_acc_df.groupby(["completion_tokens_bin", "model"], observed=True)[
        ["correct"]
    ]
    .mean()
    .reset_index()
)
# binned_acc_df.columns = [
#     "_".join(col).strip("_") for col in binned_acc_df.columns.values
# ]

binned_acc_df
with sns.plotting_context("notebook"):
    g = sns.relplot(
        data=binned_acc_df,
        x="completion_tokens_bin",
        y="correct",
        hue="model",
        palette=PALETTE_MODEL,
        # legend=None,
        # col_wrap=3
    )

    g.legend.remove()

    legend_format(
        keys=binned_acc_df["model"].unique(),
        ax=g.ax,
    )

    filter_by_alpha(
        keys=binned_acc_df["model"].unique(),
        highlight=["gemma-3-1b"],
        alpha=0.1,
        ax=g.ax,
    )

    g.ax.set_xscale("log")
    g.ax.set_xlabel("Completion Tokens  [binned, log scale]")
    g.ax.set_ylabel("Mean Accuracy")

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.2, hspace=0)

In [None]:
binned_acc_df

In [None]:
empty_models = (
    accuracy_df.groupby("model", observed=False)["grammar_file"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

binned_acc_df = accuracy_df[
    ["completion_tokens", "correct", "model", "sample.length"]
].copy()
binned_acc_df = binned_acc_df[binned_acc_df.completion_tokens > 0]

binned_acc_df["completion_tokens_bin"] = (
    binned_acc_df["completion_tokens"]
    .map(lambda x: np.log10(x))
    .round(1)
    .map(lambda x: 10**x)
)
binned_acc_df["correct"] = binned_acc_df["correct"].astype(float)
binned_acc_df = (
    binned_acc_df.groupby(["completion_tokens_bin", "model"], observed=True)[
        ["correct"]
    ]
    .mean()
    .reset_index()
)

binned_acc_small_df = binned_acc_df.copy()

binned_acc_small_df["model"] = binned_acc_small_df["model"].cat.remove_categories(
    empty_models + ["gemma-3-1b", "gemma-3-4b"]
)

fig_height = 1.2

with sns.plotting_context("notebook"):
    g = sns.relplot(
        data=binned_acc_small_df,
        x="completion_tokens_bin",
        y="correct",
        hue="model",
        palette=PALETTE_MODEL,
        col="model",
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height / 3,
    )

    for i, ax in enumerate(g.axes.flat):
        model = binned_acc_small_df["model"].cat.categories[i]

        ax.text(
            0.1,
            0.1,
            f"{binned_acc_small_df['model'].cat.categories[i]}",
            ha="left",
            va="bottom",
            transform=ax.transAxes,
            # fontsize=7,
            fontweight="bold",
            color=darken(PALETTES["model"][model], 0.3),
        )

        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
        )

    g.set(xscale="log")
    g.set_titles("")
    g.set(ylabel="Mean Accuracy", xlabel=None)
    g.axes.flat[0].set_xlabel("Completion Tokens  [binned, log scale]", ha="left", x=0)

    g.legend.remove()

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0)

In [None]:
# flatten the multi-index columns

# (
# binned_acc_df.columns = [
#     "_".join(col).strip("_") for col in binned_acc_df.columns.values
# ]
# )

binned_acc_df

In [None]:
binned_sl_df = accuracy_df[
    ["completion_tokens", "correct", "model", "sample.length"]
].copy()

binned_sl_df = (
    binned_sl_df.groupby(["sample.length", "model"], observed=False)[
        ["completion_tokens"]
    ]
    .agg(["mean", "sem"])
    .reset_index()
)
binned_sl_df.columns = ["_".join(col).strip("_") for col in binned_sl_df.columns.values]

with sns.plotting_context("notebook"):
    g = sns.relplot(
        data=binned_sl_df,
        x="sample.length",
        y="completion_tokens_mean",
        hue="model",
        palette=PALETTE_MODEL,
    )

    g.legend.remove()

    legend_format(
        keys=binned_acc_df["model"].unique(),
        ax=g.ax,
    )

    # g.ax.set_xscale("log")
    g.ax.set_yscale("log")

    g.ax.set_xlabel("Sample Length")
    g.ax.set_ylabel("Completion Tokens  [log scale]")

In [None]:
binned_sl_df = accuracy_df[
    [
        "completion_tokens",
        "correct",
        "model",
        "sample.length",
        "sample.type.ground_truth",
    ]
].copy()
binned_sl_df = (
    binned_sl_df.groupby(
        ["sample.length", "model", "sample.type.ground_truth"], observed=False
    )[["completion_tokens"]]
    .mean()
    .reset_index()
    .groupby(["sample.length", "model"], observed=False)[["completion_tokens"]]
    .mean()
    .reset_index()
)
# binned_sl_df.columns = ["_".join(col).strip("_") for col in binned_sl_df.columns.values]

binned_nlp_df = accuracy_df[
    [
        "completion_tokens",
        "correct",
        "model",
        "n_nonlexical_productions",
        "sample.type.ground_truth",
    ]
].copy()
binned_nlp_df = (
    binned_nlp_df.groupby(
        ["n_nonlexical_productions", "model", "sample.type.ground_truth"],
        observed=False,
    )[["completion_tokens"]]
    .mean()
    .reset_index()
    .groupby(["n_nonlexical_productions", "model"], observed=False)[
        ["completion_tokens"]
    ]
    .mean()
    .reset_index()
)

empty_models = (
    binned_sl_df.groupby("model", observed=False)["completion_tokens"].nunique() == 0
)
empty_models = empty_models[empty_models].index.to_list()

binned_sl_small_df = binned_sl_df.copy()
binned_sl_small_df["model"] = binned_sl_small_df["model"].cat.remove_categories(
    empty_models + ["gemma-3-4b", "gemma-3-1b", "DSR1-7B"]
)

binned_nlp_small_df = binned_nlp_df.copy()
binned_nlp_small_df["model"] = binned_nlp_small_df["model"].cat.remove_categories(
    empty_models + ["gemma-3-4b", "gemma-3-1b", "DSR1-7B"]
)

fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(2, 5, wspace=0.05, hspace=0.8)

row_0_ymin = binned_nlp_small_df.query("completion_tokens > 0")[
    "completion_tokens"
].min()
row_0_ymax = binned_nlp_small_df["completion_tokens"].max()
row_1_ymin = binned_sl_small_df.query("completion_tokens > 0")[
    "completion_tokens"
].min()
row_1_ymax = binned_sl_small_df["completion_tokens"].max()

with sns.plotting_context("notebook", rc=rcs):
    for r in range(2):
        for c, model in enumerate(binned_sl_small_df["model"].cat.categories):
            ax = fig.add_subplot(grid[r, c])
            if r == 0:
                sns.scatterplot(
                    data=binned_nlp_small_df[binned_nlp_small_df.model == model],
                    x="n_nonlexical_productions",
                    y="completion_tokens",
                    color=PALETTE_MODEL[model],
                    # errorbar="se",
                    s=8,
                    alpha=0.8,
                    legend=None,
                    ax=ax,
                )

                ax.set_title(model, fontsize=7)
                ax.set_xlabel(None)
                ax.set_xscale("log")
                ax.set_yscale("log")
                ax.set_ylim(row_0_ymin, row_0_ymax)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                # format x-axis ticks like 10, 100 instead of 10^1, 10^2
                ax.xaxis.set_major_formatter(
                    mpl.ticker.FuncFormatter(lambda x, pos: f"{int(x):,}")
                )

                if c == 0:
                    ax.set_ylabel(
                        "Test-time Compute (completion tokens)", va="bottom", y=-0.5
                    )
                    ax.set_xlabel(
                        "Instruction Set Complexity (# of Nonlexical Productions)",
                        ha="left",
                        x=0.0,
                    )
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
                    ax.spines["left"].set_edgecolor("lightgrey")

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

            else:
                sns.lineplot(
                    data=binned_sl_small_df[binned_sl_small_df.model == model],
                    x="sample.length",
                    y="completion_tokens",
                    color=PALETTE_MODEL[model],
                    errorbar="se",
                    legend=None,
                    ax=ax,
                )

                ax.set_title(None)
                ax.set_xlim(1, 50)
                ax.set_xlabel(None)
                ax.set_xticks([1, 10, 20, 30, 40, 50])
                ax.set_yscale("log")
                ax.set_ylim(row_1_ymin, row_1_ymax)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_ylabel(None)
                    ax.set_xlabel("Task Complexity (Sample length)", ha="left", x=0.0)
                else:
                    ax.tick_params(axis="y", which="both", left=False)
                    ax.tick_params(axis="y", which="both", labelleft=False)
                    ax.set_ylabel(None)
                    ax.spines["left"].set_edgecolor("lightgrey")

                ax.get_xticklabels()[0].set_ha("left")
                ax.get_xticklabels()[-1].set_ha("right")

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "ttc.pdf",
        bbox_inches="tight",
    )

In [None]:
binned_sl_df

In [None]:
empty_models

In [None]:
o3_tokens_acc_df = accuracy_df[accuracy_df.model == "o3"][
    ["completion_tokens", "correct", "sample.length", "sample.type.ground_truth"]
].copy()
o3_tokens_acc_df["completion_tokens_bin"] = o3_tokens_acc_df["completion_tokens"].round(
    -2
)

# Sample lengths are in range 1--50; bin them into 5 bins
o3_tokens_acc_df["sample.length_bin"] = o3_tokens_acc_df["sample.length"].map(
    lambda x: "0-10"
    if x < 11
    else "10-20"
    if x < 21
    else "20-30"
    if x < 31
    else "30-40"
    if x < 41
    else "40-50"
)

o3_tokens_acc_df = (
    o3_tokens_acc_df.groupby(
        ["completion_tokens_bin", "sample.length_bin", "sample.type.ground_truth"],
        observed=False,
    )[["correct"]]
    .mean()
    .reset_index()
)

fig_height = 1.2

with sns.plotting_context("notebook"):
    g = sns.relplot(
        data=o3_tokens_acc_df,
        y="correct",
        x="completion_tokens_bin",
        hue="sample.type.ground_truth",
        palette=PALETTE_SAMPLE_TYPE,
        col="sample.length_bin",
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height / 3,
        alpha=0.7,
        size=8,
        legend=None,
    )

    for i, ax in enumerate(g.axes.flat):
        length_min = i * 10
        length_max = (i + 1) * 10
        sample_length = f"{length_min}-{length_max}"

        ax.text(
            0.1,
            0.1,
            f"{sample_length}",
            ha="left",
            va="bottom",
            transform=ax.transAxes,
            # fontsize=7,
            fontweight="bold",
        )

        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
        )

    g.set(xscale="log")
    g.set(xlabel=None)
    g.set(ylabel="Mean Accuracy")
    g.set_titles("")

    g.axes.flat[0].set_xlabel(
        "Completion Tokens  [binned, log scale]",
        ha="left",
    )

    g.figure.suptitle(
        "O3 Model Accuracy by Sample Length and Completion Tokens",
        # fontsize=10,
        # fontweight="bold",
        y=1.25,
    )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0)

In [None]:
o3_tokens_acc_df = accuracy_df[accuracy_df.model == "o3"][
    ["completion_tokens", "correct", "sample.length"]
].copy()
o3_tokens_acc_df["completion_tokens_bin"] = o3_tokens_acc_df["completion_tokens"].round(
    -2
)

# Sample lengths are in range 1--50; bin them into 5 bins
o3_tokens_acc_df["sample.length_bin"] = o3_tokens_acc_df["sample.length"].map(
    lambda x: "0-10"
    if x < 11
    else "10-20"
    if x < 21
    else "20-30"
    if x < 31
    else "30-40"
    if x < 41
    else "40-50"
)

o3_tokens_acc_df = (
    o3_tokens_acc_df.groupby(
        ["completion_tokens_bin", "sample.length_bin"], observed=False
    )[["correct"]]
    .mean()
    .reset_index()
)

fig_height = 1.2

with sns.plotting_context("notebook"):
    g = sns.relplot(
        data=o3_tokens_acc_df,
        y="correct",
        x="completion_tokens_bin",
        hue="sample.length_bin",
        col="sample.length_bin",
        height=fig_height,
        aspect=PAPER_WIDTH_IN / fig_height / 3,
        alpha=0.7,
        size=8,
        legend=None,
    )

    for i, ax in enumerate(g.axes.flat):
        length_min = i * 10
        length_max = (i + 1) * 10
        sample_length = f"{length_min}-{length_max}"

        ax.text(
            0.1,
            0.1,
            f"{sample_length}",
            ha="left",
            va="bottom",
            transform=ax.transAxes,
            # fontsize=7,
            fontweight="bold",
        )

        _ = ax.axhline(
            y=0.5,
            color=COLOR_AT_CHANCE,
            alpha=ALPHA_AT_CHANCE,
            linestyle="--",
        )

    g.set(xscale="log")
    g.set(xlabel=None)
    g.set(ylabel="Mean Accuracy")
    g.set_titles("")

    g.axes.flat[0].set_xlabel(
        "Completion Tokens  [binned, log scale]",
        ha="left",
    )

    g.figure.suptitle(
        "O3 Model Accuracy by Sample Length and Completion Tokens",
        # fontsize=10,
        # fontweight="bold",
        y=1.25,
    )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0.1, hspace=0)

In [None]:
fig = plt.figure(figsize=(PAPER_WIDTH_IN, 1.6))
grid = gs.GridSpec(nrows=2, ncols=5, figure=fig, hspace=0.15, wspace=0.1)

# base_ax = fig.add_subplot(grid[0, 0])


o3_tokens_acc_df = accuracy_df[accuracy_df.model == "o3"][
    ["completion_tokens", "correct", "sample.length", "sample.type.ground_truth"]
].copy()
o3_tokens_acc_df["completion_tokens_bin"] = (
    o3_tokens_acc_df["completion_tokens"]
    .map(lambda x: np.log2(x))
    .round(1)
    .map(lambda x: 2**x)
)

# Sample lengths are in range 1--50; bin them into 5 bins
o3_tokens_acc_df["sample.length_bin"] = o3_tokens_acc_df["sample.length"].map(
    lambda x: "1–10 symbols"
    if x < 11
    else "11–20"
    if x < 21
    else "21–30"
    if x < 31
    else "31–40"
    if x < 41
    else "41–50"
)

o3_tokens_acc_typed_df = (
    o3_tokens_acc_df.groupby(
        ["completion_tokens_bin", "sample.length_bin", "sample.type.ground_truth"],
        observed=False,
    )[["correct"]]
    .mean()
    .reset_index()
)

o3_tokens_acc_untyped_df = (
    o3_tokens_acc_df.groupby(
        ["completion_tokens_bin", "sample.length_bin"],
        observed=False,
    )[["correct"]]
    .mean()
    .reset_index()
)

sl_bins = o3_tokens_acc_df["sample.length_bin"].unique()
min_x = o3_tokens_acc_df["completion_tokens_bin"].min()
max_x = o3_tokens_acc_df["completion_tokens_bin"].max()

with sns.plotting_context("paper", font_scale=1):
    for row in range(2):
        for col in range(5):
            ax = fig.add_subplot(grid[row, col])
            data = o3_tokens_acc_df[
                o3_tokens_acc_df["sample.length_bin"] == sl_bins[col]
            ]

            if row == 0:
                data = o3_tokens_acc_untyped_df[
                    o3_tokens_acc_untyped_df["sample.length_bin"] == sl_bins[col]
                ]
                sns.scatterplot(
                    data=data,
                    x="completion_tokens_bin",
                    y="correct",
                    ax=ax,
                    size=8,
                    color="grey",
                    legend=None,
                )
                sns.regplot(
                    data=data,
                    x="completion_tokens_bin",
                    y="correct",
                    ax=ax,
                    scatter=False,
                    lowess=True,
                    color="grey",
                    line_kws={"color": "black", "alpha": 0.5},
                )
            else:
                data = o3_tokens_acc_typed_df[
                    o3_tokens_acc_typed_df["sample.length_bin"] == sl_bins[col]
                ]
                sns.scatterplot(
                    data=data,
                    x="completion_tokens_bin",
                    y="correct",
                    hue="sample.type.ground_truth",
                    palette=PALETTE_SAMPLE_TYPE,
                    legend=None,
                    size=8,
                    ax=ax,
                    alpha=0.8,
                )
                for stype in ["positive", "negative"]:
                    sns.regplot(
                        data=data[data["sample.type.ground_truth"] == stype],
                        x="completion_tokens_bin",
                        y="correct",
                        ax=ax,
                        scatter=False,
                        lowess=True,
                        color=darken(PALETTE_SAMPLE_TYPE[stype], 0.1),
                    )

            _ = ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            # Set the x-axis to log scale
            ax.set_xscale("log")
            ax.set_ylim(-0.02, 1.02)
            ax.set_xlim(min_x, max_x)

            # # Set the x-axis label
            ax.set_xlabel(None)
            ax.set_ylabel(None)

            # turn off the ticks
            ax.tick_params(axis="x", which="both", bottom=False)
            ax.tick_params(axis="y", which="both", left=False)

            # turn off axis tick labels
            ax.tick_params(axis="x", which="both", labelbottom=False)
            ax.tick_params(axis="y", which="both", labelleft=False)

            # Turn off top and right spines
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            if col > 0:
                ax.spines["left"].set_edgecolor("grey")
            else:
                ax.set_yticks([0, 1])
            if row == 0:
                ax.spines["bottom"].set_edgecolor("grey")

    fig.get_axes()[5].set_xlabel(
        "o3 Completion Tokens  [binned, log scale]",
        ha="left",
        x=0,
    )
    fig.get_axes()[0].set_ylabel(
        "Mean Accuracy",
        va="bottom",
        y=-0.08,
    )

    for i, ax in enumerate(fig.get_axes()[0:5]):
        ax.set_title(sl_bins[i])

    for ax in fig.get_axes()[5:]:
        ax.tick_params(axis="x", which="both", bottom=True, labelbottom=True)

    fig.get_axes()[0].tick_params(
        axis="y",
        which="both",
        left=True,
        labelleft=True,
        labelsize=10,
    )
    fig.get_axes()[5].tick_params(
        axis="y",
        which="both",
        left=True,
        labelleft=True,
        labelsize=10,
    )

    # Add sample type labels
    fig.get_axes()[0].text(
        0.1,
        0.1,
        "all samples",
        ha="left",
        va="bottom",
        color=darken("grey"),
        fontweight="bold",
        transform=fig.get_axes()[0].transAxes,
        fontsize=7,
    )
    fig.get_axes()[5].text(
        0.85,
        0.1,
        "positive",
        ha="right",
        va="bottom",
        color=PALETTE_SAMPLE_TYPE["positive"],
        fontweight="bold",
        transform=fig.get_axes()[5].transAxes,
        fontsize=7,
    )
    fig.get_axes()[5].text(
        0.05,
        0.7,
        "negative",
        ha="left",
        va="bottom",
        color=PALETTE_SAMPLE_TYPE["negative"],
        fontweight="bold",
        transform=fig.get_axes()[5].transAxes,
        fontsize=7,
    )

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    plt.savefig(
        FIGURES_DIR / "o3_accuracy_by_sample_length_and_completion_tokens.pdf",
        bbox_inches="tight",
    )

In [None]:
o3_tokens_acc_df.head()

In [None]:
fig = plt.figure(figsize=(PAPER_WIDTH_IN, 3))
grid = gs.GridSpec(nrows=5, ncols=5, figure=fig, hspace=0.15, wspace=0.1)

o3_tokens_acc_df = accuracy_df[accuracy_df.model == "o3"][
    [
        "completion_tokens",
        "correct",
        "sample.length",
        "sample.type.ground_truth",
        "n_nonlexical_productions",
    ]
].copy()
o3_tokens_acc_df["completion_tokens_bin"] = (
    o3_tokens_acc_df["completion_tokens"]
    .map(lambda x: np.log2(x))
    .round(0)
    .map(lambda x: 2**x)
)

# Sample lengths are in range 1--50; bin them into 5 bins
o3_tokens_acc_df["sample.length_bin"] = o3_tokens_acc_df["sample.length"].map(
    lambda x: "1–10 symbols"
    if x < 11
    else "11–20"
    if x < 21
    else "21–30"
    if x < 31
    else "31–40"
    if x < 41
    else "41–50"
)

o3_tokens_acc_df["n_nonlex_bin"] = (
    o3_tokens_acc_df["n_nonlexical_productions"]
    .round(-2)
    .astype(
        pd.CategoricalDtype(
            categories=[0, 100, 200, 300, 400],
            ordered=True,
        )
    )
)

o3_tokens_acc_untyped_df = (
    o3_tokens_acc_df.groupby(
        ["completion_tokens_bin", "sample.length_bin", "n_nonlex_bin"],
        observed=False,
    )[["correct"]]
    .mean()
    .reset_index()
)

sl_bins = o3_tokens_acc_df["sample.length_bin"].unique()
nlex_bins = o3_tokens_acc_df["n_nonlex_bin"].cat.categories
min_x = o3_tokens_acc_df["completion_tokens_bin"].min()
max_x = o3_tokens_acc_df["completion_tokens_bin"].max()

with sns.plotting_context("paper", font_scale=1):
    for row in range(5):
        for col in range(5):
            ax = fig.add_subplot(grid[row, col])

            data = o3_tokens_acc_untyped_df[
                (o3_tokens_acc_untyped_df["sample.length_bin"] == sl_bins[col])
                & (o3_tokens_acc_untyped_df["n_nonlex_bin"] == nlex_bins[row])
            ]
            sns.scatterplot(
                data=data,
                x="completion_tokens_bin",
                y="correct",
                ax=ax,
                size=8,
                color="grey",
                legend=None,
            )
            sns.regplot(
                data=data,
                x="completion_tokens_bin",
                y="correct",
                ax=ax,
                scatter=False,
                lowess=True,
                color="grey",
                line_kws={"color": "black", "alpha": 0.5},
            )

            _ = ax.axhline(
                y=0.5,
                color=COLOR_AT_CHANCE,
                alpha=ALPHA_AT_CHANCE,
                linestyle="--",
            )

            # Set the x-axis to log scale
            ax.set_xscale("log")
            ax.set_ylim(-0.02, 1.02)
            ax.set_xlim(min_x, max_x)

            # # Set the x-axis label
            ax.set_xlabel(None)
            ax.set_ylabel(None)

            # turn off the ticks
            ax.tick_params(axis="x", which="both", bottom=False)
            ax.tick_params(axis="y", which="both", left=False)

            # turn off axis tick labels
            ax.tick_params(axis="x", which="both", labelbottom=False)
            ax.tick_params(axis="y", which="both", labelleft=False)

            # add a y-axis label on the right-hand side
            if col == 4:
                ax.set_ylabel(
                    f"{nlex_bins[row]}",
                    # ha="left",
                    # x=0.0,
                    # y=0.5,
                )
                ax.yaxis.set_label_position("right")

            # Turn off top and right spines
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            if col > 0:
                ax.spines["left"].set_edgecolor("grey")
            else:
                ax.set_yticks([0, 1])
            if row == 0:
                ax.spines["bottom"].set_edgecolor("grey")

    fig.get_axes()[5].set_xlabel(
        "o3 Completion Tokens  [binned, log scale]",
        ha="left",
        x=0,
    )
    fig.get_axes()[0].set_ylabel(
        "Mean Accuracy",
        va="bottom",
        y=-0.08,
    )

    for i, ax in enumerate(fig.get_axes()[0:5]):
        ax.set_title(sl_bins[i])

    for ax in fig.get_axes()[5:]:
        ax.tick_params(axis="x", which="both", bottom=True, labelbottom=True)

    fig.get_axes()[0].tick_params(
        axis="y",
        which="both",
        left=True,
        labelleft=True,
        labelsize=10,
    )
    fig.get_axes()[5].tick_params(
        axis="y",
        which="both",
        left=True,
        labelleft=True,
        labelsize=10,
    )

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    # plt.savefig(
    #     FIGURES_DIR / "o3_accuracy_by_sample_length_and_completion_tokens.pdf",
    #     bbox_inches="tight",
    # )

In [None]:
o3_tokens_acc_df["n_nonlex_bin"].unique()

In [None]:
(
    accuracy_df[
        [
            "grammar_file",
            "correct",
            "sample",
            "sample.length",
            "sample.type.ground_truth",
            "model",
        ]
    ][accuracy_df["sample.type.ground_truth"] == "positive"]
    # .sort_values(by="n_nonlexical_productions")
    .query("model == 'DSR1-7B'")
    .query("correct == True")
    .query("grammar_file == 'grammar_20250402155408_676876'")
    .sort_values(by="sample.length")
)

## Rank Ordering

In [None]:
n_lexprod_bins = [1, 100, 200, 300, 400, 500]

fig = plt.figure(figsize=(PAPER_WIDTH_IN + 1, 1.25))
grid = gs.GridSpec(nrows=1, ncols=len(n_lexprod_bins) - 1, figure=fig, wspace=0.1)

for i, (lo, hi) in enumerate(zip(n_lexprod_bins, n_lexprod_bins[1:])):
    ax = fig.add_subplot(grid[0, i])

    ordering_corr = (
        accuracy_df[
            [
                "model",
                "grammar_file",
                "correct",
                "sample.length",
                "sample.type.ground_truth",
                "n_nonlexical_productions",
            ]
        ]
        .query("model != 'gemma-3-12b'")
        .query("model != 'gemma-3-27b'")
        .query("n_nonlexical_productions >= @lo and n_nonlexical_productions < @hi")
        .groupby(
            ["model", "grammar_file", "sample.type.ground_truth", "sample.length"],
            observed=False,
        )[["correct"]]
        .mean()
        .reset_index()
        .groupby(["grammar_file", "model"], observed=False)[["correct"]]
        .mean()
        .unstack("model")
        .dropna(axis=1, how="all")
        .corr(method="spearman")
        .droplevel(0, axis=1)
        .droplevel(0, axis=0)
    )

    mask = np.tril(np.ones_like(ordering_corr, dtype=bool))

    sns.heatmap(
        data=ordering_corr,
        mask=~mask,
        cmap=CMAP_HEATMAP,
        annot=False,
        # fmt=".2f",
        linewidths=0.5,
        linecolor="white",
        xticklabels=False,
        yticklabels=True if i == 0 else False,
        vmin=-1,
        vmax=1,
        ax=ax,
        cbar=False,
        square=True,
    )

    ax.set_xlabel(None)
    ax.set_ylabel(None)
    ax.tick_params(axis="y", labelsize=7)
    ax.tick_params(axis="x", labelsize=7)
    plt.xticks(rotation=45, ha="right", va="top")
    if i == 0:
        ax.set_title(f"{lo}–{hi} productions", fontsize=8)
    else:
        ax.set_title(f"{lo+1}–{hi}", fontsize=8)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    plt.savefig(
        FIGURES_DIR / "grammar_rank_corrs.pdf",
        bbox_inches="tight",
    )

In [None]:
sl_bins = [1, 10, 20, 30, 40, 50]

fig = plt.figure(figsize=(PAPER_WIDTH_IN + 1, 1.25))
grid = gs.GridSpec(nrows=1, ncols=len(sl_bins), figure=fig, wspace=0.1)

for i, (lo, hi) in enumerate(zip(sl_bins, sl_bins[1:])):
    ax = fig.add_subplot(grid[0, i])

    sample_ordering_corr = (
        accuracy_df.rename({"sample.length": "sl"}, axis=1)[
            [
                "correct",
                "sample",
                "model",
                "sl",
            ]
        ]
        .query("sl >= @lo and sl < @hi")
        .pivot_table(
            index="sample",
            columns="model",
            values="correct",
            aggfunc="mean",
            fill_value=np.nan,
            observed=False,
        )
        .dropna(axis=1, how="all")
        .corr(method="spearman")
    )

    mask = ~np.tril(np.ones_like(sample_ordering_corr, dtype=bool))

    sns.heatmap(
        sample_ordering_corr,
        mask=mask,
        cmap=CMAP_HEATMAP,
        annot=False,
        linewidths=0.5,
        linecolor="white",
        xticklabels=False,
        yticklabels=True if i == 0 else False,
        vmin=-1,
        vmax=1,
        ax=ax,
        cbar=False,
        square=True,
    )

    ax.set_xlabel(None)
    ax.set_ylabel(None)
    ax.tick_params(axis="y", labelsize=7)
    if i == 0:
        ax.set_title(f"{lo}–{hi} symbols", fontsize=8)
    else:
        ax.set_title(f"{lo+1}–{hi}", fontsize=8)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    plt.savefig(
        FIGURES_DIR / "sample_rank_corrs.pdf",
        bbox_inches="tight",
    )

In [None]:
n_lexprod_bins = [100, 200, 300]
sl_bins = [1, 10, 20]

fig = plt.figure(figsize=(PAPER_WIDTH_IN + 2, 1.6))
grid = gs.GridSpec(
    nrows=1, ncols=6, figure=fig, wspace=0.1, width_ratios=[1, 1, 0.2, 1, 1, 0.1]
)

with sns.plotting_context("paper", font_scale=1, rc=rcs):
    cax = fig.add_subplot(grid[0, 5])

    for i, (lo, hi) in enumerate(zip(n_lexprod_bins, n_lexprod_bins[1:])):
        ax = fig.add_subplot(grid[0, i])

        ordering_corr = (
            accuracy_df[
                [
                    "model",
                    "grammar_file",
                    "correct",
                    "sample.length",
                    "sample.type.ground_truth",
                    "n_nonlexical_productions",
                ]
            ]
            .query("model != 'gemma-3-12b'")
            .query("model != 'gemma-3-27b'")
            .query("n_nonlexical_productions >= @lo and n_nonlexical_productions < @hi")
            .groupby(
                ["model", "grammar_file", "sample.type.ground_truth", "sample.length"],
                observed=False,
            )[["correct"]]
            .mean()
            .reset_index()
            .groupby(["grammar_file", "model"], observed=False)[["correct"]]
            .mean()
            .unstack("model")
            .dropna(axis=1, how="all")
            .corr(method="spearman")
            .droplevel(0, axis=1)
            .droplevel(0, axis=0)
        )

        mask = np.triu(np.ones_like(ordering_corr, dtype=bool))

        sns.heatmap(
            data=ordering_corr,
            mask=mask,
            cmap=CMAP_HEATMAP,
            annot=False,
            # fmt=".2f",
            linewidths=0.5,
            linecolor="white",
            xticklabels=True,
            yticklabels=True if i == 0 else False,
            vmin=-1,
            vmax=1,
            ax=ax,
            cbar=False,
        )

        ax.set_xlabel(None)
        ax.set_ylabel(None)
        ax.tick_params(axis="y", labelsize=7)
        ax.tick_params(axis="x", labelsize=7)
        plt.xticks(rotation=45, ha="right", va="top")
        if i == 0:
            ax.set_title(f"{lo+1}–{hi} productions", fontsize=8, ha="left", x=0)
        else:
            ax.set_title(f"{lo+1}–{hi}", fontsize=8, ha="left", x=0)

    for i, (lo, hi) in enumerate(zip(sl_bins, sl_bins[1:])):
        ax = fig.add_subplot(grid[0, i + 3])

        sample_ordering_corr = (
            accuracy_df.rename({"sample.length": "sl"}, axis=1)[
                [
                    "correct",
                    "sample",
                    "model",
                    "sl",
                ]
            ]
            .query("sl >= @lo and sl < @hi")
            .pivot_table(
                index="sample",
                columns="model",
                values="correct",
                aggfunc="mean",
                fill_value=np.nan,
                observed=False,
            )
            .dropna(axis=1, how="all")
            .corr(method="spearman")
        )

        mask = np.triu(np.ones_like(sample_ordering_corr, dtype=bool))

        sns.heatmap(
            sample_ordering_corr,
            mask=mask,
            cmap=CMAP_HEATMAP,
            annot=False,
            linewidths=0.5,
            linecolor="white",
            xticklabels=True,
            yticklabels=False,
            vmin=-1,
            vmax=1,
            ax=ax,
            cbar=True if i == 0 else False,
            cbar_ax=cax if i == 0 else None,
            cbar_kws={"ticks": [-1, -0.5, 0, 0.5, 1]} if i == 0 else None,
        )

        ax.set_xlabel(None)
        ax.set_ylabel(None)
        ax.tick_params(axis="x", labelsize=7)
        plt.xticks(rotation=45, ha="right", va="top")
        if i == 0:
            ax.set_title(f"{lo}–{hi} symbols", fontsize=8, ha="left", x=0)
        else:
            ax.set_title(f"{lo+1}–{hi}", fontsize=8, ha="left", x=0)

    # # draw a horizontal line over the first two subplots
    # pos1 = fig.axes[0].get_position()
    # pos2 = fig.axes[1].get_position()
    # pos3 = fig.axes[2].get_position()
    # pos4 = fig.axes[3].get_position()
    # y = pos1.y1 + 0.3  # tweak 0.01 up/down as needed

    # # draw a horizontal line from the left edge of ax1 to the right edge of ax2
    # line1 = mpl.lines.Line2D(
    #     [pos1.x0 - 0.13, pos2.x1 - 0.04],
    #     [y, y],
    #     transform=fig.transFigure,
    #     color='grey',
    #     linewidth=1
    # )
    # fig.add_artist(line1)
    # line2 = mpl.lines.Line2D(
    #     [pos3.x0 - 0.02, pos4.x1 + 0.08],
    #     [y, y],
    #     transform=fig.transFigure,
    #     color='grey',
    #     linewidth=1
    # )
    # fig.add_artist(line2)

    # # add text above each line
    # fig.text(
    #     pos1.x0 - 0.13,
    #     y + 0.02,
    #     "Per-grammar Correlations",
    #     ha="left",
    #     va="bottom",
    #     fontsize=8,
    #     fontweight="bold",
    # )

    # fig.text(
    #     pos3.x0 - 0.02,
    #     y + 0.02,
    #     "Per-sample Correlations",
    #     ha="left",
    #     va="bottom",
    #     fontsize=8,
    #     fontweight="bold",
    # )

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)

    plt.savefig(
        FIGURES_DIR / "grammar-sample_rank_corrs_small.pdf",
        bbox_inches="tight",
    )

In [None]:
(
    accuracy_df.rename({"sample.length": "sl"}, axis=1)[
        [
            "correct",
            "sample",
            "model",
            "sl",
        ]
    ]
    .query("sl >= 11 and sl < 21")
    .pivot_table(
        index="sample",
        columns="model",
        values="correct",
        aggfunc="mean",
        fill_value=np.nan,
        observed=False,
    )
    # .dropna(axis=1, how="all")
    .corr(method="spearman")
)

## Strategy Classification

In [None]:
classification_df = pd.read_feather(
    PROJECT_ROOT / "data" / "gpt_classification_df.feather"
)

In [None]:
TTC_MODELS = ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4.1", "o4-mini", "o3"]


fig_height = 1.6
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(
    2,
    len(TTC_MODELS),
    wspace=0.05,
    hspace=0.2,
    # width_ratios=[1, 1, 1, 0.6, 0.6]
)

binned_sl_df = accuracy_df[
    [
        "completion_tokens",
        "correct",
        "model",
        "sample.length",
        "sample.type.ground_truth",
    ]
].copy()
binned_sl_df = (
    binned_sl_df.groupby(
        ["sample.length", "model", "sample.type.ground_truth"], observed=False
    )[["completion_tokens"]]
    .mean()
    .reset_index()
    .groupby(["sample.length", "model"], observed=False)[["completion_tokens"]]
    .mean()
    .reset_index()
)

binned_sl_df["relative_ttc"] = binned_sl_df.groupby("model", observed=True)[
    "completion_tokens"
].transform(lambda x: x / x.max())

min_relttc = binned_sl_df["relative_ttc"].min()

with sns.plotting_context("paper", rc=rcs):
    for r in [0, 1]:
        for c, model in enumerate(TTC_MODELS):
            ax = fig.add_subplot(grid[r, c])

            peak_ttc_xval = (
                binned_sl_df[binned_sl_df["model"] == model]
                .sort_values(by="relative_ttc", ascending=False)
                .iloc[0]["sample.length"]
            )

            if r == 0 or c < 3:
                ax.axvline(
                    x=peak_ttc_xval,
                    color=MODEL_COLOR,
                    linestyle="--",
                    alpha=0.8,
                    zorder=5,
                )

            ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1))

            if r == 0:
                sns.lineplot(
                    data=binned_sl_df[binned_sl_df.model == model],
                    x="sample.length",
                    y="relative_ttc",
                    color=MODEL_COLOR,
                    errorbar="se",
                )
                ax.set_title(model, fontsize=7)
                ax.set_ylim(0, 1)
                ax.set_xlabel(None)
                ax.set_xlim(1, 50)
                ax.set_xticks([1, peak_ttc_xval, 50])
                ax.set_xticklabels([])
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_yticks([0, 1])
                    ax.set_ylabel("Relative TTC")
                else:
                    ax.set_yticks([])
                    ax.set_ylabel(None)
            else:
                if c > 2:
                    # remove axis
                    ax.set_xlabel(None)
                    ax.set_ylabel(None)
                    ax.set_yticks([])
                    ax.spines["top"].set_visible(False)
                    ax.spines["right"].set_visible(False)
                    # ax.spines["left"].set_visible(False)
                    # ax.spines["bottom"].set_visible(False)
                    ax.set_facecolor("#eeeeee")

                    ax.set_xticks([1, peak_ttc_xval, 50])
                    ax.set_xlim(1, 50)
                    ax.get_xticklabels()[0].set_ha("left")
                    ax.get_xticklabels()[0].set_ha("left")
                    ax.get_xticklabels()[1].set_ha("left")
                    ax.get_xticklabels()[-1].set_ha("right")

                    # draw lines connecting the opposite corners
                    ax.plot(
                        [0, 1],
                        [0, 1],
                        transform=ax.transAxes,
                        color="black",
                        alpha=0.2,
                        linewidth=1,
                    )
                    ax.plot(
                        [0, 1],
                        [1, 0],
                        transform=ax.transAxes,
                        color="black",
                        alpha=0.2,
                        linewidth=1,
                    )
                else:
                    sns.histplot(
                        classification_df[classification_df["model"] == model],
                        x="sample.length",
                        hue="strategy",
                        palette=PALETTE_STRAGETY,
                        stat="proportion",
                        multiple="fill",
                        bins=50,
                        ax=ax,
                        alpha=0.8,
                        linewidth=0,
                        legend=False,
                    )
                    ax.set_title(None)
                    ax.spines["top"].set_visible(False)
                    ax.spines["right"].set_visible(False)
                    ax.set_xticks([1, peak_ttc_xval, 50])
                    ax.set_xlim(1, 50)
                    ax.get_xticklabels()[0].set_ha("left")
                    ax.get_xticklabels()[1].set_ha("left")
                    ax.get_xticklabels()[-1].set_ha("right")

                    if c == 0:
                        ax.set_xlabel(
                            "Task Complexity (Example Length)", ha="left", x=0.0
                        )
                        ax.set_ylabel("Proportion\nof Strategies")
                        ax.set_yticks([0, 1])
                    else:
                        ax.set_xlabel(None)
                        ax.set_ylabel(None)
                        ax.set_yticks([])

                    if c == 0:
                        ax.text(
                            0.2,
                            0.35,
                            "rule-based",
                            color=PALETTE_STRAGETY["rule-based"],
                            ha="left",
                            va="center",
                            transform=ax.transAxes,
                            fontsize=8,
                            fontweight="bold",
                        )

                        ax.text(
                            0.95,
                            0.9,
                            "heuristic",
                            color="#ce8669",
                            ha="right",
                            va="top",
                            transform=ax.transAxes,
                            fontsize=8,
                            fontweight="bold",
                        )

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "gpt_strategy.pdf",
        bbox_inches="tight",
    )

In [None]:
binned_sl_df = accuracy_df[
    [
        "completion_tokens",
        "correct",
        "model",
        "sample.length",
        "sample.type.ground_truth",
    ]
].copy()
binned_sl_df = (
    binned_sl_df.groupby(
        ["sample.length", "model", "sample.type.ground_truth"], observed=False
    )[["completion_tokens"]]
    .mean()
    .reset_index()
    .groupby(["sample.length", "model"], observed=False)[["completion_tokens"]]
    .mean()
    .reset_index()
)

binned_sl_df["relative_ttc"] = binned_sl_df.groupby("model", observed=True)[
    "completion_tokens"
].transform(lambda x: x / x.max())

min_relttc = binned_sl_df["relative_ttc"].min()

TTC_MODELS = ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4.1", "o4-mini", "o3"]

fig_height = 0.8
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, len(TTC_MODELS), wspace=0.05)

with sns.plotting_context("paper", rc=rcs):
    for c, model in enumerate(TTC_MODELS):
        ax = fig.add_subplot(grid[0, c])

        peak_ttc_xval = (
            binned_sl_df[binned_sl_df["model"] == model]
            .sort_values(by="relative_ttc", ascending=False)
            .iloc[0]["sample.length"]
        )

        # ax.yaxis.set_major_formatter(mpl.ticker.PercentFormatter(1))

        sns.lineplot(
            data=binned_sl_df[binned_sl_df.model == model],
            x="sample.length",
            y="completion_tokens",
            color=MODEL_COLOR,
            errorbar="se",
        )
        ax.set_title(model, fontsize=7)
        ax.set_ylim(0, 10_000)
        # ax.set_yscale("log")
        ax.set_xlabel(None)
        ax.set_xlim(1, 50)
        ax.set_xticks([1, 50])
        ax.set_xticklabels([1, 50])
        ax.get_xticklabels()[1].set_ha("left")
        ax.get_xticklabels()[-1].set_ha("right")

        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)

        if c == 0:
            # ax.set_yticks([0, 1])
            ax.set_ylabel("TTC")
            ax.set_xlabel("Task Complexity (Example Length)", ha="left", x=0.0)
        else:
            ax.set_yticks([])
            ax.set_ylabel(None)

    for o in fig.findobj():
        o.set_clip_on(False)

    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "ttc.pdf",
        bbox_inches="tight",
    )

In [None]:
(response_df.query("model == 'gpt-4.1-mini'").query("correct == True"))

In [None]:
accuracy_df[
    [
        "n_nonlexical_productions",
        "n_lexical_productions",
        "n_terminals",
        "n_nonterminals",
    ]
].describe()