In [None]:
import json
import os
import re

import dotenv
import matplotlib.pyplot as plt
import openai
import pandas as pd
import pyrootutils
import seaborn as sns

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)
GRAMMARS_PATH = PROJECT_ROOT / "data" / "scfg_grammars"

results_files = list(GRAMMARS_PATH.rglob("batch_*.jsonl"))
inputs_files = list(GRAMMARS_PATH.rglob("scfg_*.jsonl"))

dotenv.load_dotenv(PROJECT_ROOT / ".env")

In [None]:
openai_client = openai.OpenAI()

In [None]:
res_dfs = []
inputs_dfs = []

for f in results_files:
    df = pd.read_json(f, lines=True)
    json_struct = json.loads(df.to_json(orient="records"))
    flat_df = pd.json_normalize(json_struct)

    batch_id = f.name.split("_output.jsonl")[0]

    # extract response
    flat_df["model_response"] = flat_df["response.body.choices"].apply(
        lambda x: x[0]["message"]["content"]
    )
    flat_df["prompt_tokens"] = flat_df["response.body.usage.prompt_tokens"]
    flat_df["completion_tokens"] = flat_df["response.body.usage.completion_tokens"]
    flat_df["total_tokens"] = flat_df["response.body.usage.total_tokens"]
    flat_df["model"] = flat_df["response.body.model"]
    flat_df["batch_id"] = batch_id

    res_dfs.append(flat_df)
res_df = pd.concat(res_dfs, ignore_index=True)
batch_ids = res_df["batch_id"].unique()

for bid in batch_ids:
    batch = openai_client.batches.retrieve(bid)
    input_file = openai_client.files.retrieve(batch.input_file_id)
    res_df.loc[res_df["batch_id"] == bid, "input_file"] = input_file.filename

    res_df["grammar_name"] = res_df["input_file"].apply(
        lambda x: "_".join(str(x).split("_")[0:2])
    )

inputs_dfs = []

for f in inputs_files:
    df = pd.read_json(f, lines=True)
    json_struct = json.loads(df.to_json(orient="records"))
    flat_df = pd.json_normalize(json_struct)
    flat_df["grammar_name"] = flat_df["body.metadata.grammar_file"]
    flat_df["lhs"] = flat_df["body.metadata.lhs"]
    flat_df["rhs"] = flat_df["body.metadata.rhs"]
    inputs_dfs.append(flat_df)

inputs_df = pd.concat(inputs_dfs, ignore_index=True)


# join res_df and inputs_df on "grammar_name", "request_id"
res_df = pd.merge(
    res_df,
    inputs_df,
    on=["grammar_name", "custom_id"],
)

In [None]:
inputs_df.head()

## Extract answer

In [None]:
answer_re = re.compile(r"Final Answer: (.*?)(?:\n|$)", re.DOTALL)


def extract_answer(model_response):
    matches = answer_re.findall(model_response)
    if matches:
        last_match: str = matches[-1]
        last_match = re.sub(r"[^a-zA-Z\s]", "", last_match)
        last_match = last_match.strip()
        return last_match
    else:
        return None

In [None]:
res_df = res_df.drop_duplicates(subset=["custom_id", "batch_id"])
res_df["model_answer"] = res_df["model_response"].apply(extract_answer)
res_df = res_df.dropna(subset=["model_answer", "rhs"]).reset_index(drop=True)
res_df["exact_match"] = res_df["model_answer"] == res_df["rhs"]
res_df["bow_match"] = res_df.apply(
    lambda row: sorted(row["rhs"].split()) == sorted(row["model_answer"].split()),
    axis=1,
)

res_df["lhs_length"] = res_df["lhs"].apply(lambda x: len(x.split()))


# melt exact_match and bow_match
metrics_df = res_df.melt(
    id_vars=["model", "lhs_length", "custom_id"],
    value_vars=["exact_match", "bow_match"],
    var_name="match_type",
    value_name="match_value",
)

# rename `exact_match` and `bow_match` in match_type column
metrics_df["match_type"] = metrics_df["match_type"].replace(
    {"exact_match": "Exact Match", "bow_match": "Bag of Words"}
)

In [None]:
res_df.groupby(["model", "lhs_length"])["custom_id"].count()

In [None]:
res_df.groupby(["model", "lhs_length"])[["exact_match", "bow_match"]].mean()

In [None]:
fig = plt.figure(figsize=(6, 2.5), layout="constrained")
gs = fig.add_gridspec(1, 3)

ax0 = fig.add_subplot(gs[0, 0])
ax1 = fig.add_subplot(gs[0, 1])
ax2 = fig.add_subplot(gs[0, 2])

axes = [ax0, ax1, ax2]


for i, model in enumerate(metrics_df["model"].unique()):
    model_df = metrics_df[metrics_df["model"] == model]
    sns.lineplot(
        data=model_df,
        x="lhs_length",
        y="match_value",
        hue="match_type",
        ax=axes[i],
        marker="o",
    )

    axes[i].set_ylim(-0.05, 1.05)

    model_name = model.split("-2")[0]

    # format y-axis ticks as percentages
    axes[i].set_yticks([0, 0.25, 0.5, 0.75, 1.0])
    axes[i].set_yticklabels(["0%", "25%", "50%", "75%", "100%"])
    axes[i].set_title(model_name, fontsize=10)
    axes[i].set_ylabel("Accuracy")
    axes[i].set_xlabel("Sentence Length")

    if i > 0:
        axes[i].get_legend().remove()
        axes[i].set_ylabel(None)
        axes[i].set_yticks([])
    else:
        axes[i].legend(title=None, loc="upper left", fontsize=9)

In [None]:
sns.lineplot(
    data=res_df,
    x="lhs_length",
    y="bow_match",
    hue="model",
    marker="o",
    ci=None,
    palette="tab10",
)