In [29]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import joblib
import sys
import clin.llm
import datasets
import time
import openai
openai.api_key_path = '/home/chansingh/.OPENAI_KEY'
import numpy as np
from typing import List, Dict
import clin.prompts
from collections import defaultdict

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# dataset: mitclinicalml/clinical-ie
# 3 splits here: 'medication_status', 'medication_attr', 'coreference
llm = clin.llm.get_llm('gpt-4-0314')

## List medications task

In [None]:
dset = datasets.load_dataset('mitclinicalml/clinical-ie', 'medication_status')
val = pd.DataFrame.from_dict(dset['validation'])
test = pd.DataFrame.from_dict(dset['test'])
df = pd.concat([val, test])

In [None]:
df.iloc[ex_num]

In [38]:
nums = np.arange(len(df)).tolist()
np.random.default_rng(seed=13).shuffle(nums)
n_shots = 5
resps = []
for i in tqdm(range(len(nums))):
    # print(i)
    if i - n_shots < 0:
        examples_nums_shot = nums[i - n_shots:] + nums[:i]
    else:
        examples_nums_shot = nums[i - n_shots: i]
    ex_num = nums[i]
    prompt = clin.prompts.get_multishot_prompt(df, examples_nums_shot, ex_num)

    response = None
    while response is None:
        try:
            messages = [
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt},
                # {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."},
                # {"role": "user", "content": "Where was it played?"}
            ]
            response = llm(messages)
        except:
            time.sleep(1)
    resps.append(response['choices'][0]['message']['content'])

100%|██████████| 105/105 [00:00<00:00, 1136.95it/s]

cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!
cached!





In [39]:
dfe = df.iloc[nums]
dfe['resps'] = resps

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  dfe['resps'] = resps


In [40]:
dfe.head()

Unnamed: 0,index,snippet,active_medications,discontinued_medications,neither_medications,resps
12,17,_%#NAME#%_ is asplenic so gatifloxacin was ini...,"[""gatifloxacin"", ""acyclovir"", ""Bactrim""]","[""systemic steroids"", ""cyclosporin""]",[],"""gatifloxacin"" (active)\n- ""acyclovir"" (active..."
28,33,The patient was started on levothyroxine 0.15 ...,[],"[""levothyroxine""]",[],"""levothyroxine"" (discontinued)"
59,64,PHYSICAL EXAMINATION: VITAL SIGNS: Temperature...,"[""clindamycin""]","[""azithromycin"", ""ceftriaxone""]",[],"""azithromycin"" (discontinued)\n- ""clindamycin""..."
18,23,Her Lipitor has been discontinued after admiss...,"[""Zocor"", ""Haldol""]","[""Lipitor""]",[],"""Lipitor"" (discontinued)\n- ""Zocor"" (active)\n..."
34,39,2. Minimal fluid overload. 3. Acute on chronic...,"[""metoprol"", ""EC ASA""]","[""lisinopril"", ""lasix""]",[],"""NS 75 cc/h"" (active for 6 hours only)\n- ""las..."


In [41]:


def parse_response_list(s: str) -> Dict[str, str]:
    """
    "Gatifloxacin" (initiated)
    - "Acyclovir" (prophylactic therapy through day 100)
    - "Bactrim" (active for PCP prophylaxis)
    - "Systemic steroids" (weaned)
    - "Cyclosporin" (discontinued)

    -> 

    {
        'Gatifloxacin': 'initiated',
        'Acyclovir': 'prophylactic therapy through day 100',
        'Bactrim': 'active for PCP prophylaxis',
        'Systemic steroids': 'weaned',
        'Cyclosporin': 'discontinued',
    }
    """
    s = s.replace('- ', '')
    s_list = s.split('\n')
    med_status_dict = {}
    for i, s in enumerate(s_list):
        # find second occurence of "
        idx = s.find('"', s.find('"') + 1)
        medication = s[:idx].strip('"')
        status = s[idx + 1:].strip().strip('()')
        med_status_dict[medication] = status
    return med_status_dict

def eval_med_extraction(med_status_dict: Dict[str, str], df_row: pd.Series) -> List[bool]:
    """
    Given a dictionary of medication status, and a row of the dataframe,
    return precision and recall
    """
    meds_retrieved = list(med_status_dict.keys())
    meds_true = clin.prompts.str_to_list(df_row['active_medications']) + \
        clin.prompts.str_to_list(df_row['discontinued_medications']) + \
        clin.prompts.str_to_list(df_row['neither_medications'])

    # clean up
    meds_retrieved = [med.strip(' "').lower() for med in meds_retrieved]
    meds_true = [med.strip(' "').lower() for med in meds_true]
    print(sorted(meds_retrieved))
    print(sorted(meds_true))
    print()

    # compute precision and recall
    precision = len(set(meds_retrieved).intersection(set(meds_true))) / len(meds_retrieved)
    recall = len(set(meds_retrieved).intersection(set(meds_true))) / len(meds_true)
    return {
        'precision': precision,
        'recall': recall,
    }

mets_dict = defaultdict(list)
for i in range(len(dfe)):
    # print(i)
    mets = eval_med_extraction(parse_response_list(dfe.iloc[i]['resps']), dfe.iloc[i])
    for k in ['precision', 'recall']:
        mets_dict[k].append(mets[k])
# for resp in dfe['resps'][:3]:
    # print(resp, end='\n\n')
    # print(parse_response_list(resp))
    # eval_med_extraction(parse_response_list(resp), dfe.iloc[0])

['acyclovir', 'bactrim', 'cyclosporin', 'gatifloxacin', 'steroids']
['acyclovir', 'bactrim', 'cyclosporin', 'gatifloxacin', 'systemic steroids']

['levothyroxine']
['levothyroxine']

['azithromycin', 'ceftriaxone', 'clindamycin']
['azithromycin', 'ceftriaxone', 'clindamycin']

['haldol', 'lipitor', 'zocor']
['haldol', 'lipitor', 'zocor']

['ec asa 81 mg', 'lasix', 'lisinopril', 'metoprol', 'ns 75 cc/h']
['ec asa', 'lasix', 'lisinopril', 'metoprol']

['cozaar', 'darvocet', 'iv']
['cozaar', 'darvocet']

['levaquin', 'zithromax']
['levaquin', 'zithromax']

['fentanyl patch', 'gemzar chemotherapy', 'oxycontin ir', 'pca', 'percocet']
['fentanyl', 'gemzar', 'oxycontin ir', 'percocet']

['progesterone shots']
['progesterone']

['vancomycin', 'zosyn']
['vancomycin', 'vancomycin', 'zosyn']

['p.o. protonix', 'protonix drip', 'sliding scale']
['protonix']

['inderal-s', 'lactulose', 'rifampin']
['inderal-s', 'lactulose', 'rifampin']

['humalog 75/25 insulin', 'lantus insulin', 'metformin', 'mult

In [47]:
print(f'recall {np.mean(mets_dict["recall"]):.3f} precision {np.mean(mets_dict["precision"]):.3f}')

recall 0.877 precision 0.861
