In [None]:
%load_ext autoreload
%autoreload 2
from os.path import join
from tqdm import tqdm
import pandas as pd
import numpy as np
import clin.eval.med_status
import datasets
import clin.parse
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'
results_dir = '../results/'
# results_dir = '../results_tmp/'
from clin.config import PATH_REPO
import imodelsx.process_results
r = imodelsx.process_results.get_results_df(results_dir, use_cached=False)
r = r[[col for col in r.columns if not col in ['checkpoint_verify', 'role_verify']]]
r = r[r.use_megaprompt == 0]

def viz_blues(df):
    return df.style.format(precision=3).background_gradient(cmap='Blues')

In [None]:
# medication_status
r_med = r[r.dataset_name == 'medication_status']
df = pd.DataFrame.from_dict(datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')['test'])
r_med = clin.eval.med_status.add_status_eval(r_med, df)

# Average results

In [None]:
# medication_status
r_med = r[r.dataset_name == "medication_status"]
df = pd.DataFrame.from_dict(
    datasets.load_dataset("mitclinicalml/clinical-ie", "medication_status")["test"]
)
r_med = clin.eval.med_status.add_status_eval(r_med, df)
r_med = imodelsx.process_results.average_over_seeds(
    r_med, experiment_filename=join(PATH_REPO, "experiments", "eval_model.py")
)

# compare values for a single row
row_df = r_med[
    (r_med.n_shots == 5) * (r_med.checkpoint == "text-davinci-003")
].reset_index()
rc = row_df[[c for c in row_df.columns if "___" in c]]
# create multindex columns by splitting on '___'
rc = rc.rename(columns={c: tuple(c.split("___")) for c in rc.columns})

# convert tuple column names to multiindex
rc.columns = pd.MultiIndex.from_tuples(rc.columns)
rc = rc.T.reset_index()
rc = (
    rc.rename(
        columns={
            "level_0": "",
            "level_1": "Verifiers",
        }
    )
    .pivot_table(index="Verifiers", columns="", values=0)
    .round(3)
)
# rc.style.format(precision=3).background_gradient(cmap='gray')

cols = {
    "f1": "F1 (Med)",
    "precision": "Precision (Med)",
    "recall": "Recall (Med)",
    "status_f1_macro_cond": "F1 (Med status)",
}
rows = {
    "original": "Original",
    "ov": "Omission",
    "pv": "Prune",
    "ov_pv": "Omission + Prune",
    "sv": "Omission + Prune + Evidence",
}
rt_med = (
    rc[list(cols.keys())].rename(columns=cols).loc[list(rows.keys())].rename(index=rows)
)
rt_med_sem = (
    rc[list(cols.keys())].rename(columns=cols)
    .loc[[k + '_err' for k in list(rows.keys())]]
    .rename(index={k + '_err': rows[k] for k in rows})
)
viz_blues(rt_med)

In [None]:
r_ebm = r[r.dataset_name == "ebm"]
r_ebm = imodelsx.process_results.average_over_seeds(
    r_ebm, experiment_filename=join(PATH_REPO, "experiments", "eval_model.py")
)

# compare values for a single row
row = r_ebm[(r_ebm.n_shots == 5) * (r_ebm.checkpoint == "text-davinci-003")].iloc[0]

# show metrics
row_df = pd.DataFrame(
    pd.Series({k: row[k] for k in row.keys() if "___" in k}).round(3)
).T
rc = row_df[[c for c in row_df.columns if "___" in c]]
# create multindex columns by splitting on '___'
rc = rc.rename(columns={c: tuple(c.split("___")) for c in rc.columns})

# convert tuple column names to multiindex
rc.columns = pd.MultiIndex.from_tuples(rc.columns)
rc = rc.T.reset_index()
rc = (
    rc.rename(
        columns={
            "level_0": "",
            "level_1": "Verifiers",
        }
    )
    .pivot_table(index="Verifiers", columns="", values=0)
    .round(3)
)
# rc.index = [x.replace("list_", "") for x in rc.index.values]
cols = {
    "f1": "F1 (Arms)",
    "precision": "Precision (Arms)",
    "recall": "Recall (Arms)",
}
rows = {
    "original": "Original",
    "ov": "Omission",
    "pv": "Prune",
    "ov_pv": "Omission + Prune",
    "ov_pv_ev": "Omission + Prune + Evidence",
}
rt_ebm = rc[list(cols.keys())].rename(columns=cols).loc[list(rows.keys())].rename(index=rows)
rt_ebm_sem = (
    rc[list(cols.keys())].rename(columns=cols)
    .loc[[k + '_err' for k in list(rows.keys())]]
    .rename(index={k + '_err': rows[k] for k in rows})
)
viz_blues(rt_ebm)

# Visualize

In [None]:
# add columns from rt_ebm to rt_med
rt = rt_med.join(rt_ebm) #, rsuffix=' (Arms)')
rt = rt.drop(columns='F1 (Med status)')
display(viz_blues(rt))

In [None]:
# add error bars
for k in rt_ebm.index:
    rt_ebm.loc[k] = rt_ebm.loc[k].astype(str) + "\err{" + rt_ebm_sem.loc[k].astype(str) + "}"
for k in rt_med.index:
    rt_med.loc[k] = (
        rt_med.loc[k].astype(str) + "\err{" + rt_med_sem.loc[k].astype(str) + "}"
    )

rt = rt_med.join(rt_ebm) #, rsuffix=' (Arms)')
rt = rt.drop(columns='F1 (Med status)')
print(rt.style.format(precision=3).to_latex(hrules=True))