In [133]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy
import clin.eval
import clin.verifiers.deduplicate
import clin.verifiers.evidence
import clin.verifiers.omission
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [134]:
r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]

# add medication status eval
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(dfe, r)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)


100%|██████████| 6/6 [00:00<00:00, 2719.45it/s]


  0%|          | 0/2 [00:00<?, ?it/s]


100%|██████████| 6/6 [00:00<00:00, 396.11it/s]


Unnamed: 0_level_0,Unnamed: 1_level_0,f1,recall,precision,acc_cond,f1_macro_cond
checkpoint,n_shots,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
gpt-3.5-turbo,1,0.835,0.847,0.823,0.79,0.064
gpt-3.5-turbo,5,0.887,0.876,0.898,0.861,0.175
gpt-4-0314,1,0.83,0.9,0.771,0.79,0.056
gpt-4-0314,5,0.878,0.903,0.855,0.921,0.212
text-davinci-003,1,0.899,0.932,0.868,0.888,0.394
text-davinci-003,5,0.919,0.921,0.918,0.925,0.626


### Test verifiers

In [135]:
# get original
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
n = len(dfe)
med_status_dict_list_orig = [clin.parse.parse_response_medication_list(row['resps'][i]) for i in range(n)]
llm = clin.llm.get_llm('text-davinci-003')

In [136]:
# apply evidence verifier
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
med_status_dict_list_ev = []
med_evidence_dict_list_ev = []
for i in tqdm(range(n)):
    med_status_dict, med_evidence_dict = ev(
        snippet=dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm)
    med_status_dict_list_ev.append(med_status_dict)
    med_evidence_dict_list_ev.append(med_evidence_dict)
med_status_dict_list_ev_pruned = [
    {
        k: v for k, v in med_status_dict_list_ev[i].items()
        if not med_evidence_dict_list_ev[i][k] == 'no evidence'
        # and med_evidence_dict_list_ev[i][k] in dfe.iloc[i]['snippet'].lower()
    }
    for i in range(n)
]


100%|██████████| 100/100 [00:00<00:00, 6075.18it/s]


In [137]:
# apply deduplication verifier
dv = clin.verifiers.deduplicate.DeduplicateVerifier()
med_status_dict_list_dv = [dv(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm) for i in tqdm(range(n))]


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

In [None]:
# apply omission verifier
ov = clin.verifiers.omission.OmissionVerifier()
med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm) for i in tqdm(range(n))]

In [None]:
# apply omission verifier then deduplication verifier
med_status_dict_list_ovs = []
med_status_dict_list_dvs = []

ov = clin.verifiers.omission.OmissionVerifier()
dv = clin.verifiers.deduplicate.DeduplicateVerifier()
for iter_num in range(4):
    if iter_num == 0:
        bulleted_str_list = row['resps']
    else:
        bulleted_str_list = bulleted_str_list_dv
    med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_ov = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_ov[i]) for i in tqdm(range(n))]

    med_status_dict_list_dv = [dv(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_ov[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_dv = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_dv[i]) for i in tqdm(range(n))]
    
    med_status_dict_list_ovs.append(deepcopy(med_status_dict_list_ov))
    med_status_dict_list_dvs.append(deepcopy(med_status_dict_list_dv))

In [132]:
PREDS_DICT = {
    'original': med_status_dict_list_orig,
    'evidence': med_status_dict_list_ev,
    'evidence_pruned': med_status_dict_list_ev_pruned,
    'deduplicate': med_status_dict_list_dv,
    'omission': med_status_dict_list_ov,
} | {f'ov{i}': med_status_dict_list_ovs[i] for i in range(len(med_status_dict_list_ovs))} \
    | {f'dv{i}': med_status_dict_list_dvs[i] for i in range(len(med_status_dict_list_dvs))}
mets_dict = defaultdict(list)
for k in PREDS_DICT.keys():
    mets_dict_single = clin.eval.calculate_metrics(PREDS_DICT[k], dfe, verbose=False)
    for k_met in mets_dict_single.keys():
        mets_dict[k_met].append(mets_dict_single[k_met])
df = pd.DataFrame.from_dict(mets_dict).round(3)
df.index = PREDS_DICT.keys()
df.style.format(precision=3).background_gradient(cmap='Blues')

Unnamed: 0,recall,precision,f1
original,0.921,0.918,0.919
evidence,0.915,0.917,0.916
evidence_pruned,0.915,0.92,0.917
deduplicate,0.794,0.9,0.844
omission,0.929,0.852,0.889
ov0,0.935,0.846,0.888
dv0,0.794,0.9,0.844


### Print errors

In [129]:
mets = clin.eval.calculate_metrics(med_status_dict_list_ev_pruned, dfe, verbose=True)

correct
ret ['coumadin', 'oxycodone', 'percocet', 'vistaril']
grt ['coumadin', 'oxycodone', 'percocet', 'verapamil', 'vistaril']

ret ['cidofovir', 'ditropan']
grt ['chemotherapy', 'cidofovir', 'ditropan']

ret ['levofloxacin', 'vancomycin', 'zosyn']
grt ['levofloxacin', 'vancomycin', 'vancomycin', 'zosyn']

correct
correct
correct
correct
correct
ret ['6 mp', 'b-12', 'birth control pill', 'dilaudid', 'hydrochlorothiazide', 'lisinopril', 'percocet']
grt ['6 mp', 'b-12', 'dilaudid', 'hydrochlorothiazide', 'lisinopril', 'morphine', 'percocet', 'remicade', 'sulfa']

correct
ret ['aspirin', 'ibuprofen', 'naprosyn', 'plavix', 'tylenol']
grt ['aspirin', 'ibuprofen', 'naprosyn', 'plavix', 'tylenol', 'tylenol es']

ret ['progesterone shots']
grt ['progesterone']

correct
correct
correct
correct
correct
ret ['detrol-la', 'duragesic', 'oxycodone', 'urecholine']
grt ['detrol-la', 'duragesic', 'oxycodone ir', 'urecholine']

correct
correct
ret ['ace inhibitor', 'coumadin', 'plavix']
grt ['ace inhi