In [2]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy
import clin.eval
import clin.modules.prune
import clin.modules.evidence
import clin.modules.omission
import clin.modules.status
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]
n = len(dfe)

# (
#     r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision']].mean()
#     .style.format(precision=3).background_gradient(cmap='Blues')
# )

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


  0%|          | 0/2 [00:00<?, ?it/s]

In [17]:
rc = r[[c for c in r.columns if '___' in c]]
# create multindex columns by splitting on '___'
rc = rc.rename(columns={c: tuple(c.split('___')) for c in rc.columns})

# convert tuple column names to multiindex
rc.columns = pd.MultiIndex.from_tuples(rc.columns)
rc = rc.T.reset_index()
rc = rc.rename(columns={
    'level_0': 'metric',
    'level_1': 'verification',
}).pivot_table(index='verification', columns='metric', values=0).round(3)

rc

metric,f1,precision,recall
verification,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
ev,0.918,0.918,0.918
original,0.919,0.918,0.921
ov,0.906,0.87,0.944
ov_pv,0.936,0.946,0.926
ov_pv_ev,0.938,0.949,0.926
pv,0.926,0.948,0.906


## Test verifiers on medication extraction

In [2]:
# get original
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
extracted_strs_orig = row['resps']
med_status_dict_list_orig = [clin.parse.parse_response_medication_list(extracted_strs_orig[i]) for i in range(n)]
llm_verify = clin.llm.get_llm('text-davinci-003')

ov = clin.modules.omission.OmissionVerifier()
pv = clin.modules.prune.PruneVerifier()
ev = clin.modules.evidence.EvidenceVerifier(n_shots_neg=1, n_shots_pos=1)
sv = clin.modules.status.StatusVerifier()

In [None]:
# apply omission verifier
med_status_dict_list_ov = [ov(dfe.iloc[i]['snippet'], bulleted_str=extracted_strs_orig[i], llm=llm_verify, verbose=False) for i in tqdm(range(n))]

# apply prune verifier
med_status_dict_list_pv = [pv(dfe.iloc[i]['snippet'], bulleted_str=extracted_strs_orig[i], llm=llm_verify, verbose=False) for i in tqdm(range(n))]

# apply evidence verifier
med_status_and_evidence = [ev(snippet=dfe.iloc[i]['snippet'], bulleted_str=extracted_strs_orig[i], llm=llm_verify) for i in tqdm(range(n))]
med_status_dict_list_ev = [med_status_and_evidence[i][0] for i in range(n)]
med_evidence_dict_list_ev = [med_status_and_evidence[i][1] for i in range(n)]

In [1]:
# apply sequential verifiers
med_status_dict_list_ov_ = [ov(dfe.iloc[i]['snippet'], bulleted_str=extracted_strs_orig[i], llm=llm_verify, lower=False) for i in tqdm(range(n))]
bulleted_str_list_ov_ = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_ov_[i]) for i in range(n)]

med_status_dict_list_pv_ = [pv(dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_ov_[i], llm=llm_verify, lower=False) for i in tqdm(range(n))]
bulleted_str_list_pv_ = [clin.parse.medication_dict_to_bullet_str(med_status_dict_list_pv_[i]) for i in range(n)]

med_status_and_evidence_ = [ev(snippet=dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_pv_[i], llm=llm_verify) for i in tqdm(range(n))]
med_status_dict_list_ev_ = [med_status_and_evidence_[i][0] for i in range(n)]
med_evidence_dict_list_ev_ = [med_status_and_evidence_[i][1] for i in range(n)]

NameError: name 'tqdm' is not defined

In [None]:
med_status_results = {
    'original': med_status_dict_list_orig,
    'omission': med_status_dict_list_ov,
    'prune': med_status_dict_list_pv,
    'evidence': med_status_dict_list_ev,
    'omission + prune': med_status_dict_list_pv_,
    'omission + prune + evidence': med_status_dict_list_ev_,
} 
mets_dict = defaultdict(list)
for k in med_status_results.keys():
    mets_dict_single = clin.eval.calculate_metrics(med_status_results[k], dfe, verbose=False)
    for k_met in mets_dict_single.keys():
        mets_dict[k_met].append(mets_dict_single[k_met])
df = pd.DataFrame.from_dict(mets_dict).round(3)[['f1', 'recall', 'precision']]
df.index = med_status_results.keys()
df.style.format(precision=3).background_gradient(cmap='Blues')

## Print errors

In [None]:
mets = clin.eval.calculate_metrics(med_status_dict_list_pvs[0], dfe, verbose=True)

# Test on medication status

In [None]:
d = med_status_dict_list_pvs[0]

In [None]:
bulleted_str_list_d = [clin.parse.medication_dict_to_bullet_str(d[i]) for i in tqdm(range(n))]

In [None]:
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
# ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=0, n_shots_pos=2)
med_status_dict_list_ev = []
med_evidence_dict_list_ev = []
for i in tqdm(range(n)):
    med_status_dict, med_evidence_dict = ev(
        snippet=dfe.iloc[i]['snippet'], bulleted_str=bulleted_str_list_d[i], llm=llm)
    med_status_dict_list_ev.append(med_status_dict)
    med_evidence_dict_list_ev.append(med_evidence_dict)
med_status_dict_list_ev_pruned = [
    {
        k: v for k, v in med_status_dict_list_ev[i].items()
        if not med_evidence_dict_list_ev[i][k] == 'no evidence'
        # and med_evidence_dict_list_ev[i][k] in dfe.iloc[i]['snippet'].lower()
    }
    for i in range(n)
]

In [None]:
for i in tqdm(range(n)):
    print(i)
    sv(dfe.iloc[i]['snippet'],
                              med_status_dict=med_status_dict_list_ev[i],
                                med_evidence_dict=med_evidence_dict_list_ev[i],
                                llm=llm)

In [None]:
med_status_dict_list_sv = [sv(dfe.iloc[i]['snippet'],
                              med_status_dict=med_status_dict_list_ev[i],
                                med_evidence_dict=med_evidence_dict_list_ev[i],
                                llm=llm) for i in tqdm(range(n))]

In [None]:
# add medication status eval
med_status_dicts_list = [[clin.parse.parse_response_medication_list(r.iloc[i]['resps'][j]) for j in range(n)] for i in range(len(r))]
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(med_status_dicts_list, dfe)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)

In [None]:
med_status_dicts_list = [[clin.parse.parse_response_medication_list(r.iloc[i]['resps'][j]) for j in range(n)] for i in range(len(r))]
med_status_dicts_list += list(PREDS_DICT.values()) + [d_pruned]
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(med_status_dicts_list, dfe)
idx = list((r.checkpoint + ' ' + r.n_shots.astype(str)).values) + list(PREDS_DICT.keys()) + ['d']

In [None]:
(
    pd.DataFrame.from_dict({'accs_cond': accs_cond, 'f1s_macro_cond': f1s_macro_cond}, orient='index', columns=idx).T
    .style.format(precision=3).background_gradient(cmap='Blues')
)