In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
from copy import deepcopy
import clin.eval
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

sys.path.append('../experiments/')
results_dir = '../results/'

r = imodelsx.process_results.get_results_df(results_dir, use_cached=False)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]
n = len(dfe)

def add_status_eval(r):
    '''Add status eval by aggregating over all columns with dict_ in the name
    '''
    d = defaultdict(list)
    dict_columns = [k for k in r.keys() if k.startswith('dict_') and not k.startswith('dict_evidence')]
    # common_meds_status_gt_dict = clin.eval.get_common_medications(r[dict_columns].values.flatten().tolist(), dfe)
    for i in range(r.shape[0]):
        row = r.iloc[i]        
        med_status_dicts_list = [row[k] for k in dict_columns]
        common_meds_status_gt_dict = clin.eval.get_common_medications(med_status_dicts_list, dfe)
        accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(med_status_dicts_list, common_meds_status_gt_dict)
        for j, setting in enumerate(dict_columns):
            setting_name = setting.replace('dict_', '')
            d[f'status_acc_cond___{setting_name}'].append(accs_cond[j])
            d[f'status_f1_macro_cond___{setting_name}'].append(f1s_macro_cond[j])
    for k in d:
        r[k] = d[k]
    return r
r = add_status_eval(r)


cols = {
    'f1___original': 'Medication extraction, original',
    'f1___ov_pv_ev': 'Medication extraction, self-verified',
    'status_f1_macro_cond___original': 'Medication status, original',
    'status_f1_macro_cond___sv': 'Medication status, self-verified',
}
rt = (
    r.groupby(['checkpoint', 'n_shots'])[list(cols.keys())].mean().rename(columns=cols).round(3).T
    # highlight rows 2 and 4 including the index
    .style.apply(lambda x: ['background: #333' if 'self-verified' in x.name else '' for i in x], axis=1)
    # bold index label for row 2
    .apply(lambda x: ['font-weight: bold' if 'self-verified' in x.name else '' for i in x], axis=1)
    .format(precision=3)
)
rt
# rt.style.format(precision=3).background_gradient(cmap='Blues')

100%|██████████| 4/4 [00:00<00:00, 1040.25it/s]


  0%|          | 0/2 [00:00<?, ?it/s]

100%|██████████| 7/7 [00:00<00:00, 500.91it/s]
100%|██████████| 7/7 [00:00<00:00, 418.07it/s]
100%|██████████| 7/7 [00:00<00:00, 443.45it/s]
100%|██████████| 7/7 [00:00<00:00, 491.60it/s]


checkpoint,gpt-4-0314,gpt-4-0314,text-davinci-003,text-davinci-003
n_shots,1,5,1,5
"Medication extraction, original",0.83,0.878,0.899,0.919
"Medication extraction, self-verified",0.881,0.906,0.923,0.936
"Medication status, original",0.04,0.167,0.398,0.66
"Medication status, self-verified",0.074,0.174,0.448,0.763


In [95]:
# compare values for a single row
row_df = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')]
rc = row_df[[c for c in row_df.columns if '___' in c]]
# create multindex columns by splitting on '___'
rc = rc.rename(columns={c: tuple(c.split('___')) for c in rc.columns})

# convert tuple column names to multiindex
rc.columns = pd.MultiIndex.from_tuples(rc.columns)
rc = rc.T.reset_index()
rc = rc.rename(columns={
    'level_0': '',
    'level_1': 'Verifiers',
}).pivot_table(index='Verifiers', columns='', values=0).round(3)
rc.style.format(precision=3).background_gradient(cmap='gray')

Unnamed: 0_level_0,f1,precision,recall,status_acc_cond,status_f1_macro_cond
Verifiers,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
ev,0.918,0.918,0.918,0.888,0.66
original,0.919,0.918,0.921,0.888,0.66
ov,0.906,0.87,0.944,0.891,0.663
ov_pv,0.936,0.946,0.926,0.891,0.663
ov_pv_ev,0.936,0.949,0.924,0.891,0.663
pv,0.926,0.948,0.906,0.888,0.66
sv,0.936,0.949,0.924,0.897,0.763


In [96]:
cols = {
    'f1': 'F1',
    'precision': 'Precision',
    'recall': 'Recall',
    'status_f1_macro_cond': 'F1 (Medication Status)',
}
rows = {
    'original': 'Original',
    'ov': 'Omission',
    'pv': 'Prune',
    'ov_pv': 'Omission + Prune',
    'sv': 'Omission + Prune + Evidence'
}
(
    rc[list(cols.keys())].rename(columns=cols)
    .loc[list(rows.keys())].rename(index=rows)
    .style.format(precision=3).background_gradient(cmap='Blues')
)


Unnamed: 0_level_0,F1,Precision,Recall,F1 (Medication Status)
Verifiers,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
Original,0.919,0.918,0.921,0.66
Omission,0.906,0.87,0.944,0.663
Prune,0.926,0.948,0.906,0.66
Omission + Prune,0.936,0.946,0.926,0.663
Omission + Prune + Evidence,0.936,0.949,0.924,0.763


In [None]:
# print errs for med extraction
# mets = clin.eval.calculate_metrics(row['dict_ov_pv_ev'], dfe, verbose=True)

# print errors for status extraction
# accs_cond, f1s_macro_cond = clin.eval.eval_medication_status([med_status_dicts_list[-1]], dfe, verbose=True)

## Interpret an example

In [59]:
# single row to investigate
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]

In [60]:
row.keys()

Index(['dataset_name', 'seed', 'save_dir', 'checkpoint', 'checkpoint_verify',
       'n_shots', 'use_cache', 'save_dir_unique', 'extracted_strs',
       'recall___original', 'precision___original', 'f1___original',
       'dict_original', 'recall___ov', 'precision___ov', 'f1___ov', 'dict_ov',
       'recall___pv', 'precision___pv', 'f1___pv', 'dict_pv', 'recall___ev',
       'precision___ev', 'f1___ev', 'dict_ev', 'recall___ov_pv',
       'precision___ov_pv', 'f1___ov_pv', 'dict_ov_pv', 'recall___ov_pv_ev',
       'precision___ov_pv_ev', 'f1___ov_pv_ev', 'dict_ov_pv_ev', 'recall___sv',
       'precision___sv', 'f1___sv', 'dict_sv', 'dict_evidence_ov_pv_ev',
       'status_acc_cond___original', 'status_f1_macro_cond___original',
       'status_acc_cond___ov', 'status_f1_macro_cond___ov',
       'status_acc_cond___pv', 'status_f1_macro_cond___pv',
       'status_acc_cond___ev', 'status_f1_macro_cond___ev',
       'status_acc_cond___ov_pv', 'status_f1_macro_cond___ov_pv',
       'status_a

In [81]:
i = 11
ev = row['dict_evidence_ov_pv_ev'][i]
med_status_dict = row['dict_sv'][i]
gt = dfe.iloc[i]
print(gt)
print(med_status_dict)
print('\n' + gt['snippet'])

for k in ev:
    print(k)
    print('\t', ev[k])
    print()

index                                                                      18
snippet                     Her aspirin (81 mg q.d.) is discontinued, and ...
active_medications                        ["Plavix", "Tylenol", "Tylenol ES"]
discontinued_medications                                          ["aspirin"]
neither_medications                                 ["ibuprofen", "Naprosyn"]
Name: 13, dtype: object
{'aspirin': 'discontinued', 'ibuprofen': 'neither', 'Naprosyn': 'neither', 'Tylenol': 'active', 'Plavix': 'active'}

Her aspirin (81 mg q.d.) is discontinued, and the patient is advised that she needs to avoid ibuprofen, Naprosyn, alcohol, caffeine, and chocolate. She is advised that Tylenol 325 mg or Tylenol ES (500 mg) is safe to take at 1 or 2 q.4-6h. p.r.n. for pain or fever. Discharge activity is without restriction. DISCHARGE MEDICATIONS: 1. Plavix 75 mg p.o. q.d.
aspirin
	 her aspirin (81 mg q.d.) is discontinued

ibuprofen
	 avoid ibuprofen

Naprosyn
	 avoid naprosyn

Tyl