# Accuracy analysis

In [None]:
import json
import os
import pathlib

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd
import pyrootutils
import regex as re
import seaborn as sns
import sklearn.metrics as sk_metrics

In [None]:
PROJECT_ROOT = pyrootutils.find_root(
    search_from=os.path.abspath(""), indicator=".project-root"
)

## Load data & model responses

In [None]:
YES_RE = re.compile(r"yes", re.IGNORECASE)


def extract_content(choices_list: list) -> str:
    return choices_list[0]["message"]["content"]


def extract_prediction(response: str) -> str:
    last_20_chars = response[-20:]
    if YES_RE.search(last_20_chars):
        return "positive"
    else:
        return "negative"

In [None]:
completions_path = PROJECT_ROOT / "data" / "completions"

# list all files matching ".*_inputs.jsonl" in the completions_path
inputs_files = list(completions_path.glob("*_inputs.jsonl"))
results_files = list(completions_path.glob("*_results.jsonl"))

# Grab the batch_id from the filename
batch_id_re = re.compile(r"^(batch_\w+)_")

inputs_dfs = []
for f in inputs_files:
    i_df = pd.read_json(f, lines=True)
    i_json_struct = json.loads(i_df.to_json(orient="records"))
    i_flat_df = pd.json_normalize(i_json_struct)
    batch_id = batch_id_re.search(f.name).group(1)
    i_flat_df["batch_id"] = batch_id
    inputs_dfs.append(i_flat_df)
inputs_df = pd.concat(inputs_dfs, ignore_index=True)

results_dfs = []
for f in results_files:
    r_df = pd.read_json(f, lines=True)
    r_json_struct = json.loads(r_df.to_json(orient="records"))
    r_flat_df = pd.json_normalize(r_json_struct)
    batch_id = batch_id_re.search(f.name).group(1)
    r_flat_df["batch_id"] = batch_id
    results_dfs.append(r_flat_df)
results_df = pd.concat(results_dfs, ignore_index=True)

# Merge inputs and results on the the batch_id and custom_id
response_full_df = results_df.merge(
    inputs_df[
        [
            "custom_id",
            "batch_id",
            "body.metadata.sample_type",  # ground-truth label for sample
            "body.metadata.sample",  # the sample itself
            "body.metadata.grammar_file",  # grammar file used
            "body.metadata.model",  # model used
            "body.metadata.n_shots",  # n_shots used
        ]
    ],
    on=["batch_id", "custom_id"],
)

response_full_df = response_full_df.rename(
    columns={
        "body.metadata.sample_type": "sample.type.ground_truth",
        "body.metadata.sample": "sample",
        "body.metadata.grammar_file": "grammar_file",
        "body.metadata.model": "model",
        "body.metadata.n_shots": "n_shots",
    }
)

response_full_df["model_response"] = response_full_df["response.body.choices"].apply(
    extract_content
)

response_df = response_full_df[
    [
        "sample",
        "sample.type.ground_truth",
        "model_response",
        "grammar_file",
        "model",
        "n_shots",
    ]
].copy()


response_df["sample.type.predicted"] = response_df["model_response"].apply(
    extract_prediction
)

response_df["sample.length"] = response_df["sample"].apply(
    lambda s: len(str(s).split(" "))
)

response_df["correct"] = (
    response_df["sample.type.ground_truth"] == response_df["sample.type.predicted"]
)

response_df = response_df.dropna()

response_df["n_shots"] = pd.Categorical(
    response_df["n_shots"],
    categories=["0", "2", "4"],
    ordered=True,
)
response_df["sample.type.ground_truth"] = pd.Categorical(
    response_df["sample.type.ground_truth"],
    categories=["positive", "negative"],
    ordered=True,
)
response_df["sample.type.predicted"] = pd.Categorical(
    response_df["sample.type.predicted"],
    categories=["positive", "negative"],
    ordered=True,
)
response_df["model"] = pd.Categorical(
    response_df["model"],
)

unique_grammars = response_df["grammar_file"].unique()
g_map_dict = {g: f"Grammar {i+1}" for i, g in enumerate(unique_grammars)}
response_df["grammar_file"] = response_df["grammar_file"].map(g_map_dict)
response_df["grammar_file"] = pd.Categorical(
    response_df["grammar_file"],
    categories=list(g_map_dict.values()),
    ordered=True,
)

response_df.info()

In [None]:
(
    response_df.groupby(["grammar_file", "n_shots", "sample.type.ground_truth"])[
        ["sample"]
    ].count()
)

## Plot sample-length distribution

In [None]:
fig = plt.figure(figsize=(6, 3))

ax = fig.add_subplot(111)

sns.histplot(
    data=response_df,
    x="sample.length",
    ax=ax,
    bins=25,
    hue="sample.type.ground_truth",
    palette={"positive": "orange", "negative": "purple"},
)

ax.get_legend().set_title("Sample type")

ax.set_yscale("log")
ax.set_xlabel("Sample length")

Since some of longer sample lengths only have a few samples, the variance on the 
accuracy will be really high. We solve this by throwing out any samples without at least
10 samples in that length category.

In [None]:
MIN_NUM_SAMPLES = 20

samples_by_length = response_df.groupby("sample.length")["sample"].count()
many_samples_lengths = samples_by_length[
    samples_by_length > MIN_NUM_SAMPLES
].index.values

response_df = response_df[response_df["sample.length"].isin(many_samples_lengths)]

## Calculate accuracy metrics

In [None]:
mean_accuracy = sk_metrics.accuracy_score(
    response_df["sample.type.ground_truth"], response_df["sample.type.predicted"]
)

mean_cm = sk_metrics.confusion_matrix(
    response_df["sample.type.ground_truth"],
    response_df["sample.type.predicted"],
    normalize="true",
)

negative_sample_acc = mean_cm[0][0]
positive_sample_acc = mean_cm[1][1]

In [None]:
ax = plt.subplot()

sns.heatmap(data=mean_cm, annot=True, ax=ax, vmin=0.0, vmax=1.0, cmap="coolwarm")

ax.set_xlabel("Predicted Label")
ax.set_xticklabels(["Negative", "Positive"])
ax.set_ylabel("True Label")
ax.set_yticklabels(["Negative", "Positive"])

## Plot accuracy by sample length & type

In [None]:
fig = plt.figure(figsize=(6, 5))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3])

ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1], sharex=ax0)

n_bins = response_df["sample.length"].nunique()

sns.histplot(
    data=response_df,
    x="sample.length",
    ax=ax0,
    bins=n_bins,
    color="gray",
)

sns.lineplot(
    data=response_df,
    x="sample.length",
    y="correct",
    hue="sample.type.ground_truth",
    ax=ax1,
    style="sample.type.ground_truth",
    palette={"positive": "orange", "negative": "purple"},
    markers=["o", "o"],
    dashes=False,
    alpha=0.5,
    linewidth=2,
    err_style="bars",
)

ax0.set_yscale("log")
ax0.set_ylim(10, None)

ax1.set_ylabel("Mean accuracy")
ax1.set_xlabel("Sample length")

# add horizontal lines for per-class accuracy
ax1.axhline(positive_sample_acc, color="orange", linestyle="--", linewidth=2)
ax1.axhline(negative_sample_acc, color="purple", linestyle="--", linewidth=2)

# add horizontal line for overall accuracy
ax1.axhline(mean_accuracy, color="black", linestyle="--", linewidth=2)

# add text for accuracy values
ax1.text(
    x=0.97,
    y=positive_sample_acc + 0.01,
    s=f"{positive_sample_acc:.2f}",
    color="orange",
    transform=ax1.transAxes,
    horizontalalignment="right",
    fontdict={"weight": "bold"},
)
ax1.text(
    x=0.97,
    y=negative_sample_acc - 0.02,
    s=f"{negative_sample_acc:.2f}",
    color="purple",
    transform=ax1.transAxes,
    horizontalalignment="right",
    verticalalignment="top",
    fontdict={"weight": "bold"},
)
ax1.text(
    x=0.97,
    y=mean_accuracy + 0.025,
    s=f"{mean_accuracy:.2f}",
    color="black",
    transform=ax1.transAxes,
    horizontalalignment="right",
    fontdict={"weight": "bold"},
)

ax1.get_legend().set_title("Sample type")

# hide x-axis label and tick labels on the first subplot
ax0.set_xlabel("")
ax0.tick_params(axis="x", which="both", bottom=True, top=False, labelbottom=False)

## Few-shot analysis

In [None]:
fig = plt.figure(figsize=(3, 3))
gs = gridspec.GridSpec(1, 1)

ax0 = plt.subplot(gs[0])

sns.lineplot(
    data=response_df,
    x="n_shots",
    y="correct",
    hue="sample.type.ground_truth",
    ax=ax0,
    style="grammar_file",
    palette={"positive": "orange", "negative": "purple"},
    markers=True,
    errorbar=None,
    alpha=0.5,
    linewidth=2,
)

sns.lineplot(
    data=response_df,
    x="n_shots",
    y="correct",
    ax=ax0,
    color="black",
    errorbar=None,
    linewidth=3,
    marker="o",
)

handles, labels = ax0.get_legend_handles_labels()
labels[0] = "Sample Type"
labels[3] = ""

ax0.legend(
    loc="upper left",
    bbox_to_anchor=(1, 1),
    handles=handles,
    labels=labels,
)

ax0.set_ylim(0, 1)
ax0.set_ylabel("Mean Accuracy")
ax0.set_xlabel("# of Shots")

In [None]:
zero_shots_df = response_df[response_df["n_shots"] == "0"]
two_shots_df = response_df[response_df["n_shots"] == "2"]
four_shots_df = response_df[response_df["n_shots"] == "4"]

mean_cm_0 = sk_metrics.confusion_matrix(
    zero_shots_df["sample.type.ground_truth"],
    zero_shots_df["sample.type.predicted"],
    normalize="true",
)

mean_cm_2 = sk_metrics.confusion_matrix(
    two_shots_df["sample.type.ground_truth"],
    two_shots_df["sample.type.predicted"],
    normalize="true",
)

mean_cm_4 = sk_metrics.confusion_matrix(
    four_shots_df["sample.type.ground_truth"],
    four_shots_df["sample.type.predicted"],
    normalize="true",
)

fig = plt.figure(figsize=(10, 3))
gs = gridspec.GridSpec(1, 3)

ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1], sharey=ax0)
ax2 = plt.subplot(gs[2], sharey=ax0)

sns.heatmap(
    data=mean_cm_0, annot=True, ax=ax0, vmin=0.0, vmax=1.0, cmap="coolwarm", cbar=False
)
sns.heatmap(
    data=mean_cm_2, annot=True, ax=ax1, vmin=0.0, vmax=1.0, cmap="coolwarm", cbar=False
)
sns.heatmap(
    data=mean_cm_4, annot=True, ax=ax2, vmin=0.0, vmax=1.0, cmap="coolwarm", cbar=False
)


ax0.set_ylabel("True Label")
ax0.set_yticklabels(["Negative", "Positive"])

ax0.set_xlabel("Predicted Label")
ax0.set_xticklabels(["Negative", "Positive"])
ax0.set_title("0 Shot")
ax1.set_xlabel("Predicted Label")
ax1.set_xticklabels(["Negative", "Positive"])
ax1.set_title("2 Shot")
ax2.set_xlabel("Predicted Label")
ax2.set_xticklabels(["Negative", "Positive"])
ax2.set_title("4 Shot")

ax1.tick_params(axis="y", which="both", left=True, right=False, labelleft=False)
ax2.tick_params(axis="y", which="both", left=True, right=False, labelleft=False)

## Positive sample proportions

In [None]:
partitions_path = PROJECT_ROOT / "data" / "partitions"

# open all .csv files in partitions_path
partitions_files = list(partitions_path.glob("*.csv"))

# read all .csv files into a single dataframe
partitions_dfs = []
for f in partitions_files:
    f_name = pathlib.Path(f).stem.split("_k=")[0].split("counts_")[1]
    g_name = g_map_dict[f_name]
    p_df = pd.read_csv(f)
    p_df["grammar_file"] = g_name
    partitions_dfs.append(p_df)

partitions_df = (
    pd.concat(partitions_dfs, ignore_index=True)
    .groupby(["grammar_file", "sample.length"])
    .first()
    .reset_index()
)

In [None]:
fig = plt.figure(figsize=(6, 4))
gs = gridspec.GridSpec(1, 1)
ax0 = plt.subplot(gs[0])

sns.lineplot(
    data=partitions_df,
    x="sample.length",
    y="prop_positive_samples",
    hue="grammar_file",
    style="grammar_file",
    linewidth=2,
    markers=True,
    # color="orange",
    palette="Oranges",
    ax=ax0,
)

ax0.set_ylabel("Proportion of Strings in Grammar")
ax0.set_ylim(0, 1)
ax0.get_legend().set_title("")
ax0.set_xlabel("Sample Length")