In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy
import clin.eval
import clin.verifiers.prune
import clin.verifiers.evidence
import clin.verifiers.omission
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]
n = len(dfe)

# add medication status eval
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(dfe, r)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)

## Test verifiers

In [None]:
# get original
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
med_status_dict_list_orig = [clin.parse.parse_response_medication_list(row['resps'][i]) for i in range(n)]
llm = clin.llm.get_llm('text-davinci-003')

In [None]:
# apply prune verifier
pv = clin.verifiers.prune.PruneVerifier()
med_status_dict_list_pv = [pv(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm, verbose=False) for i in tqdm(range(n))]

# apply omission verifier
ov = clin.verifiers.omission.OmissionVerifier()
med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm, verbose=False) for i in tqdm(range(n))]

# apply evidence verifier
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
# ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=0, n_shots_pos=2)
med_status_dict_list_ev = []
med_evidence_dict_list_ev = []
for i in tqdm(range(n)):
    med_status_dict, med_evidence_dict = ev(
        snippet=dfe.iloc[i]['snippet'], bulleted_str=row['resps'][i], llm=llm)
    med_status_dict_list_ev.append(med_status_dict)
    med_evidence_dict_list_ev.append(med_evidence_dict)
med_status_dict_list_ev_pruned = [
    {
        k: v for k, v in med_status_dict_list_ev[i].items()
        if not med_evidence_dict_list_ev[i][k] == 'no evidence'
        # and med_evidence_dict_list_ev[i][k] in dfe.iloc[i]['snippet'].lower()
    }
    for i in range(n)
]

In [64]:
# apply omission verifier then deduplication verifier
med_status_dict_list_ovs = []
med_status_dict_list_pvs = []

ov = clin.verifiers.omission.OmissionVerifier()
pv = clin.verifiers.prune.PruneVerifier()
for iter_num in range(2):
    if iter_num == 0:
        bulleted_str_list_pv_ = row['resps']
        bulleted_str_list_ov_ = row['resps']    

    med_status_dict_list_ov_ = [ov(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_pv_[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_ov_ = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_ov_[i]) for i in tqdm(range(n))]

    med_status_dict_list_pv_ = [pv(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_ov_[i], llm=llm, lower=False) for i in tqdm(range(n))]
    bulleted_str_list_pv_ = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_pv_[i]) for i in tqdm(range(n))]
    
    # save
    med_status_dict_list_ovs.append(deepcopy(med_status_dict_list_ov_))
    med_status_dict_list_pvs.append(deepcopy(med_status_dict_list_pv_))


100%|██████████| 100/100 [00:00<00:00, 6383.25it/s]

100%|██████████| 100/100 [00:00<00:00, 606990.45it/s]

100%|██████████| 100/100 [00:00<00:00, 6517.14it/s]

100%|██████████| 100/100 [00:00<00:00, 607870.14it/s]

100%|██████████| 100/100 [00:00<00:00, 6423.62it/s]

100%|██████████| 100/100 [00:00<00:00, 645277.54it/s]

100%|██████████| 100/100 [00:00<00:00, 5794.68it/s]

100%|██████████| 100/100 [00:00<00:00, 683111.40it/s]


In [65]:
PREDS_DICT = {
    'original': med_status_dict_list_orig,
    'evidence': med_status_dict_list_ev,
    'evidence_pruned': med_status_dict_list_ev_pruned,
    'prune': med_status_dict_list_pv,
    'omission': med_status_dict_list_ov,
} | {f'ov{i}': med_status_dict_list_ovs[i] for i in range(len(med_status_dict_list_ovs))} \
    | {f'pv{i}': med_status_dict_list_pvs[i] for i in range(len(med_status_dict_list_pvs))}
mets_dict = defaultdict(list)
for k in PREDS_DICT.keys():
    mets_dict_single = clin.eval.calculate_metrics(PREDS_DICT[k], dfe, verbose=False)
    for k_met in mets_dict_single.keys():
        mets_dict[k_met].append(mets_dict_single[k_met])
df = pd.DataFrame.from_dict(mets_dict).round(3)[['f1', 'recall', 'precision']]
df.index = PREDS_DICT.keys()
df.style.format(precision=3).background_gradient(cmap='Blues')

Unnamed: 0,f1,recall,precision
original,0.919,0.921,0.918
evidence,0.916,0.915,0.917
evidence_pruned,0.917,0.915,0.92
prune,0.927,0.909,0.945
omission,0.906,0.944,0.87
ov0,0.906,0.944,0.87
ov1,0.903,0.944,0.865
pv0,0.938,0.929,0.946
pv1,0.941,0.938,0.944


## Print errors

In [67]:
mets = clin.eval.calculate_metrics(med_status_dict_list_pvs[0], dfe, verbose=True)

correct
correct
ret ['cidofovir', 'ditropan']
grt ['chemotherapy', 'cidofovir', 'ditropan']

ret ['levofloxacin', 'vancomycin', 'zosyn']
grt ['levofloxacin', 'vancomycin', 'vancomycin', 'zosyn']

correct
correct
correct
correct
correct
ret ['6 mp', 'b-12', 'birth control pill', 'dilaudid', 'hydrochlorothiazide', 'lisinopril', 'morphine', 'percocet', 'remicade', 'sulfa']
grt ['6 mp', 'b-12', 'dilaudid', 'hydrochlorothiazide', 'lisinopril', 'morphine', 'percocet', 'remicade', 'sulfa']

correct
ret ['aspirin', 'ibuprofen', 'naprosyn', 'plavix', 'tylenol']
grt ['aspirin', 'ibuprofen', 'naprosyn', 'plavix', 'tylenol', 'tylenol es']

ret ['progesterone shots']
grt ['progesterone']

correct
correct
correct
correct
correct
ret ['detrol-la', 'duragesic', 'oxycodone', 'urecholine']
grt ['detrol-la', 'duragesic', 'oxycodone ir', 'urecholine']

correct
correct
ret ['ace inhibitor', 'coumadin', 'plavix']
grt ['ace inhibitor', 'coumadin', 'coumadin', 'plavix']

correct
correct
correct
correct
ret ['

In [None]:
# m1 = med_status_dict_list_orig
# m1 = med_status_dict_list_ov
m1 = med_status_dict_list_pvs[0]
# m2 = med_status_dict_list_ov

for i in range(len(m1)):
    # print(dfe.iloc[i]['snippet'])
    if not m1[i] == m2[i]:
        print(m1[i])
        # print(m2[i])

        # print(med_evidence_dict_list_ev[i])
        print()
        print()