In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import joblib
import sys
import clin.llm
import datasets
import time
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'
import numpy as np
from typing import List

In [None]:
# dataset: mitclinicalml/clinical-ie
# 3 splits here: 'medication_status', 'medication_attr', 'coreference
llm = clin.llm.get_llm('gpt-4-0314')

## List medications task

In [None]:
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
val = pd.DataFrame.from_dict(dset['validation'])
test = pd.DataFrame.from_dict(dset['test'])
df = pd.concat([val, test])

def list_medications(row) -> str:
    def str_to_list(s):
        l = s.replace('[', '').replace(']', '').split(',')
        l = [val.strip() for val in l]
        if l == ['']:
            return []
        else:
            return l
        
    d = [('active', val) for val in str_to_list(row['active_medications'])] + \
        [('discontinued', val) for val in str_to_list(row['discontinued_medications'])] + \
        [('neither', val) for val in str_to_list(row['neither_medications'])]
    np.random.default_rng(seed=13).shuffle(d)
    # print(d)
    s = '- ' + '\n- '.join([f'{med} ({status})' for status, med in d])
    return s
# print(list_medications(row))

In [None]:
def get_multishot_prompt(df, examples_nums_shot: List[int], ex_num: int):
    prompt = ''
    for ex in examples_nums_shot:
        prompt += f'''Patient note: {df.iloc[ex]['snippet']}

Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.

{list_medications(df.iloc[ex])}

'''
    prompt += f'''Patient note: {df.iloc[ex_num]['snippet']}

Create a bulleted list of which medications are mentioned and whether they are active, discontinued, or neither.

-'''
    return prompt

# """
# - "Kadian" (active)
# -"Dilaudid" (discontinued)
# -"Levaquin" (active)
# """

nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
n_shots = 1
for i in tqdm(range(len(nums))):
    # print(i)
    if i - n_shots < 0:
        examples_nums_shot = nums[i - n_shots:] + nums[:i]
    else:
        examples_nums_shot = nums[i - n_shots: i]
    ex_num = nums[i]
    prompt = get_multishot_prompt(df, examples_nums_shot, ex_num)

    response = None
    while response is None:
        try:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
                # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
                # {"role": "user", "content": "Where was it played?"}
            ]
            response = llm(messages)
        except:
            time.sleep(1)

In [None]:
if response is not None:
    response_text = response['choices'][0]['message']['content']
print(response_text)