# Accuracy analysis

In [None]:
import json
import pathlib

import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import pandas as pd
import regex as re
import seaborn as sns
import sklearn.metrics as sk_metrics

## Load data & model responses

In [None]:
# Load batch response

response_path = pathlib.Path("../data/batch_file_response.jsonl")
response_df = pd.read_json(response_path, lines=True)
response_json_struct = json.loads(response_df.to_json(orient="records"))
response_df_flat = pd.json_normalize(response_json_struct)

pos_samples_path = pathlib.Path("../data/sample_trim_20241018115328_positive_300.txt")
neg_samples_path = pathlib.Path("../data/sample_trim_20241018115328_negative_300.txt")

pos_samples_df = pd.read_csv(pos_samples_path, sep="\t", header=None, names=["sample"])
neg_samples_df = pd.read_csv(neg_samples_path, sep="\t", header=None, names=["sample"])

pos_samples_df["sample_type"] = "Positive"
neg_samples_df["sample_type"] = "Negative"

samples_df = pd.concat([pos_samples_df, neg_samples_df], ignore_index=True)
samples_df["custom_id"] = samples_df.apply(lambda x: f"request-{x.name}", axis=1)
samples_df["sample_length"] = samples_df["sample"].apply(lambda x: len(x.split(' ')))


# Extract response and prediction

YES_RE = re.compile(r"yes", re.IGNORECASE)

def extract_content(choices_list: list) -> str:
    return choices_list[0]["message"]["content"]

def extract_prediction(response: str) -> str:
    last_20_chars = response[-20:]
    if YES_RE.search(last_20_chars):
        return "Positive"
    else:
        return "Negative"


response_df_flat["response_content"] = (
    response_df_flat["response.body.choices"]
        .apply(extract_content)
)
response_df_flat["response_prediction"] = (
    response_df_flat["response_content"]
        .apply(extract_prediction)
)

# join samples_df and response_df_flat on `custom_id`

merged_df = response_df_flat.merge(
    samples_df,
    left_on="custom_id",
    right_on="custom_id",
    how="left"
)

merged_df["correct"] = merged_df["sample_type"] == merged_df["response_prediction"]

## Plot sample-length distribution

In [None]:
sns.histplot(
    data=merged_df,
    x="sample_length",
)

## Calculate accuracy metrics

In [None]:
mean_accuracy = sk_metrics.accuracy_score(
    merged_df["sample_type"], 
    merged_df["response_prediction"]
)

mean_cm = sk_metrics.confusion_matrix(
    merged_df["sample_type"], 
    merged_df["response_prediction"],
    normalize="true",
)

negative_sample_acc = mean_cm[0][0]
positive_sample_acc = mean_cm[1][1]

In [None]:
ax=plt.subplot()

sns.heatmap(
    data=mean_cm,
    annot=True,
    ax=ax,
    vmin=0.0,
    vmax=1.0,
    cmap="coolwarm"
)

ax.set_xlabel("Predicted Label")
ax.set_xticklabels(["Negative", "Positive"])
ax.set_ylabel("True Label")
ax.set_yticklabels(["Negative", "Positive"])

## Plot accuracy by sample length & type

In [None]:
fig = plt.figure(figsize=(6, 5))
gs = gridspec.GridSpec(2, 1, height_ratios=[1, 3])

ax0 = plt.subplot(gs[0])
ax1 = plt.subplot(gs[1], sharex=ax0)

sns.histplot(
    data=merged_df,
    x="sample_length",
    ax=ax0,
    bins=30,
    color="gray",
)

sns.lineplot(
    data=merged_df,
    x="sample_length",
    y="correct",
    hue="sample_type",
    ax=ax1,
    style="sample_type",
    palette={"Positive": "orange", "Negative": "purple"},
    markers=['o', 'o'],
    dashes=False,
    alpha=0.7,
)

ax0.set_yscale("log")

ax1.set_ylabel("Mean accuracy")
ax1.set_xlabel("Sample length")

# add horizontal lines for per-class accuracy
ax1.axhline(positive_sample_acc, color="orange", linestyle="--")
ax1.axhline(negative_sample_acc, color="purple", linestyle="--")

# add horizontal line for overall accuracy
ax1.axhline(mean_accuracy, color="black", linestyle="--")

# add text for accuracy values
ax1.text(
    x=0.97, 
    y=positive_sample_acc + 0.01, 
    s=f"{positive_sample_acc:.2f}", 
    color="orange", 
    transform=ax1.transAxes, 
    horizontalalignment="right"
)
ax1.text(
    x=0.97, 
    y=negative_sample_acc + 0.04, 
    s=f"{negative_sample_acc:.2f}", 
    color="purple", 
    transform=ax1.transAxes, 
    horizontalalignment="right"
)
ax1.text(
    x=0.97, 
    y=mean_accuracy + 0.025, 
    s=f"{mean_accuracy:.2f}", 
    color="black",
    transform=ax1.transAxes, 
    horizontalalignment="right"
)


ax1.get_legend().set_title("Sample type")

# hide x-axis label and tick labels on the first subplot
ax0.set_xlabel("")
ax0.tick_params(axis='x', which='both', bottom=True, top=False, labelbottom=False)