In [20]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
import clin.eval
import clin.verifiers.deduplicate
import clin.verifiers.evidence
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [46]:
r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]

# add medication status eval
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(dfe, r)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 6/6 [00:00<00:00, 391.11it/s]


Unnamed: 0_level_0,Unnamed: 1_level_0,f1,recall,precision,acc_cond,f1_macro_cond
checkpoint,n_shots,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
gpt-3.5-turbo,1,0.818,0.845,0.812,0.79,0.064
gpt-3.5-turbo,5,0.871,0.873,0.888,0.861,0.175
gpt-4-0314,1,0.82,0.891,0.787,0.79,0.056
gpt-4-0314,5,0.87,0.894,0.864,0.921,0.212
text-davinci-003,1,0.896,0.926,0.884,0.888,0.394
text-davinci-003,5,0.919,0.927,0.925,0.925,0.626


### Test verifiers

In [30]:
# get original
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
n = len(dfe)
med_status_dict_list_orig = [clin.parse.parse_response_medication_list(row['resps'][i]) for i in range(n)]
llm = clin.llm.get_llm('text-davinci-003')

In [21]:
# apply evidence verifier
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
med_status_dict_list_ev = []
med_evidence_dict_list_ev = []
for i in tqdm(range(n)):
    med_status_dict, med_evidence_dict = ev(
        snippet=dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm)
    med_status_dict_list_ev.append(med_status_dict)
    med_evidence_dict_list_ev.append(med_evidence_dict)
med_status_dict_list_ev_pruned = [
    {
        k: v for k, v in med_status_dict_list_ev[i].items()
        if not med_evidence_dict_list_ev[i][k] == 'no evidence'
    }
    for i in range(n)
]

100%|██████████| 100/100 [00:00<00:00, 8132.59it/s]


In [33]:
# apply deduplication verifier
dv = clin.verifiers.deduplicate.DeduplicateVerifier()
med_status_dict_list_dv = [dv(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm) for i in tqdm(range(n))]

100%|██████████| 100/100 [04:28<00:00,  2.68s/it]


In [45]:
PREDS_DICT = {
    'original': med_status_dict_list_orig,
    'evidence': med_status_dict_list_ev,
    'evidence_pruned': med_status_dict_list_ev_pruned,
    'deduplicate': med_status_dict_list_dv,
}
mets_dict = defaultdict(list)
for k in PREDS_DICT.keys():
    mets_dict_single = clin.eval.calculate_metrics(PREDS_DICT[k], dfe, verbose=False)
    for k_met in mets_dict_single.keys():
        mets_dict[k_met].append(np.mean(mets_dict_single[k_met]))
df = pd.DataFrame.from_dict(mets_dict).round(3)
df.index = PREDS_DICT.keys()
df.style.format(precision=3).background_gradient(cmap='Blues')

Unnamed: 0,precision,recall,f1
original,0.925,0.927,0.919
evidence,0.925,0.924,0.918
evidence_pruned,0.928,0.924,0.92
deduplicate,0.932,0.828,0.865
