## utils

In [1]:
import pandas as pd
import dill
import openai
import tiktoken
from tqdm import tqdm, trange

api_model = 'gpt-4o'

encoding = tiktoken.encoding_for_model(api_model)

def call_api(content):
    """
    call openai api
    :param content: prompt, str
    :return: model response, str
    """
    messages = [{"role": "user", "content": content}]
    completion = openai.ChatCompletion.create(model=api_model, messages=messages, temperature=0)
    msg = completion.get("choices")[0]["message"]["content"]
    return msg


def get_msg(content):
    """
    get response from openai api
    :param content: prompt, str
    :return: model response, str
    """
    try:
        msg = call_api(content)
    except Exception as e:
        msg = get_msg(content)
    return msg


def simplify_note(note):
    prompt = 'Please summarize specific sections from a patient\'s discharge summary: 1. HISTORY OF PRESENT ILLNESS, 2. PAST MEDICAL HISTORY, 3. ALLERGIES, 4. MEDICATIONS ON ADMISSION 5.DISCHARGE MEDICATIONS. Ignore other details while in hospital and focus only on these sections.\n'\
'output template:\n'\
'HISTORY OF PRESENT ILLNESS:\n'\
'(Language summary as short as possible)\n'\
'PAST MEDICAL HISTORY:\n'\
'(Language summary as short as possible)\n'\
'ALLERGIES:\n'\
'(A series of allergies names, separated by commas, does not require any other information)\n'\
'MEDICATIONS ON ADMISSION:\n'\
'(A series of drug names, separated by commas, remove dosage information. Maybe None.)\n'\
'DISCHARGE MEDICATIONS:\n'\
'(A series of drug names, separated by commas, remove dosage information. Maybe None.)\n'\
'Note:' + note + '\n' + 'Summarize result in five aspects in a concise paragraph without any other words:\n'
    
    msg = get_msg(prompt)

    return msg


def split_string(s, splitted_num):
    split_indices = [i * len(s) // splitted_num for i in range(1, splitted_num)]

    result = []
    start = 0
    for index in split_indices:
        end_0 = min(s.find('.', index), s.find('\n', index))
        end_new = s.find('\n\n', index)
        if abs(end_new - end_0) < 200:
            end = end_new
        else:
            end = end_0
        if end == -1:
            end = len(s)
        result.append(s[start:end + 1])
        start = end + 1

    result.append(s[start:])

    return result

def devide_list(origin_text_list):
    while 1:
        new_text_list = []
        for text in origin_text_list:
            if len(encoding.encode(text)) < 3800:
                new_text_list.append(text)
            else:
                splitted_num = len(encoding.encode(text)) // 3800  + 1
                splitted_result = split_string(text, splitted_num)
                new_text_list += splitted_result
        if new_text_list == origin_text_list:
            break
        else:
            origin_text_list = new_text_list
    return new_text_list

def check_note(note):
    idx1 = note.upper().find('HISTORY OF PRESENT ILLNESS')
    idx2 = note.upper().find('PAST MEDICAL HISTORY')
    idx3 = note.upper().find('ALLERGIES')
    idx4 = note.upper().find('MEDICATIONS ON ADMISSION')
    idx5 = note.upper().find('DISCHARGE MEDICATIONS')
    if idx1 == -1 or idx2 == -1 or idx3 == -1 or idx4 == -1 or idx5 == -1:
        return False
    elif idx1 > idx2 or idx2 > idx3 or idx3 > idx4 or idx4 > idx5:
        return False
    else:
        return True

def generate_note(row):
    # for index, row in result_data.iterrows():
    hadm_id = row['HADM_ID']
    note_text = row.TEXT

    origin_text_list = devide_list([note_text])
    if len(origin_text_list) == 1:
        for i in range(10):
            note = simplify_note(origin_text_list[0])
            if check_note(note):
                break
            else:
                note = simplify_note(origin_text_list[0])
  
        return hadm_id, [note]
    else:
        processed_text = []
        for text_idx, text in enumerate(origin_text_list):
            for i in range(10):
                note = simplify_note(text)
                if check_note(note):
                    break
                else:
                    note = simplify_note(text)
            processed_text.append(note)

        return hadm_id, processed_text

In [2]:
openai.api_key = 'sk-xx'

data4LLM_path = 'data_process/output/mimic-iii/data4LLM.csv'
noteevents_path = 'data_process/input/mimic-iii/NOTEEVENTS.csv'

filter_noteevents_path = 'data_process/output/mimic-iii/noteevents_filtered.pkl'
simplified_note_path = 'data_process/output/mimic-iii/note_simp.pkl'
note_p1_path = 'data_process/output/mimic-iii/history_of_present_illness.pkl'
note_p2_path = 'data_process/output/mimic-iii/past_medical_history.pkl'
note_p3_path = 'data_process/output/mimic-iii/allergies.pkl'
note_p4_path = 'data_process/output/mimic-iii/med_on_adm.pkl'
note_content_path = 'data_process/output/mimic-iii/note_content.pkl'

data4LLM_with_note_path = 'data_process/output/mimic-iii/data4LLM_with_note.csv'

## generate noteevents_filtered.pkl

In [None]:
import pandas as pd
import dill

noteevents = pd.read_csv(noteevents_path)
noteevents = noteevents.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME'])

data4LLM = pd.read_csv(data4LLM_path)

# show the statistics of the data and noteevents
print('data4LLM shape:', data4LLM.shape)

# filter out the HADM_ID that are not in data4LLM
print('noteevents shape before filtering:', noteevents.shape)
noteevents = noteevents[noteevents['HADM_ID'].isin(data4LLM['HADM_ID'])]
noteevents = noteevents[(noteevents['CATEGORY'] == 'Discharge summary') & (noteevents['DESCRIPTION'] == 'Report')]
noteevents = noteevents.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'CHARTTIME'])
noteevents = noteevents.reset_index(drop=True)
print('noteevents shape after filtering:', noteevents.shape)

In [None]:
# check whether all hadm_id in data4LLM are in noteevents
def check_hadm_id_in_noteevents(noteevents, data4LLM):
    hadm_id_in_data4LLM = set(data4LLM['HADM_ID'].tolist())
    hadm_id_in_noteevents = set(noteevents['HADM_ID'].tolist())
    num_hadm_id_not_in_noteevents = len(hadm_id_in_data4LLM - hadm_id_in_noteevents)
    print('num_hadm_id_not_in_noteevents:', num_hadm_id_not_in_noteevents)
    print('hadm_id in data4LLM but not in noteevents:', hadm_id_in_data4LLM - hadm_id_in_noteevents)

# check whether all hadm_id in noteevents are in data4LLM
check_hadm_id_in_noteevents(noteevents, data4LLM)

In [None]:
# check whether all hadm_id appears only once in noteevents
def check_hadm_id_appear_only_once(noteevents):
    nunique_hadm_id = noteevents['HADM_ID'].nunique()
    hadm_id_count = noteevents['HADM_ID'].value_counts()
    hadm_id_appear_more_than_once = hadm_id_count[hadm_id_count > 1].index.tolist()
    print('nunique_hadm_id:', nunique_hadm_id, '\t', len(hadm_id_appear_more_than_once), 'hadm_ids appear more than once', hadm_id_appear_more_than_once)
    return hadm_id_appear_more_than_once

print('check whether all hadm_id appears only once in noteevents...')
hadm_id_appear_more_than_once = check_hadm_id_appear_only_once(noteevents)

In [None]:
# for those hadm_id that appear more than once, we keep the first appearance
print('noteevents shape before filtering:', noteevents.shape)
noteevents_filtered = noteevents.drop_duplicates(subset=['HADM_ID'], keep='first')
print('noteevents shape after filtering:', noteevents_filtered.shape)
_ = check_hadm_id_appear_only_once(noteevents_filtered)
# 把hadm_id转换成int
noteevents_filtered['HADM_ID'] = noteevents_filtered['HADM_ID'].astype(int)


dill.dump(noteevents_filtered, open(filter_noteevents_path, 'wb'))

## split and simplify the note

In [None]:
note_data = dill.load(open(filter_noteevents_path, 'rb'))

result_pd = pd.DataFrame(columns=['HADM_ID', 'NOTE'])

for _, row in tqdm(note_data.iterrows(), total=len(note_data)):
    
    hadm_id, note_list = generate_note(row)

    for note in note_list:
        result_pd.loc[len(result_pd)] = [hadm_id, note]

    dill.dump(result_pd, open(simplified_note_path, 'wb'))

## Manually modify notes that still do not meet the format requirements

In [None]:
process_note = dill.load(open(simplified_note_path, 'rb'))

for i, row in process_note.iterrows():
    if not check_note(row['NOTE']):
        print(row['NOTE'])
        print('*********************************************')
        process_note.at[i, 'NOTE'] = input('Please input the correct note: ')

dill.dump(process_note, open(simplified_note_path, 'wb'))

## history of present illness

In [None]:
final_note = dill.load(open(simplified_note_path, 'rb'))

idx = 0

result_pd_1 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

while idx < len(final_note):
    hadm_id = final_note.loc[idx].HADM_ID
    visit_note = ' '
    note_part = final_note.loc[idx].NOTE
    start = note_part.upper().find('HISTORY OF PRESENT ILLNESS')
    end = note_part.upper().find('PAST MEDICAL HISTORY')
    visit_note += note_part[start + len('HISTORY OF PRESENT ILLNESS:'):end].strip()
    while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
        idx += 1
        visit_note += '  +  '
        note_part = final_note.loc[idx].NOTE
        start = note_part.upper().find('HISTORY OF PRESENT ILLNESS')
        end = note_part.upper().find('PAST MEDICAL HISTORY')
        visit_note += note_part[start + len('HISTORY OF PRESENT ILLNESS:'):end].strip()

    idx += 1

    result_pd_1.loc[len(result_pd_1)] = [hadm_id, visit_note]

history_of_present_illness_pd = pd.DataFrame(columns=['HADM_ID', 'HISTORY OF PRESENT ILLNESS'])



for idx in trange(len(result_pd_1)):
    hadm_id = result_pd_1.loc[idx].HADM_ID
    note = result_pd_1.loc[idx].CON_NOTE
    prompt = '''
I'll provide you with an input containing the history of present illness for a patient. Your task is to:
1.Retain the descriptions of the patient's history of present illness before admission and on admission, while removing all descriptions after admission and at discharge.
2.Consolidate the text to produce a concise output.

input: ''' + note + '''

You only need to answer the refined results, no other explanation is needed!

output:
'''
    result = get_msg(prompt)

    history_of_present_illness_pd.loc[len(history_of_present_illness_pd)] = [hadm_id, result]

    dill.dump(history_of_present_illness_pd, open(note_p1_path, 'wb'))

## past medical history

In [None]:
idx = 0

final_note = dill.load(open(simplified_note_path, 'rb'))

result_pd_2 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

while idx < len(final_note):
    hadm_id = final_note.loc[idx].HADM_ID
    visit_note = ' '
    note_part = final_note.loc[idx].NOTE
    start = note_part.upper().find('PAST MEDICAL HISTORY')
    end = note_part.upper().find('ALLERGIES')
    visit_note += note_part[start + len('PAST MEDICAL HISTORY:'):end].strip()
    while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
        idx += 1
        visit_note += '  +  '
        note_part = final_note.loc[idx].NOTE
        start = note_part.upper().find('PAST MEDICAL HISTORY')
        end = note_part.upper().find('ALLERGIES')
        visit_note += note_part[start + len('PAST MEDICAL HISTORY:'):end].strip()

    idx += 1

    result_pd_2.loc[len(result_pd_2)] = [hadm_id, visit_note]


past_medical_history_pd = pd.DataFrame(columns=['HADM_ID', 'PAST MEDICAL HISTORY'])


for idx in trange(len(result_pd_2)):
    hadm_id = result_pd_2.loc[idx].HADM_ID
    note = result_pd_2.loc[idx].CON_NOTE
    prompt = '''
I'll provide you with input containing a patient's past medical history. I need you to consolidate the text and output a concise summary.

input: ''' + note + '''

You only need to answer the refined results, no other explanation is needed!

output:
'''
    result = get_msg(prompt)

    past_medical_history_pd.loc[len(past_medical_history_pd)] = [hadm_id, result]

    dill.dump(past_medical_history_pd, open(note_p2_path, 'wb'))

## allergies

In [None]:
idx = 0

final_note = dill.load(open(simplified_note_path, 'rb'))

result_pd_3 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])

while idx < len(final_note):
    hadm_id = final_note.loc[idx].HADM_ID
    visit_note = ' '
    note_part = final_note.loc[idx].NOTE
    start = note_part.upper().find('ALLERGIES')
    end = note_part.upper().find('MEDICATIONS ON ADMISSION')
    visit_note += note_part[start + len('ALLERGIES:'):end].strip()
    while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
        idx += 1
        visit_note += '  +  '
        note_part = final_note.loc[idx].NOTE
        start = note_part.upper().find('ALLERGIES')
        end = note_part.upper().find('MEDICATIONS ON ADMISSION')
        visit_note += note_part[start + len('ALLERGIES:'):end].strip()
    result_pd_3.loc[len(result_pd_3)] = [hadm_id, visit_note]

    idx += 1

allergies_pd = pd.DataFrame(columns=['HADM_ID', 'ALLERGIES'])

for idx in trange(len(result_pd_3)):
    hadm_id = result_pd_3.loc[idx].HADM_ID
    note = result_pd_3.loc[idx].CON_NOTE
    prompt = prompt = '''
I'm going to give you an input, which is a bunch of text and some plus signs. I need you to extract all the drug names for me from each input, and output the corresponding list.

Here are some of the input and output sample:

input1:No Known Allergies to Drugs.  +  None mentioned.

output1:[]

input2:None mentioned.  +  The patient is allergic to cefazolin and penicillins.

output2:[cefazolin, penicillins]

Now you need to provide the corresponding output of input3, without any other words:

input3:''' + note + '''

You only need to output a list!

output3:
'''
    result = get_msg(prompt) 

    allergies_pd.loc[len(allergies_pd)] = [hadm_id, result]

    dill.dump(allergies_pd, open(note_p3_path, 'wb'))

## med_on_adm

In [None]:
idx = 0

final_note = dill.load(open(simplified_note_path, 'rb'))

result_pd_4 = pd.DataFrame(columns=['HADM_ID', 'CON_NOTE'])


while idx < len(final_note):
    hadm_id = final_note.loc[idx].HADM_ID
    visit_note = ' '
    note_part = final_note.loc[idx].NOTE
    start = note_part.upper().find('MEDICATIONS ON ADMISSION')
    end = note_part.upper().find('DISCHARGE MEDICATIONS')
    visit_note += note_part[start + len('MEDICATIONS ON ADMISSION:'):end].strip()
    while idx + 1 < len(final_note) and final_note.loc[idx + 1].HADM_ID == hadm_id:
        idx += 1
        visit_note += '  +  '
        note_part = final_note.loc[idx].NOTE
        start = note_part.upper().find('MEDICATIONS ON ADMISSION')
        end = note_part.upper().find('DISCHARGE MEDICATIONS')
        visit_note += note_part[start + len('MEDICATIONS ON ADMISSION:'):end].strip()
    result_pd_4.loc[len(result_pd_4)] = [hadm_id, visit_note]

    idx += 1

med_on_adm_pd = pd.DataFrame(columns=['HADM_ID', 'MEDICATIONS ON ADMISSION'])


for idx in trange(len(result_pd_4)):
    hadm_id = result_pd_4.loc[idx].HADM_ID
    note = result_pd_4.loc[idx].CON_NOTE
    prompt = '''
I'm going to give you an input, which is a bunch of text and some plus signs. I need you to extract all the drug names for me from each input, and output the corresponding list.

Here are some of the input and output sample:

input1:None.  +   Nifedipine XL, Calcitriol, Lisinopril, Aspirin, Lasix, Glyburide, Clonidine, Zoloft, Simvastatin, Tums, Procrit, Lupron, Niferex.

output1:[Nifedipine XL, Calcitriol, Lisinopril, Aspirin, Lasix, Glyburide, Clonidine, Zoloft, Simvastatin, Tums, Procrit, Lupron, Niferex]

input2: The patient was taking Aspirin, Atovaquone, Levofloxacin  +  The patient was on multiple medications including Emtriva, Lisinoprol, Metoprolol, Stavudine.

output2:[Aspirin, Atovaquone, Levofloxacin, Emtriva, Lisinoprol, Metoprolol, Stavudine]

Now you need to provide the corresponding output of input3, without any other words:

input3:''' + note + '''

You only need to output a list!

output3:
'''
    result = get_msg(prompt)

    med_on_adm_pd.loc[len(med_on_adm_pd)] = [hadm_id, result]

    dill.dump(med_on_adm_pd, open(note_p4_path, 'wb'))

## combine

In [None]:
history_of_present_illness_pd = dill.load(open(note_p1_path, 'rb'))
past_medical_history_pd = dill.load(open(note_p2_path, 'rb'))
allergies_pd = dill.load(open(note_p3_path, 'rb'))
med_on_adm_pd = dill.load(open(note_p4_path, 'rb'))

result_pd_5 = pd.DataFrame(columns=['HADM_ID', 'NOTE_CONTENT'])

for idx in range(len(history_of_present_illness_pd)):
    hadm_id = history_of_present_illness_pd.loc[idx].HADM_ID
    history_note = history_of_present_illness_pd.iloc[idx, 1]

    past_medical_history_note = past_medical_history_pd[past_medical_history_pd.HADM_ID == hadm_id].iloc[0, 1]
    allergies_note = allergies_pd[allergies_pd.HADM_ID == hadm_id].iloc[0, 1]
    med_on_adm_note = med_on_adm_pd[med_on_adm_pd.HADM_ID == hadm_id].iloc[0, 1]
    
    note_content = 'History of present illness: ' + history_note + ',\nPast medical history: ' + past_medical_history_note + ',\nAllergies: ' + allergies_note + ',\nMedications on admission: ' + med_on_adm_note

    result_pd_5.loc[len(result_pd_5)] = [hadm_id, note_content]
    
dill.dump(result_pd_5, open(note_content_path, 'wb'))

## generate data4LLM_with_note.csv

In [16]:
original_data = pd.read_csv(data4LLM_path)
note_content = dill.load(open(note_content_path, 'rb'))

data4LLM_with_note = []

# 逐行读取original_data
# 读取每一行下的HADM_ID，看一下是否在note_content中有对应行
# 如果有，则将original_data这一行的内容与note_content这一行的NOTE_CONTENT拼接在一起，加入到data4LLM_with_note中
for index, row in original_data.iterrows():
    hadm_id = row['HADM_ID']
    if hadm_id in note_content['HADM_ID'].values:
        note = note_content[note_content['HADM_ID'] == hadm_id].iloc[0, 1]
        data4LLM_with_note.append(row.tolist() + [note])

pd.DataFrame(data4LLM_with_note, columns=original_data.columns.tolist() + ['NOTE']).to_csv(data4LLM_with_note_path, index=False)

    
