In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import joblib
import sys
import clin.llm
import datasets
import time
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'
import numpy as np
from typing import List, Dict
import clin.prompts
import clin.eval
from collections import defaultdict

In [None]:
# dataset: mitclinicalml/clinical-ie
# 3 splits here: 'medication_status', 'medication_attr', 'coreference
llm = clin.llm.get_llm('gpt-4-0314')

## List medications task

In [None]:
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
val = pd.DataFrame.from_dict(dset['validation'])
test = pd.DataFrame.from_dict(dset['test'])
df = pd.concat([val, test])

In [None]:
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
n_shots = 5
resps = []
for i in tqdm(range(len(nums))):
    # print(i)
    if i - n_shots < 0:
        examples_nums_shot = nums[i - n_shots:] + nums[:i]
    else:
        examples_nums_shot = nums[i - n_shots: i]
    ex_num = nums[i]
    prompt = clin.prompts.get_multishot_prompt(df, examples_nums_shot, ex_num)

    response = None
    while response is None:
        try:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
                # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
                # {"role": "user", "content": "Where was it played?"}
            ]
            response = llm(messages)
        except:
            time.sleep(1)
    resps.append(response['choices'][0]['message']['content'])

In [None]:
dfe = df.iloc[nums]
dfe['resps'] = resps

In [None]:
dfe.head()

In [59]:
mets_dict = defaultdict(list)
for i in range(len(dfe)):
    # print(i)
    medications_list_resp = clin.eval.parse_response_medication_list(dfe.iloc[i]['resps'])
    mets = clin.eval.eval_med_extraction(medications_list_resp, dfe.iloc[i])
    for k in ['precision', 'recall']:
        mets_dict[k].append(mets[k])
# for resp in dfe['resps'][:3]:
    # print(resp, end='\n\n')
    # print(parse_response_list(resp))
    # eval_med_extraction(parse_response_list(resp), dfe.iloc[0])

['acyclovir', 'bactrim', 'cyclosporin', 'gatifloxacin', 'steroids']
['acyclovir', 'bactrim', 'cyclosporin', 'gatifloxacin', 'systemic steroids']

['levothyroxine']
['levothyroxine']

['azithromycin', 'ceftriaxone', 'clindamycin']
['azithromycin', 'ceftriaxone', 'clindamycin']

['haldol', 'lipitor', 'zocor']
['haldol', 'lipitor', 'zocor']

['ec asa', 'lasix', 'lisinopril', 'metoprol', 'ns']
['ec asa', 'lasix', 'lisinopril', 'metoprol']

['cozaar', 'darvocet', 'iv']
['cozaar', 'darvocet']

['levaquin', 'zithromax']
['levaquin', 'zithromax']

['fentanyl patch', 'gemzar chemotherapy', 'oxycontin ir', 'pca', 'percocet']
['fentanyl', 'gemzar', 'oxycontin ir', 'percocet']

['progesterone shots']
['progesterone']

['vancomycin', 'zosyn']
['vancomycin', 'vancomycin', 'zosyn']

['p.o. protonix', 'protonix drip', 'sliding scale']
['protonix']

['inderal-s', 'lactulose', 'rifampin']
['inderal-s', 'lactulose', 'rifampin']

['humalog 75/25 insulin', 'lantus insulin', 'metformin', 'multivitamin table

In [None]:
print(f'recall {np.mean(mets_dict["recall"]):.3f} precision {np.mean(mets_dict["precision"]):.3f}')