In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import imodelsx.process_results
import sys
import datasets
import numpy as np
import clin.eval

sys.path.append('../experiments/')
results_dir = '../results/'

In [None]:
r = imodelsx.process_results.get_results_df(results_dir, use_cached=True)

# get data for eval
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
df_val = pd.DataFrame.from_dict(dset['validation'])
df = pd.DataFrame.from_dict(dset['test'])
# df = pd.concat([val, test])
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
dfe = df.iloc[nums]

# add medication status eval
accs_cond, f1s_macro_cond = clin.eval.eval_medication_status(dfe, r)
r['acc_cond'] = accs_cond
r['f1_macro_cond'] = f1s_macro_cond

In [None]:
(
    r.groupby(['checkpoint', 'n_shots'])[['f1', 'recall', 'precision', 'acc_cond', 'f1_macro_cond']].mean()
    .style.format(precision=3).background_gradient(cmap='Blues')
)

# Test verifiers

In [232]:
import clin.verifiers.evidence
import clin.llm
import clin.parse
from collections import defaultdict
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'

In [238]:
row = r[(r.n_shots == 5) * (r.checkpoint == 'text-davinci-003')].iloc[0]
llm = clin.llm.get_llm('text-davinci-003')
med_status_dicts = []
med_evidence_dicts = []
# original run was 
ev = clin.verifiers.evidence.EvidenceVerifier(n_shots_neg=2, n_shots_pos=0)
resps = row['resps']

for i in tqdm(range(len(resps))):
    resp = resps[i]
    snippet = dfe.iloc[i]['snippet']

    med_status_dict, med_evidence_dict = ev(snippet=snippet, bulleted_str=resp, llm=llm)
    med_status_dicts.append(med_status_dict)
    med_evidence_dicts.append(med_evidence_dict)

100%|██████████| 100/100 [00:00<00:00, 8131.65it/s]


In [239]:
med_status_dict_pruned = [
    {
        k: v for k, v in med_status_dicts[i].items()
        if not med_evidence_dicts[i][k] == 'no evidence'
    }
    for i in range(len(med_status_dicts))
]

In [240]:
mets_dict = clin.eval.calculate_metrics(med_status_dicts, dfe)
for k in mets_dict.keys():
    print(k, np.mean(mets_dict[k]))
print()
mets_dict = clin.eval.calculate_metrics(med_status_dict_pruned, dfe)
for k in mets_dict.keys():
    print(k, np.mean(mets_dict[k]))

precision 0.9245238095238095
recall 0.9244761904761905
f1 0.9179288489288492

precision 0.9278571428571429
recall 0.9244761904761905
f1 0.9199288489288491
