In [None]:
%load_ext autoreload
%autoreload 2
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import clin.llm
import clin.parse
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'
from typing import List
results_dir = '../results/'
from clin.config import PATH_REPO
from clin.modules.ebm import extract, omission, prune, evidence
import clin.eval.ebm
import clin.eval.eval
from clin.modules import ebm
import joblib
import imodelsx.process_results

r = imodelsx.process_results.get_results_df(results_dir, use_cached=False)
r = r[r.dataset_name == 'ebm']
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]

In [None]:
# show metrics
row_df = pd.DataFrame(
    pd.Series({k: row[k] for k in row.keys() if "___" in k}).round(3)
).T
rc = row_df[[c for c in row_df.columns if "___" in c]]
# create multindex columns by splitting on '___'
rc = rc.rename(columns={c: tuple(c.split("___")) for c in rc.columns})

# convert tuple column names to multiindex
rc.columns = pd.MultiIndex.from_tuples(rc.columns)
rc = rc.T.reset_index()
rc = (
    rc.rename(
        columns={
            "level_0": "",
            "level_1": "Verifiers",
        }
    )
    .pivot_table(index="Verifiers", columns="", values=0)
    .round(3)
)
# rc.index = [x.replace("list_", "") for x in rc.index.values]
cols = {
    "f1": "F1",
    "precision": "Precision",
    "recall": "Recall",
}
rows = {
    "original": "Original",
    "ov": "Omission",
    "pv": "Prune",
    "ov_pv": "Omission + Prune",
    "ov_pv_ev": "Omission + Prune + Evidence",
}
(
    rc[list(cols.keys())]
    .rename(columns=cols)
    .loc[list(rows.keys())]
    .rename(index=rows)
    .style.format(precision=3)
    .background_gradient(cmap="Blues")
)

## Visualize an example

In [None]:
df = joblib.load(join(PATH_REPO, 'data', 'ebm', 'ebm_interventions_cleaned.pkl'))
df = df.iloc[:100]
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]
# n = len(dfe)
# llm = clin.llm.get_llm('text-davinci-003')

# compare lists
l1 = [sorted(l) for l in dfe["interventions"].values.tolist()]
# l1 = r['list_ov']
# l1 = r['list_ov']
# l1 = r['list_ov_pv_ev']
l2 = row['list_ov_pv']
for i in range(len(l1)):
    l1_, l2_ = clin.eval.ebm.process_ebm_lists(l1[i], l2[i])
    # if set(l1_) == set(l2_) and len(set(l1_)) > 2:
    if len(set(l1_)) > 3:
        print(dfe.iloc[i]['doc'])
        print(i)
        print(sorted(l1[i]))
        print(sorted(l2[i]))
        print()

## Analyze errors

In [None]:
df = joblib.load(join(PATH_REPO, 'data', 'ebm', 'ebm_interventions_cleaned.pkl'))
df = df.iloc[:100]
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]
# n = len(dfe)
# llm = clin.llm.get_llm('text-davinci-003')

# compare lists
l1 = [sorted(l) for l in dfe["interventions"].values.tolist()]
# l1 = r['list_ov']
# l1 = r['list_ov']
# l1 = r['list_ov_pv_ev']
l2 = row['list_ov_pv']
for i in range(len(l1)):
    l1_, l2_ = clin.eval.ebm.process_ebm_lists(l1[i], l2[i])
    if set(l1_) != set(l2_):
        print(dfe.iloc[i]['doc'])
        print(i)
        print(sorted(l1[i]))
        print(sorted(l2[i]))
        print()

In [None]:
# look at validation data
dfv = joblib.load(join(PATH_REPO, 'data', 'ebm', 'ebm_interventions_cleaned.pkl')).iloc[100:]
for i in range(len(dfv)):
    row = dfv.iloc[i]
    print(row['doc'])
    print(clin.parse.list_to_bullet_str(row['interventions']))
    print()