# Response Classification

In [None]:
import colorsys
import json
import os
import pathlib
import re
from collections import defaultdict
from typing import Any

import matplotlib as mpl
import matplotlib.gridspec as gs
import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyrootutils
import regex as re
import seaborn as sns
import sklearn.metrics as sk_metrics
import statsmodels.api as sm
import statsmodels.formula.api as smf
from dotenv import load_dotenv
from IPython.display import display
from openai import OpenAI

from formal_gym import prompt as fg_prompt

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

load_dotenv(os.path.join(PROJECT_ROOT, ".env"))

client = OpenAI()

## Plotting Setup

In [None]:
FIGURES_DIR = PROJECT_ROOT / "notebooks" / "figures"

PAPER_WIDTH_IN = 5.5

rcs = {
    "font.size": 10.0,
    "axes.labelsize": "small",
    "axes.titlesize": "small",
    "xtick.labelsize": "x-small",
    "ytick.labelsize": "x-small",
}


def darken(
    color: str
    | tuple[float, float, float]
    | dict[str, Any]
    | sns.palettes._ColorPalette,
    by: float = 0.2,
):
    """
    Darken a color by provided amount.
    """

    def _darken_color(c: str | tuple[float, float, float], by: float):
        by = min(max(0, by), 1)
        pct_darken = 1 - by

        if isinstance(c, str):
            c = sns.color_palette([c])[0]

        for c_i in c:
            if c_i > 1:
                c_i /= 255
        c_hls = colorsys.rgb_to_hls(c[0], c[1], c[2])
        # Darken the color by reducing the lightness

        c_hls = (
            c_hls[0],  # hue
            c_hls[1] * pct_darken,  # lightness
            c_hls[2],  # saturation
        )
        # Convert back to RGB
        c_rgb = colorsys.hls_to_rgb(c_hls[0], c_hls[1], c_hls[2])
        return c_rgb

    if isinstance(color, dict):
        # If color is a dictionary, assume it's a palette
        # and darken each color in the palette
        return {k: _darken_color(v, by) for k, v in color.items()}
    elif isinstance(color, sns.palettes._ColorPalette):
        colors = [_darken_color(c, by) for c in color]
        return sns.palettes._ColorPalette(colors)
    else:
        return _darken_color(color, by)


# For heatmaps, correlations, -1 to 1 scales, etc
CMAP_HEATMAP = "vlag"

# For any plots where color differentiates sample type
PALETTE_SAMPLE_TYPE = {
    "positive": darken("#ffcc66"),
    "negative": "#5c5cff",
    "unknown": "#C41E3A",
}

# For any plots where color differentiates model
PALETTE_MODEL = darken(
    {
        "gpt-4.1-nano": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[0],
        "gpt-4.1-mini": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[1],
        "gpt-4.1": sns.cubehelix_palette(start=0.5, rot=-0.5, n_colors=4)[2],
        "o4-mini": sns.color_palette("YlOrBr", n_colors=2)[0],
        "o3": sns.color_palette("YlOrBr", n_colors=2)[1],
        "gemma-3-1b": sns.cubehelix_palette(n_colors=5)[0],
        "gemma-3-4b": sns.cubehelix_palette(n_colors=5)[1],
        "gemma-3-12b": sns.cubehelix_palette(n_colors=5)[2],
        "gemma-3-27b": sns.cubehelix_palette(n_colors=5)[3],
        "DSR1-7B": sns.color_palette("cool", n_colors=2)[0],
    }
)

PALETTE_SCORE = {
    "Weighted F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[0],
    "Macro F1": sns.color_palette("terrain", n_colors=2, desat=0.8)[1],
}

PALETTE_STRAGETY = {
    "rule-based": "#942822",
    "heuristic": "#FFE7CE",
    "code": sns.color_palette("gist_earth", n_colors=4)[2],
    "unknown": sns.color_palette("gist_earth", n_colors=4)[3],
}

PALETTES = {
    "model": PALETTE_MODEL,
    "sample_type": PALETTE_SAMPLE_TYPE,
    "score": PALETTE_SCORE,
    "strategy": PALETTE_STRAGETY,
}

# For marking at-chance baselines
COLOR_AT_CHANCE = "#ff0000"  # Red
ALPHA_AT_CHANCE = 0.5

# Bar Chart settings
BAR_EDGE_COLOR = "black"
BAR_EDGE_WIDTH = 0.8


def display_palette(palette: dict[str, Any] | sns.palettes._ColorPalette):
    if isinstance(palette, sns.palettes._ColorPalette):
        display(palette)
    else:
        colors = list(palette.values())
        display(sns.color_palette(colors))


def filter_by_alpha(
    keys: list[str],
    ax,
    palette: dict[str, Any] | None = None,
    alpha=0.3,
    highlight: str | list[str] | None = None,
):
    alphas = defaultdict(lambda: alpha)
    if highlight is not None:
        if isinstance(highlight, str):
            highlight = [highlight]
        for h in highlight:
            alphas[h] = 1

    if palette is None:
        # look through all the palettes in PALETTE; find one whose keys match
        # the keys passed in; if found, use that palette; otherwise, throw an error
        for palette_name, palette_dict in PALETTES.items():
            if all(k in palette_dict for k in keys):
                palette = palette_dict
                break
        else:
            raise ValueError(f"No matching palette found for keys: {keys}")

    face_colors = ax.collections[0].get_facecolors()
    face_colors[:, 3] = alpha
    for key in keys:
        key_color = palette[key]

        # Get indices of face_colors whose first 3 values match the model color
        indices = [
            i
            for i, color in enumerate(face_colors)
            if (color[0], color[1], color[2]) == key_color[:3]
        ]
        for i in indices:
            face_colors[i][3] = alphas[key]
    ax.collections[0].set_facecolor(face_colors)


def legend_format(
    ax: mpl.axes._axes.Axes | sns.FacetGrid,
    keys: list[str] | None = None,
    title: str | None = None,
    **kwargs,
):
    if isinstance(ax, sns.FacetGrid):
        fg = ax
        ax = fg.ax

        _ = fg.legend.remove()

    # Legend Formatting
    handles, labels = ax.get_legend_handles_labels()

    if keys is not None:
        if "gpt-4.1" in keys:
            spacing_locs = [3, 6]
        else:
            ValueError(f"Unsure how to space legend for {keys=}")

        spacer = mpatches.Patch(alpha=0, linewidth=0)
        for sloc in spacing_locs:
            handles.insert(sloc, spacer)
            labels.insert(sloc, "")

    _ = ax.legend(handles, labels)
    _ = ax.get_legend().set_frame_on(False)

    if title is not None:
        _ = ax.get_legend().set_title(title)

    if "loc" not in kwargs:
        kwargs["loc"] = "upper left"
    if "bbox_to_anchor" not in kwargs:
        kwargs["bbox_to_anchor"] = (1, 1)
    _ = sns.move_legend(ax, **kwargs)


# MARK: Printing, experimentation, etc
display(sns.color_palette("terrain", n_colors=2, desat=0.8))
display(darken(sns.color_palette("terrain", n_colors=2, desat=0.8), by=0.2))

display_palette(PALETTE_MODEL)
display_palette(PALETTE_SAMPLE_TYPE)
display_palette(PALETTE_STRAGETY)

In [None]:
GRAMMARS_DIR = PROJECT_ROOT / "data" / "grammars"

response_df_file = PROJECT_ROOT / "data" / "response_df.feather"
f1_df_file = PROJECT_ROOT / "data" / "f1_df.feather"
acc_df_file = PROJECT_ROOT / "data" / "acc_df.feather"

response_df = None
accuracy_df = None
f1_df = None

if os.path.exists(response_df_file):
    print(f"Loading `response_df` from {response_df_file}")
    response_df = pd.read_feather(
        response_df_file,
    )

if os.path.exists(f1_df_file):
    print(f"Loading `f1_df` from {f1_df_file}")
    f1_df = pd.read_feather(
        f1_df_file,
    )

if os.path.exists(acc_df_file):
    print(f"Loading `accuracy_df` from {acc_df_file}")
    accuracy_df = pd.read_feather(
        acc_df_file,
    )

## DSR1 Overthinking

In [None]:
dsr1_acc_df = accuracy_df[accuracy_df.model == "DSR1-7B"].copy()
dsr1_acc_df.info()

In [None]:
# Count the number of times "wait" appears in the model_response column

dsr1_acc_df["wait_count"] = dsr1_acc_df["model_response"].str.count(
    r"\bwait\b", flags=re.IGNORECASE
)

In [None]:
fig_height = 0.8

fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, 2, wspace=0.05, width_ratios=[1.3, 2])

with sns.plotting_context("paper", rc=rcs):
    for c in range(2):
        ax = fig.add_subplot(grid[0, c])
        if c == 0:
            sns.lineplot(
                dsr1_acc_df,
                x="sample.length",
                y="wait_count",
                color=PALETTE_MODEL["DSR1-7B"],
                ax=ax,
            )
            ax.set_xlabel("Task Complexity (Sample Length)", ha="left", x=0.0)
            ax.set_ylabel("“Wait” Count")
            ax.set_yticks([0, 5, 10, 15, 20])
        else:
            sns.lineplot(
                dsr1_acc_df,
                x="n_nonlexical_productions",
                y="wait_count",
                color=PALETTE_MODEL["DSR1-7B"],
                ax=ax,
            )
            ax.set_ylabel(None)
            ax.set_xlabel(
                "Instruction Set Complexity (# of Nonlexical Productions)",
                ha="left",
                x=0.0,
            )
            ax.set_yticks([])
            ax.set_xscale("log")
            ax.set_xticks([10, 100])
            ax.set_xticklabels([10, 100])
            ax.set_ylim(0, 20)
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)
    plt.savefig(
        FIGURES_DIR / "dsr1_wait_count.pdf",
        bbox_inches="tight",
    )

## GPT-4.1 Strategy Classification

In [None]:
gpt_df = accuracy_df[
    accuracy_df["model"].isin(["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4.1"])
].copy()[
    [
        "model_response",
        "sample",
        "sample.type.ground_truth",
        "correct",
        "sample.length",
        "n_nonlexical_productions",
        "grammar_file",
        "model",
    ]
]

models_to_remove = set(accuracy_df["model"].cat.categories) - set(
    ["gpt-4.1-nano", "gpt-4.1-mini", "gpt-4.1"]
)

gpt_df["model"] = gpt_df["model"].cat.remove_categories(models_to_remove)

gpt_df.info()

In [None]:
test_grammar = gpt_df["grammar_file"].unique()[10]
print(f"Test grammar: {test_grammar}")

In [None]:
def plot_gpt_acc(grammar_idx: int):
    test_grammar = gpt_df["grammar_file"].unique()[grammar_idx]

    acc_by_length_df = (
        gpt_df[gpt_df["grammar_file"] == test_grammar]
        .groupby(["sample.length", "model"], observed=True)["correct"]
        .mean()
        .reset_index()
    )

    with sns.plotting_context("paper", rc=rcs):
        fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
        grid = fig.add_gridspec(1, 3, wspace=0.05)

        for c, model in enumerate(acc_by_length_df["model"].unique()):
            ax = fig.add_subplot(grid[0, c])
            sns.lineplot(
                acc_by_length_df[acc_by_length_df["model"] == model],
                x="sample.length",
                y="correct",
                color=PALETTE_MODEL[model],
                ax=ax,
            )
            ax.set_ylim(0, 1)
            ax.set_title(model)
            ax.spines["top"].set_visible(False)
            ax.spines["right"].set_visible(False)

            if c == 0:
                ax.set_xlabel("Task Complexity (Sample Length)", ha="left", x=0.0)
                ax.set_ylabel("Accuracy")
                ax.set_yticks([0, 1])
            else:
                ax.set_xlabel(None)
                ax.set_ylabel(None)
                ax.set_yticks([])

    return test_grammar

In [None]:
plot_gpt_acc(195)

In [None]:
GRAMMAR_IDX = 192

grammar_name = plot_gpt_acc(GRAMMAR_IDX)
# grammar_name = gpt_df["grammar_file"].unique()[GRAMMAR_IDX]
print(f"Grammar name: {grammar_name}")

In [None]:
gpt_192_df = gpt_df[gpt_df["grammar_file"] == grammar_name].reset_index()

gpt_192_df

In [None]:
def create_batch_classification_file(grammar_name: str, pct_per_length: float = 0.2):
    classification_prompt = """You will be presented with a completion from an LLM which was given a context-free grammar and a string of symbols drawn from that grammar's set of terminal symbols and asked to determine whether the string is generated by the grammar or not. Your job is to classify how the LLM attempted to solve the task by binning the completion strategy into one of the following categories:
  - `heuristic`: The LLM attempts to solve the task by using heuristics it surmises from the grammar, such as “if the string is long, it is likely generated by the grammar” or “the string only contains terminal symbols present in the grammar, so it’s likely a positive sample”. Count strategies as heuristic if they appeal to the existence of certain production rules but do not rigorously determine that no such derivation exists.
  - `rule-based`: The LLM attempts to solve the task by writing out the FULL DERIVATION of the sample from the grammar, or rigorously determining that no such derivation exists. Only count strategies as rule-based if the LLM doesn’t use any guesswork to reach its final conclusion.
  - `code`: The LLM attempts to solve the task by writing out a program or algorithm which it claims will solve the task. This includes writing out a program in a programming language, or writing out pseudocode.

You can write as much as you want in your answer, but please end your response with the name of the classification you think is most appropriate.

Here is the LLM's completion:

```
{completion}
```
"""

    grammar_df = gpt_df[gpt_df["grammar_file"] == grammar_name].reset_index(drop=True)

    subsampled_df = (
        grammar_df.groupby(
            ["sample.length", "sample.type.ground_truth", "model"], observed=True
        )[
            [
                "model_response",
                "sample.type.ground_truth",
                "correct",
                "sample.length",
                "grammar_file",
                "sample",
                "model",
            ]
        ]
        .apply(lambda x: x.sample(frac=pct_per_length, random_state=42))
        .reset_index(drop=True)
    )

    subsampled_df["classification_prompt"] = subsampled_df["model_response"].map(
        lambda x: classification_prompt.format(completion=x)
    )

    subsampled_df["api_request"] = subsampled_df.apply(
        lambda row: fg_prompt.ChatCompletionResponse(
            user_prompt=row["classification_prompt"],
            metadata={
                "response_model": row["model"],
                "sample.length": row["sample.length"],
                "sample.type.ground_truth": row["sample.type.ground_truth"],
                "sample": row["sample"],
                "correct": row["correct"],
                "grammar_file": row["grammar_file"],
            },
        ).to_openai_batched_json(
            model="o4-mini",
            custom_id=str(row.name),
        ),
        axis=1,
    )

    classification_filename = f"{grammar_name}_gpt4.1_classification.jsonl"
    batch_jsonl_path = (
        PROJECT_ROOT / "data" / "grammars" / grammar_name / classification_filename
    )
    with open(batch_jsonl_path, "w") as f:
        for j in subsampled_df["api_request"]:
            f.write(f"{j}\n")

    return classification_filename

In [None]:
grammars_data = []

grammar_names = (
    gpt_df["grammar_file"].drop_duplicates().sample(frac=0.2, random_state=42)
)

for g in grammar_names:
    grammars_data.append(
        {
            "grammar_name": g,
            "classification_filename": create_batch_classification_file(g),
        }
    )

In [None]:
def submit_classification_batch(
    batch_jsonl_path: str,
):
    file_response = client.files.create(
        file=open(batch_jsonl_path, "rb"), purpose="batch"
    )

    batch = client.batches.create(
        input_file_id=file_response.id,
        endpoint="/v1/chat/completions",
        completion_window="24h",
    )

    return {
        "batch": batch,
        "file": file_response,
    }

In [None]:
# for i, g in enumerate(grammars_data):
#     print(f"Submitting batch {i+1}/{len(grammars_data)}")

#     batch_jsonl_path = (
#         PROJECT_ROOT / "data" / "grammars" / g["grammar_name"] / g["classification_filename"]
#     )

#     response = submit_classification_batch(
#         batch_jsonl_path=batch_jsonl_path
#     )
#     grammars_data[i]["response"] = response
#     print(f"Batch ID: {response['batch'].id}")

In [None]:
for g in grammars_data:
    g["batch_id"] = g["response"]["batch"].id

grammars_data

In [None]:
def download_classification_batches(
    grammars_data: list[dict[str, Any]],
):
    for g in grammars_data:
        input_filename = g["classification_filename"]
        preds_filename = input_filename[:-6] + "-preds.jsonl"
        output_path = (
            PROJECT_ROOT / "data" / "grammars" / g["grammar_name"] / preds_filename
        )
        if not output_path.exists():
            batch_id = g["batch_id"]

            # print(batch_id)

            batch = client.batches.retrieve(batch_id)
            if batch.status == "completed":
                results = client.files.content(batch.output_file_id)
                with open(output_path, "w") as f:
                    f.write(results.text)
                print(f"Results saved to {output_path}")
            else:
                print(f"{batch_id} status: {batch.status}")
        else:
            print(f"File {output_path} already exists, skipping download.")


download_classification_batches(grammars_data)

In [None]:
def load_response_classifications() -> pd.DataFrame:
    GRAMMARS_DIR = PROJECT_ROOT / "data" / "grammars"
    response_file_pattern = "*gpt4.1_classification-preds.jsonl"
    submission_file_pattern = "*gpt4.1_classification.jsonl"
    submission_files = list(GRAMMARS_DIR.rglob(submission_file_pattern))
    response_files = list(GRAMMARS_DIR.rglob(response_file_pattern))

    response_dfs = []
    for response_file in response_files:
        with open(response_file, "r") as f:
            cat_df = pd.read_json(f, lines=True)
            cat_struct = json.loads(cat_df.to_json(orient="records"))
            cat_df = pd.json_normalize(cat_struct)
            response_dfs.append(cat_df)

    response_df = pd.concat(response_dfs, ignore_index=True)

    submission_dfs = []
    for submission_file in submission_files:
        with open(submission_file, "r") as f:
            sub_df = pd.read_json(f, lines=True)
            sub_struct = json.loads(sub_df.to_json(orient="records"))
            sub_df = pd.json_normalize(sub_struct)
            submission_dfs.append(sub_df)
    submission_df = pd.concat(submission_dfs, ignore_index=True)

    classification_df = pd.merge(
        submission_df,
        response_df,
        on="custom_id",
        suffixes=("_submission", "_response"),
    )

    classification_df["model"] = classification_df["body.metadata.response_model"]
    classification_df["sample.length"] = classification_df[
        "body.metadata.sample.length"
    ]
    classification_df["sample.type.ground_truth"] = classification_df[
        "body.metadata.sample.type.ground_truth"
    ]
    classification_df["sample"] = classification_df["body.metadata.sample"]
    classification_df["correct"] = classification_df["body.metadata.correct"]
    classification_df["grammar_file"] = classification_df["body.metadata.grammar_file"]
    classification_df["prompt"] = classification_df["body.messages"].apply(
        lambda x: x[0]["content"]
    )
    classification_df["response"] = classification_df["response.body.choices"].apply(
        lambda x: x[0]["message"]["content"]
    )

    classification_df = classification_df[
        [
            "custom_id",
            "model",
            "sample.length",
            "sample.type.ground_truth",
            "sample",
            "correct",
            "response",
            "grammar_file",
            "prompt",
        ]
    ]

    strategy_regex = re.compile(
        r"(?:heuristic|rule[-|‐]based|code)", flags=re.IGNORECASE
    )

    classification_df["strategy"] = (
        classification_df["response"]
        .apply(lambda x: strategy_regex.findall(x))
        .apply(lambda x: x[-1] if len(x) > 0 else "unknown")
        .str.lower()
    )

    classification_df["strategy"] = classification_df["strategy"].map(
        {
            "heuristic": "heuristic",
            "rule-based": "rule-based",
            "code": "rule-based",
            "unknown": "unknown",
        }
    )

    return {
        "submissions": submission_df.dropna(),
        "classifications": classification_df.dropna(),
    }

In [None]:
df_dict = load_response_classifications()
submission_df = df_dict["submissions"]
classification_df = df_dict["classifications"]

submission_df.info()

In [None]:
submission_df["body.messages"].iloc[0][0]["content"]

In [None]:
(
    classification_df[
        (classification_df["sample.type.ground_truth"] == "positive")
        & (classification_df["sample.length"] < 6)
    ]
    .query("model == 'gpt-4.1-mini'")
    .query("correct == True")
    .query("strategy == 'rule-based'")[
        ["sample.length", "sample", "response", "prompt", "grammar_file"]
    ]
).iloc[100]["prompt"]

In [None]:
classification_df.to_feather(
    PROJECT_ROOT / "data" / "gpt_classification_df.feather",
)

In [None]:
fig_height = 0.8
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(1, 3, wspace=0.05)

with sns.plotting_context("paper", rc=rcs):
    for c, model in enumerate(classification_df["model"].unique()):
        ax = fig.add_subplot(grid[0, c])
        sns.histplot(
            classification_df[classification_df["model"] == model],
            x="sample.length",
            hue="strategy",
            palette=PALETTE_STRAGETY,
            stat="proportion",
            multiple="fill",
            bins=50,
            ax=ax,
            alpha=0.8,
            linewidth=0,
            legend=False,
        )
        ax.set_title(model, fontsize=7)
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
        ax.set_xticks([1, 50])
        ax.get_xticklabels()[0].set_ha("left")
        ax.get_xticklabels()[-1].set_ha("right")

        if c == 0:
            ax.set_xlabel("Task Complexity (Sample Length)", ha="left", x=0.0)
            ax.set_ylabel("Proportion\nof Strategies")
            ax.set_yticks([0, 1])
        else:
            ax.set_xlabel(None)
            ax.set_ylabel(None)
            ax.set_yticks([])

        if c == 0:
            ax.text(
                0.4,
                0.35,
                "rule-based",
                color=PALETTE_STRAGETY["rule-based"],
                ha="center",
                va="center",
                transform=ax.transAxes,
                fontsize=8,
                fontweight="bold",
            )

            ax.text(
                0.95,
                0.9,
                "heuristic",
                color="#ce8669",
                ha="right",
                va="top",
                transform=ax.transAxes,
                fontsize=8,
                fontweight="bold",
            )

    for o in fig.findobj():
        o.set_clip_on(False)
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

    plt.savefig(
        FIGURES_DIR / "gpt_strategy.pdf",
        bbox_inches="tight",
    )

In [None]:
classification_df

In [None]:
gpt_192_df["classification_prompt"] = gpt_192_df["model_response"].map(
    lambda x: classification_prompt + f"```\n{x}\n```"
)

gpt_192_df["api_request"] = gpt_192_df.apply(
    lambda row: fg_prompt.ChatCompletionResponse(
        user_prompt=row["classification_prompt"],
        metadata={"response_model": row["model"]},
    ).to_openai_batched_json(
        model="o4-mini",
        custom_id=str(row.name),
    ),
    axis=1,
)


display(gpt_192_df.iloc[0]["api_request"])

In [None]:
classification_filename = f"{grammar_name}-classification_o4-mini.jsonl"
batch_jsonl_path = PROJECT_ROOT / "notebooks" / "data" / classification_filename
with open(batch_jsonl_path, "w") as f:
    for j in gpt_192_df["api_request"]:
        f.write(f"{j}\n")

In [None]:
str(pathlib.Path("data") / classification_filename)

In [None]:
file_response = client.files.create(file=open(batch_jsonl_path, "rb"), purpose="batch")

file_id = file_response.id

batch = client.batches.create(
    input_file_id=file_id, endpoint="/v1/chat/completions", completion_window="24h"
)

In [None]:
# batch_6822a0de128881908790210c526cf470


In [None]:
batch = client.batches.retrieve("batch_6822a0de128881908790210c526cf470")
output = client.files.content(batch.output_file_id)


output_file_path = (
    PROJECT_ROOT / "notebooks" / "data" / f"{batch.id}-classification.jsonl"
)
with open(output_file_path, "w") as f:
    f.write(output.text)
# client.batches.retrieve(batch.id)

In [None]:
# Open output_file_path and explode the jsonl file, then convert to a dataframe
with open(output_file_path, "r") as f:
    cat_df = pd.read_json(f, lines=True)
    cat_struct = json.loads(cat_df.to_json(orient="records"))
    cat_df = pd.json_normalize(cat_struct)

cat_df["model_response"] = cat_df["response.body.choices"].apply(
    lambda x: x[0]["message"]["content"]
)
cat_df.info()

In [None]:
# search for the words "heuristic", "rule-based", and "code" in the model_response column; take the last match, or if none are found, return "unknown"
strategy_regex = re.compile(r"\b(?:heuristic|rule-based|code)\b", flags=re.IGNORECASE)

r"(heuristic|rule-based|code)"
cat_df["strategy"] = (
    cat_df["model_response"]
    .apply(lambda x: strategy_regex.findall(x))
    .apply(lambda x: x[-1] if len(x) > 0 else "unknown")
    .str.lower()
)

In [None]:
cat_df["custom_id"]

In [None]:
# join the gpt_192_df with the cat_df["strategy"] on the shared index
gpt_192_df["strategy"] = cat_df["strategy"]
gpt_192_df["strategy"] = pd.Categorical(
    gpt_192_df["strategy"],
    categories=["heuristic", "rule-based", "code", "unknown"],
    ordered=True,
)

gpt_192_df

In [None]:
plot_gpt_acc(GRAMMAR_IDX)

In [None]:
# plot the proportion of each strategy for each model as a function of sample length

fig_height = 2
fig = plt.figure(figsize=(PAPER_WIDTH_IN, fig_height))
grid = fig.add_gridspec(2, 3, wspace=0.05)

acc_by_nlp_df = (
    gpt_192_df.groupby(["n_nonlexical_productions", "model"], observed=True)["correct"]
    .mean()
    .reset_index()
)

with sns.plotting_context("paper", rc=rcs):
    for r in range(2):
        for c, model in enumerate(gpt_192_df["model"].unique()):
            ax = fig.add_subplot(grid[r, c])
            if r == 0:
                sns.lineplot(
                    acc_by_nlp_df[acc_by_nlp_df["model"] == model],
                    x="n_nonlexical_productions",
                    y="correct",
                    color=PALETTE_MODEL[model],
                    ax=ax,
                )
                ax.set_ylim(0, 1)
                ax.set_title(model)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)
                ax.set_xlabel(None)
                ax.set_xticks([])

                if c == 0:
                    ax.set_ylabel("Accuracy")
                    ax.set_yticks([0, 1])
                else:
                    ax.set_ylabel(None)
                    ax.set_yticks([])
            else:
                sns.histplot(
                    gpt_192_df[gpt_192_df["model"] == model],
                    x="n_nonlexical_productions",
                    hue="strategy",
                    stat="proportion",
                    multiple="fill",
                    bins=50,
                    ax=ax,
                    alpha=0.8,
                    linewidth=0,
                    # legend=False,
                )
                ax.set_title(None)
                ax.spines["top"].set_visible(False)
                ax.spines["right"].set_visible(False)

                if c == 0:
                    ax.set_xlabel("Grammar Complexity", ha="left", x=0.0)
                    ax.set_ylabel("Proportion\nof Strategies")
                    ax.set_yticks([0, 1])
                else:
                    ax.set_xlabel(None)
                    ax.set_ylabel(None)
                    ax.set_yticks([])

                if c == 2:
                    sns.move_legend(
                        ax,
                        loc="upper left",
                        bbox_to_anchor=(1, 1),
                        title="Strategy",
                        title_fontsize=8,
                        fontsize=8,
                    )
                else:
                    ax.get_legend().remove()

    for o in fig.findobj():
        o.set_clip_on(False)
    plt.subplots_adjust(left=0, bottom=0, right=1, top=1)

In [None]:
# Subsample the `gpt_192_df` to contain ten instances of `heuristic` and `rule-based`
# Responses

# gpt_192_df[["model_response", "strategy"]]


# shuffle the rows of the dataframe
(
    gpt_192_df[gpt_192_df["strategy"].isin(["heuristic", "rule-based"])][
        ["model_response", "strategy"]
    ]
    .groupby(
        ["strategy"],
        observed=True,
    )
    .apply(lambda x: x.sample(10, random_state=42), include_groups=True)
    .reset_index(drop=True)
    # shuffle the rows of the dataframe
    .sample(frac=1, random_state=42)
    .reset_index(drop=True)
).to_csv(
    PROJECT_ROOT / "notebooks" / "data" / "gpt_192_subsample.csv",
    index=False,
)

In [None]:
gpt_192_df