In [42]:
def calculate_accuracy_per_prompt(df, prompt_col, pred_col, gt_col):
    """
    Calculate accuracy per prompt for a given DataFrame.

    Parameters:
    df (pd.DataFrame): The input DataFrame.
    prompt_col (str): Column name for the prompt.
    pred_col (str): Column name for the predicted relevance label.
    gt_col (str): Column name for the ground truth relevance label.

    Returns:
    pd.DataFrame: A DataFrame with accuracy per prompt.
    """
    df = df.copy()  # Avoid modifying the original DataFrame
    df['correct'] = (df[pred_col] == df[gt_col]).astype(int)
    accuracy_per_prompt = df.groupby(prompt_col)['correct'].mean().reset_index()
    accuracy_per_prompt.rename(columns={'correct': 'accuracy'}, inplace=True)
    return accuracy_per_prompt

### Load from DB

In [43]:
import pandas as pd
from pathlib import Path

from src.utils import load_yaml
from src.database import get_rows

base_path = Path("../")
config = load_yaml(base_path / "config.yaml")
data_path = base_path / config["data_dir"]
db_path = data_path / config["results_file"]

### Old: Covid 3-class

In [44]:
# df_covid_results_gpt_4o = pd.DataFrame(get_rows(
#     db_path=db_path,
#     file_name="covid.csv",
#     model="gpt-4o"
# ))
#
# df_covid_results_gpt_4o.to_csv(data_path / "covid_gpt_4o.csv", index=False)

In [45]:
# df_covid_results_sonnet_35 = pd.DataFrame(get_rows(
#     db_path=db_path,
#     file_name="covid.csv",
#     model="anthropic.claude-3-5-sonnet-20240620-v1:0"
# ))
#
# df_covid_results_sonnet_35.to_csv(data_path / "covid_sonnet_35.csv", index=False)

### Old: Touche 3-class

In [46]:
# df_touche_results_gpt_4o = pd.DataFrame(get_rows(
#     db_path=db_path,
#     file_name="touche.csv",
#     model="gpt-4o"
# ))
#
# df_touche_results_gpt_4o.to_csv(data_path / "touche_gpt_4o.csv", index=False)

In [47]:
# df_touche_results_sonnet_35 = pd.DataFrame(get_rows(
#     db_path=db_path,
#     file_name="touche.csv",
#     model="anthropic.claude-3-5-sonnet-20240620-v1:0"
# ))
#
# df_touche_results_sonnet_35.to_csv(data_path / "touche_sonnet_35.csv", index=False)

### Covid 2-class

In [48]:
df_covid_2_class_results_gpt_4o = pd.DataFrame(get_rows(
    db_path=db_path,
    file_name="covid_2_class.csv",
    model="gpt-4o",
))

df_covid_2_class_results_gpt_4o.to_csv(data_path / "covid_2_class_gpt_4o.csv", index=False)

accuracy_df = calculate_accuracy_per_prompt(
    df_covid_2_class_results_gpt_4o, 'prompt', 'relevance_label', 'relevance_label_gt'
)

print(accuracy_df)

                   prompt  accuracy
0   classify_2_class_long  0.760959
1  classify_2_class_short  0.741672


In [49]:
df_covid_2_class_results_sonnet_35 = pd.DataFrame(get_rows(
    db_path=db_path,
    file_name="covid_2_class.csv",
    model="anthropic.claude-3-5-sonnet-20240620-v1:0"
))

df_covid_2_class_results_sonnet_35.to_csv(data_path / "covid_2_class_sonnet_35.csv", index=False)

accuracy_df = calculate_accuracy_per_prompt(
    df_covid_2_class_results_sonnet_35, 'prompt', 'relevance_label', 'relevance_label_gt'
)

print(accuracy_df)

                   prompt  accuracy
0   classify_2_class_long  0.791935
1  classify_2_class_short  0.720631


### Touche 2-class

In [50]:
df_touche_2_class_results_gpt_4o = pd.DataFrame(get_rows(
    db_path=db_path,
    file_name="touche_2_class.csv",
    model="gpt-4o"
))

df_touche_2_class_results_gpt_4o.to_csv(data_path / "touche_2_class_gpt_4o.csv", index=False)

accuracy_df = calculate_accuracy_per_prompt(
    df_touche_2_class_results_gpt_4o, 'prompt', 'relevance_label', 'relevance_label_gt'
)

print(accuracy_df)

                   prompt  accuracy
0   classify_2_class_long  0.849142
1  classify_2_class_short  0.822945


In [51]:
df_touche_2_class_results_sonnet_35 = pd.DataFrame(get_rows(
    db_path=db_path,
    file_name="touche_2_class.csv",
    model="anthropic.claude-3-5-sonnet-20240620-v1:0"
))

df_touche_2_class_results_sonnet_35.to_csv(data_path / "touche_2_class_sonnet_35.csv", index=False)

accuracy_df = calculate_accuracy_per_prompt(
    df_touche_2_class_results_sonnet_35, 'prompt', 'relevance_label', 'relevance_label_gt'
)

print(accuracy_df)

                   prompt  accuracy
0   classify_2_class_long  0.854110
1  classify_2_class_short  0.842367
