In [None]:
%load_ext autoreload
%autoreload 2

import pandas as pd
import eval_extraction
import openai
import extraction
import numpy as np
import matplotlib.pyplot as plt
openai.api_key = open('/home/chansingh/.OPENAI_KEY').read().strip()

# load the data
df = pd.read_pickle("../data/data_clean.pkl")
# checkpoint = "gpt-3.5-turbo-0613"
checkpoint = 'gpt-4-0613'

# Make predictions

Current outputs:
- num_male, num_female, num_total, num_male_evidence_span, num_female_evidence_span, num_total_evidence_span
- num_white, num_black, num_latino, num_asian, race_evidence_span

Targets:
- 'participants___male', 'participants___female', 'participants___total'
- 'participants___white', 'participants___black', 'participants___latino', 'participants___asian'

In [None]:
gt_cols = [
    "participants___male",
    "participants___female",
    "participants___total",
    "participants___white",
    "participants___black",
    "participants___latino",
    "participants___asian",
]
idxs = df["paper___raw_text"].notna() & ((df[gt_cols] > 0).any(axis=1))
texts = df.loc[idxs, "paper___raw_text"].values.tolist()
extractions = extraction.extract_nums_df(
    texts,
    verbose=False,
    checkpoint=checkpoint,
    subset_len_tokens={"gpt-4-0613": 4750, "gpt-3.5-turbo-0613": 3000}[checkpoint],
)
for k in extractions.keys():
    df.loc[idxs, k] = extractions[k].values

### Evaluate
Evaluates whether each extracted number is within 1 of the human-labeled value

In [None]:
preds_col_to_gt_col_dict = {
    "num_male": "participants___male",
    "num_female": "participants___female",
    "num_total": "participants___total",
    "num_white": "participants___white",
    "num_black": "participants___black",
    "num_asian": "participants___asian",
    "num_latino": "participants___latino",
}

# unparsable
for k in preds_col_to_gt_col_dict.keys():
    df.loc[~df[k].apply(eval_extraction.str_is_parsable), k] = np.nan
    if not k == "num_total":
        df[k] = eval_extraction.convert_percentages_when_total_is_known(
            df[k], df["num_total"]
        )
# print("Total n (with paper text)", idxs.sum())
mets = eval_extraction.compute_metrics_within_1(
    df, preds_col_to_gt_col_dict=preds_col_to_gt_col_dict
)
mets.to_pickle(f'../results/llm/extraction_{checkpoint}.pkl')

### Merge results

In [3]:
mets_dict = {}
for checkpoint in ['gpt-3.5-turbo-0613', 'gpt-4-0613']:
    mets_dict[checkpoint] = pd.read_pickle(f'../results/llm/extraction_{checkpoint}.pkl')

In [31]:
results = mets_dict["gpt-3.5-turbo-0613"].merge(
    mets_dict["gpt-4-0613"], suffixes=("_3.5", "_4"), on='target'
)
results = results.drop(columns=[k for k in results.columns if 'n_correct' in k or k in ['n_gt_4', 'n_pred_4']])

In [32]:
results.sort_values(by='recall_4', ascending=False)

Unnamed: 0,target,n_gt_3.5,n_pred_3.5,recall_3.5,precision_3.5,recall_4,precision_4
2,participants___total,459,363,0.51,0.65,0.67,0.81
0,participants___male,374,246,0.24,0.36,0.51,0.84
1,participants___female,379,247,0.22,0.34,0.51,0.84
6,participants___asian,37,1,0.0,0.0,0.46,0.81
5,participants___latino,54,2,0.04,1.0,0.33,0.69
4,participants___black,78,53,0.14,0.21,0.32,0.6
3,participants___white,93,63,0.11,0.16,0.28,0.53
