In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy
import clin.eval
import clin.verifiers.prune
import clin.verifiers.evidence
import clin.verifiers.omission
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

In [2]:
r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]

# add medication status eval
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(dfe, r)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)

  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 6/6 [00:00<00:00, 388.55it/s]


Unnamed: 0_level_0,Unnamed: 1_level_0,f1,recall,precision,acc_cond,f1_macro_cond
checkpoint,n_shots,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
gpt-3.5-turbo,1,0.835,0.847,0.823,0.79,0.064
gpt-3.5-turbo,5,0.887,0.876,0.898,0.861,0.175
gpt-4-0314,1,0.83,0.9,0.771,0.79,0.056
gpt-4-0314,5,0.878,0.903,0.855,0.921,0.212
text-davinci-003,1,0.899,0.932,0.868,0.888,0.394
text-davinci-003,5,0.919,0.921,0.918,0.925,0.626


## Test verifiers

In [3]:
# get original
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
n = len(dfe)
med_status_dict_list_orig = [clin.parse.parse_response_medication_list(row['resps'][i]) for i in range(n)]
llm = clin.llm.get_llm('text-davinci-003')

In [4]:
# apply evidence verifier
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
# ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=0, n_shots_pos=2)
med_status_dict_list_ev = []
med_evidence_dict_list_ev = []
for i in tqdm(range(n)):
    med_status_dict, med_evidence_dict = ev(
        snippet=dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm)
    med_status_dict_list_ev.append(med_status_dict)
    med_evidence_dict_list_ev.append(med_evidence_dict)
med_status_dict_list_ev_pruned = [
    {
        k: v for k, v in med_status_dict_list_ev[i].items()
        if not med_evidence_dict_list_ev[i][k] == 'no evidence'
        # and med_evidence_dict_list_ev[i][k] in dfe.iloc[i]['snippet'].lower()
    }
    for i in range(n)
]

100%|██████████| 100/100 [00:00<00:00, 7506.85it/s]


In [11]:
# apply prune verifier
pv = clin.verifiers.prune.PruneVerifier()
med_status_dict_list_pv = [pv(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm, verbose=False) for i in tqdm(range(n))]

100%|██████████| 100/100 [00:00<00:00, 7940.75it/s]


In [7]:
# apply omission verifier
ov = clin.verifiers.omission.OmissionVerifier()
med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm, verbose=False) for i in tqdm(range(n))]

100%|██████████| 100/100 [00:00<00:00, 5966.97it/s]


In [9]:
# apply omission verifier then deduplication verifier
med_status_dict_list_ovs = []
med_status_dict_list_pvs = []

ov = clin.verifiers.omission.OmissionVerifier()
pv = clin.verifiers.prune.PruneVerifier()
for iter_num in range(1):
    if iter_num == 0:
        bulleted_str_list = row['resps']
    else:
        bulleted_str_list = bulleted_str_list_pv
    med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_ov = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_ov[i]) for i in tqdm(range(n))]

    med_status_dict_list_pv = [pv(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_ov[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_pv = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_pv[i]) for i in tqdm(range(n))]
    
    med_status_dict_list_ovs.append(deepcopy(med_status_dict_list_ov))
    med_status_dict_list_pvs.append(deepcopy(med_status_dict_list_pv))

100%|██████████| 100/100 [00:00<00:00, 7656.63it/s]
100%|██████████| 100/100 [00:00<00:00, 432402.47it/s]
100%|██████████| 100/100 [00:00<00:00, 6331.79it/s]
100%|██████████| 100/100 [00:00<00:00, 586615.94it/s]


In [12]:
PREDS_DICT = {
    'original': med_status_dict_list_orig,
    'evidence': med_status_dict_list_ev,
    'evidence_pruned': med_status_dict_list_ev_pruned,
    'prune': med_status_dict_list_pv,
    'omission': med_status_dict_list_ov,
} | {f'ov{i}': med_status_dict_list_ovs[i] for i in range(len(med_status_dict_list_ovs))} \
    | {f'pv{i}': med_status_dict_list_pvs[i] for i in range(len(med_status_dict_list_pvs))}
mets_dict = defaultdict(list)
for k in PREDS_DICT.keys():
    mets_dict_single = clin.eval.calculate_metrics(PREDS_DICT[k], dfe, verbose=False)
    for k_met in mets_dict_single.keys():
        mets_dict[k_met].append(mets_dict_single[k_met])
df = pd.DataFrame.from_dict(mets_dict).round(3)[['f1', 'recall', 'precision']]
df.index = PREDS_DICT.keys()
df.style.format(precision=3).background_gradient(cmap='Blues')

Unnamed: 0,f1,recall,precision
original,0.919,0.921,0.918
evidence,0.916,0.915,0.917
evidence_pruned,0.917,0.915,0.92
prune,0.898,0.871,0.928
omission,0.812,0.95,0.708
ov0,0.812,0.95,0.708
pv0,0.895,0.912,0.878


## Print errors

In [None]:
mets = clin.eval.calculate_metrics(med_status_dict_list_ov, dfe, verbose=True)

# m1 = med_status_dict_list_orig
m1 = med_status_dict_list_ov
m2 = med_status_dict_list_ov

for i in range(len(m1)):
    # print(dfe.iloc[i]['snippet'])
    if not m1[i] == m2[i]:
        print(m1[i])
        # print(m2[i])

        # print(med_evidence_dict_list_ev[i])
        print()
        print()