In [1]:
import os
import dill
import jsonlines
import pandas as pd
from utils import *

In [2]:
class Voc(object):
    '''Define the vocabulary (token) dict'''

    def __init__(self):

        self.idx2word = {}
        self.word2idx = {}

    def add_sentence(self, sentence):
        '''add vocabulary to dict via a list of words'''
        for word in sentence:
            if word not in self.word2idx:
                self.idx2word[len(self.word2idx)] = word
                self.word2idx[word] = len(self.word2idx)

# create voc set
def create_str_token_mapping(df, vocabulary_file):
    diag_voc = Voc()
    med_voc = Voc()
    pro_voc = Voc()

    for index, row in df.iterrows():
        diag_voc.add_sentence(row["ICD9_CODE"])
        med_voc.add_sentence(row["ATC3"])
        pro_voc.add_sentence(row["PRO_CODE"])

    dill.dump(
        obj={"diag_voc": diag_voc, "med_voc": med_voc, "pro_voc": pro_voc},
        file=open(vocabulary_file, "wb"),
    )
    return diag_voc, med_voc, pro_voc

In [3]:
tokenizer = dill.load(open("./handled/voc_final.pkl", "rb"))

## Step 1:
Preprocess the raw MIMIC-III data as the original medication recommendation works

In [4]:
base_dir = ""   # base folder

## Some auxiliary info, such as DDI, ATC and ICD
RXCUI2atc4_file = os.path.join(base_dir, "./auxiliary/RXCUI2atc4.csv")
cid2atc6_file = os.path.join(base_dir, "./auxiliary/drug-atc.csv")
ndc2RXCUI_file = os.path.join(base_dir, "./auxiliary/ndc2RXCUI.txt")
ddi_file = os.path.join(base_dir, "./auxiliary/drug-DDI.csv")
drugbankinfo = os.path.join(base_dir, "./auxiliary/drugbank_drugs_info.csv")

In [5]:
med_file = os.path.join(base_dir, "./raw/PRESCRIPTIONS.csv")
diag_file = os.path.join(base_dir, "./raw/DIAGNOSES_ICD.csv")
procedure_file = (
    os.path.join(base_dir, "./raw/PROCEDURES_ICD.csv")
)

# input auxiliary files
med_structure_file = os.path.join(base_dir, "./handled/atc32SMILES.pkl")

# output files
ddi_adjacency_file = os.path.join(base_dir, "./handled/full/ddi_A_final.pkl")
ehr_adjacency_file = os.path.join(base_dir, "./handled/full/ehr_adj_final.pkl")
ehr_sequence_file = os.path.join(base_dir, "./handled/full/records_final.pkl")
vocabulary_file = os.path.join(base_dir, "./handled/full/voc_final.pkl")
ddi_mask_H_file = os.path.join(base_dir, "./handled/full/ddi_mask_H.pkl")
atc3toSMILES_file = os.path.join(base_dir, "./handled/full/atc3toSMILES.pkl")

In [None]:
# for med
med_pd = med_process(med_file)  # process the raw file
# med_pd_lg2 = process_visit_lg1(med_pd).reset_index(drop=True)   # remain the single-visit
med_pd_lg2 = process_visit_lg2(med_pd).reset_index(drop=True)   # filter out the patient has less 2 visits
med_pd = med_pd.merge(
    med_pd_lg2[["SUBJECT_ID"]], on="SUBJECT_ID", how="inner"
).reset_index(drop=True)

med_pd = codeMapping2atc4(med_pd, ndc2RXCUI_file, RXCUI2atc4_file)
med_pd = filter_300_most_med(med_pd)

# med to SMILES mapping
atc3toDrug = ATC3toDrug(med_pd)
druginfo = pd.read_csv(drugbankinfo)
atc3toSMILES = atc3toSMILES(atc3toDrug, druginfo)
dill.dump(atc3toSMILES, open(atc3toSMILES_file, "wb"))
med_pd = med_pd[med_pd.ATC3.isin(atc3toSMILES.keys())]
print("complete medication processing")

# for diagnosis
diag_pd = diag_process(diag_file)

print("complete diagnosis processing")

# for procedure
pro_pd = procedure_process(procedure_file)
# pro_pd = filter_1000_most_pro(pro_pd)

print("complete procedure processing")

# combine
data = combine_process(med_pd, diag_pd, pro_pd)
print("complete combining")

In [7]:
# statistics(data)

In [8]:
# # create vocab
# diag_voc, med_voc, pro_voc = create_str_token_mapping(data, vocabulary_file)
# print("obtain voc")

# # create ehr sequence data
# records = create_patient_record(data, diag_voc, med_voc, pro_voc, ehr_sequence_file)
# print("obtain ehr sequence data")

# # create ddi adj matrix
# ddi_adj = get_ddi_matrix(records, med_voc, ddi_file, cid2atc6_file, ehr_adjacency_file, ddi_adjacency_file)
# print("obtain ddi adj matrix")

# # get ddi_mask_H
# ddi_mask_H = get_ddi_mask(atc3toSMILES, med_voc)
# dill.dump(ddi_mask_H, open(ddi_mask_H_file, "wb"))

## Step 2: Get side info
Extract side information of patients from other csv

In [9]:
def get_side(source_df, side_df, side_columns, aligh_column):

    side_df = side_df[side_columns]
    source_df = pd.merge(source_df, side_df, how="left", on=aligh_column)

    return source_df

In [10]:
admission = pd.read_csv("./raw/ADMISSIONS.csv")
data = get_side(data, admission, 
                ["HADM_ID", "INSURANCE", "LANGUAGE", "RELIGION", "MARITAL_STATUS", "ETHNICITY", "DIAGNOSIS"],
                "HADM_ID"
                )

In [11]:
data.fillna(value="unknown", inplace=True)

## Step 3: Map ATC to drugname
Resolve the mapping. In the original preprocessed data, the drug is represented by ATC code, but we need the drugname for LLM.


In [12]:
RXCUI2atc4 = pd.read_csv(RXCUI2atc4_file)
RXCUI2atc4["NDC"] = RXCUI2atc4["NDC"].map(lambda x: x.replace("-", ""))
with open(ndc2RXCUI_file, "r") as f:
    ndc2RXCUI = eval(f.read())

In [13]:
RXCUI2ndc = dict(zip(ndc2RXCUI.values(), ndc2RXCUI.keys()))
RXCUI2atc4["RXCUI"] = RXCUI2atc4["RXCUI"].astype("str")
RXCUI2atc4["NDC"] = RXCUI2atc4["RXCUI"].map(RXCUI2ndc)
RXCUI2atc4.dropna(axis=0, how="any", inplace=True)
RXCUI2atc4.drop_duplicates(inplace=True)

In [14]:
RXCUI2atc4.shape, RXCUI2atc4.nunique()

((32732, 5),
 YEAR       73
 MONTH      12
 NDC      2037
 RXCUI    2037
 ATC4      445
 dtype: int64)

In [15]:
RXCUI2atc4.drop_duplicates(inplace=True)
RXCUI2atc4.shape

(32732, 5)

In [16]:
med_pd.head(5)

Unnamed: 0,SUBJECT_ID,HADM_ID,STARTDATE,DRUG,ATC3
0,17,161087,2135-05-09,Acetaminophen,N02B
1,17,194023,2134-12-27,Acetaminophen,N02B
2,21,111970,2135-02-06,Acetaminophen,N02B
3,23,152223,2153-09-03,Acetaminophen,N02B
4,36,122659,2131-05-15,Acetaminophen,N02B


In [17]:
med_pd = pd.read_csv(med_file, dtype={"NDC": "category"})
med_pd.head(5)

  med_pd = pd.read_csv(med_file, dtype={"NDC": "category"})


Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,ICUSTAY_ID,STARTDATE,ENDDATE,DRUG_TYPE,DRUG,DRUG_NAME_POE,DRUG_NAME_GENERIC,FORMULARY_DRUG_CD,GSN,NDC,PROD_STRENGTH,DOSE_VAL_RX,DOSE_UNIT_RX,FORM_VAL_DISP,FORM_UNIT_DISP,ROUTE
0,2214776,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Tacrolimus,Tacrolimus,Tacrolimus,TACR1,21796.0,469061711,1mg Capsule,2,mg,2,CAP,PO
1,2214775,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Warfarin,Warfarin,Warfarin,WARF5,6562.0,56017275,5mg Tablet,5,mg,1,TAB,PO
2,2215524,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Heparin Sodium,,,HEPAPREMIX,6522.0,338055002,"25,000 unit Premix Bag",25000,UNIT,1,BAG,IV
3,2216265,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,BASE,D5W,,,HEPBASE,,0,HEPARIN BASE,250,ml,250,ml,IV
4,2214773,6,107064,,2175-06-11 00:00:00,2175-06-12 00:00:00,MAIN,Furosemide,Furosemide,Furosemide,FURO20,8208.0,54829725,20mg Tablet,20,mg,1,TAB,PO


In [18]:
med_pd["NDC"].astype("str")
med_pd = pd.merge(med_pd, RXCUI2atc4, how="left", on="NDC")

In [19]:
atc2drug = pd.read_csv("./auxiliary/WHO ATC-DDD 2021-12-03.csv")
atc2drug["code_len"] = atc2drug["atc_code"].map(lambda x: len(x))
atc2drug = atc2drug[atc2drug["code_len"]==4]    # all levels are included. We only need the 4th level, i.e., ATC4
atc2drug.rename(columns={"atc_code": "ATC4"}, inplace=True)
atc2drug.drop(columns=["ddd", "uom", "adm_r", "note", "code_len"], axis=1, inplace=True)

In [20]:
atc2drug.head(10)

Unnamed: 0,ATC4,atc_name
2,A01A,STOMATOLOGICAL PREPARATIONS
46,A02A,ANTACIDS
79,A02B,DRUGS FOR PEPTIC ULCER AND GASTRO-OESOPHAGEAL ...
147,A02X,OTHER DRUGS FOR ACID RELATED DISORDERS
149,A03A,DRUGS FOR FUNCTIONAL GASTROINTESTINAL DISORDERS
219,A03B,"BELLADONNA AND DERIVATIVES, PLAIN"
235,A03C,ANTISPASMODICS IN COMBINATION WITH PSYCHOLEPTICS
255,A03D,ANTISPASMODICS IN COMBINATION WITH ANALGESICS
267,A03E,ANTISPASMODICS AND ANTICHOLINERGICS IN COMBINA...
270,A03F,PROPULSIVES


In [21]:
RXCUI2atc4["ATC4"] = RXCUI2atc4["ATC4"].map(lambda x: x[:4])

In [22]:
# all atc code in original data can be mapped to drugname by atc2drug.
# means that we use the same data as the traditional medication recommendation models.
pd.merge(RXCUI2atc4, atc2drug, on="ATC4", how="left")["atc_name"].isna().sum()

0

In [23]:
atc2drug["atc_name"] = atc2drug["atc_name"].map(lambda x: x.lower())

In [24]:
# get the atc2drug and drug2atc mapping dict
atc2drug_dict = dict(zip(atc2drug["ATC4"].values, atc2drug["atc_name"].values))
drug2atc_dict = dict(zip(atc2drug["atc_name"].values, atc2drug["ATC4"].values))

In [25]:
# import json
# json.dump({"atc2drug": atc2drug_dict, "drug2atc": drug2atc_dict}, open("./handled/full_atc2drug.json", "w"))

In [26]:
# get the diagnosis and procedure mapping dict, which both use the ICD. these mappings are in raw MIMIC dataset
icd2diag = pd.read_csv("./raw/D_ICD_DIAGNOSES.csv")
icd2diag_dict = dict(zip(icd2diag["ICD9_CODE"].astype(str).values, icd2diag["SHORT_TITLE"].values))

In [27]:
icd2proc = pd.read_csv("./raw/D_ICD_PROCEDURES.csv")
icd2proc_dict = dict(zip(icd2proc["ICD9_CODE"].astype(str).values, icd2proc["SHORT_TITLE"].values))

In [28]:
def decode(code_list, decoder):
    # decode a list of code into corresponding names
    miss_match = 0
    target_list = []
    for code in code_list:
        try:
            target_list.append(decoder[code])
        except:
            miss_match += 1
    
    #print(miss_match)

    return target_list

In [29]:
data["drug"] = data["ATC3"].map(lambda x: decode(x, atc2drug_dict))
data["diagnosis"] = data["ICD9_CODE"].map(lambda x: decode(x, icd2diag_dict))
data["procedure"] = data["PRO_CODE"].map(lambda x: decode(x, icd2proc_dict))

some miss matches occurs in diagnosis and procedures, but no for drug

In [30]:
data.iloc[1]["PRO_CODE"]

['3571', '3961', '8872']

In [31]:
def profile_tokenization(df, profile_columns):
    prof_dict = {"word2idx":{}, "idx2word": {}}
    for prof in profile_columns:
        prof_dict["idx2word"][prof] = dict(zip(range(df[prof].nunique()), df[prof].unique()))
        prof_dict["word2idx"][prof] = dict(zip(df[prof].unique(), range(df[prof].nunique())))
    return prof_dict

In [32]:
# profile_dict = profile_tokenization(data, ["INSURANCE", "LANGUAGE", "RELIGION", "MARITAL_STATUS", "ETHNICITY"])
# json.dump(profile_dict, open("./handled/full_profile_dict.json", "w"))

## Step 4: Construct Prompt
Design the prompt templates and construct the prompt

In [33]:
# prompt templates
main_template = "The patient has <VISIT_NUM> times ICU visits. \n <HISTORY> In this visit, he has diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. Then, the patient should be prescribed: "
hist_template = "In <VISIT_NO> visit, the patient had diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. The patient was prescribed drugs: <MEDICATION>. \n"

In [34]:
# add some patient's profiles
# main_template = "The patient's insurance type is <INSU>, language is <LANG>, religion is <RELIGION>, marital status is <MARITAL>, ethnicity is <ETHN>. The patient has <VISIT_NUM> times ICU visits. \n <HISTORY> In this visit, he has diagnosis: <DIAGNOSIS>; procedures: <PROCEDURE>. Then, the patient should be prescribed: "

In [35]:
def concat_str(str_list):
    # concat a list of drug / diagnosis / procedures
    target_str = ""
    for meta_str in str_list:
        target_str = target_str + meta_str + ", "
    target_str = target_str[:-2]    # remove the last comma

    return target_str

In [36]:
llm_data = []

for subject_id in data["SUBJECT_ID"].unique():
    item_df = data[data["SUBJECT_ID"] == subject_id]
    visit_num = item_df.shape[0] - 1
    patient = []

    profile = item_df.iloc[0]
    patient_str = main_template.replace("<INSU>", profile["INSURANCE"].lower())\
                               .replace("<LANG>", profile["LANGUAGE"].lower())\
                               .replace("<RELIGION>", profile["RELIGION"].lower())\
                               .replace("<MARITAL>", profile["MARITAL_STATUS"].lower())\
                               .replace("<ETHN>", profile["ETHNICITY"].lower())

    patient_profile = {"INSURANCE": profile["INSURANCE"], "LANGUAGE": profile["LANGUAGE"],
                       "RELIGION": profile["RELIGION"], "MARITAL_STATUS": profile["MARITAL_STATUS"],
                       "ETHNICITY": profile["ETHNICITY"]}

    # get each historical visit string
    for visit_no, (_, row) in enumerate(item_df.iterrows()):
        drug, diag, proc = concat_str(row["drug"]), concat_str(row["diagnosis"]), concat_str(row["procedure"])
        patient.append(hist_template.replace("<VISIT_NO>", str(visit_no+1))\
                                    .replace("<DIGNOSIS>", diag)\
                                    .replace("<PROCEDURE>", proc)\
                                    .replace("<MEDICATION>", drug))
    patient.pop()   # remove the ground truth record

    # filter out the patients with more than N times visits
    if len(patient) > 3:
        patient = patient[-3:]

    # concat all historical visit strings and get hist strings
    hist_str = ""
    for meta_hist in patient:
        hist_str += meta_hist
    
    patient_str = patient_str.replace("<VISIT_NUM>", str(visit_num))\
                             .replace("<HISTORY>", hist_str)\
                             .replace("<DIAGNOSIS>", diag)\
                             .replace("<PROCEDURE>", proc)
    
    drug_code = [str(x) for x in row["ATC3"]]

    hist = {"diagnosis": [], "procedure": [], "medication": []}
    for _, row in item_df.iterrows():
        hist["diagnosis"].append([str(x) for x in row["ICD9_CODE"]])
        hist["procedure"].append([str(x) for x in row["PRO_CODE"]])
        hist["medication"].append([str(x) for x in row["ATC3"]])
        
    llm_data.append({"input": patient_str, "target": drug, 
                     "subject_id": int(subject_id), "drug_code": drug_code,
                     "records": hist, "profile": patient_profile})
        

In [37]:
file_path = "./handled/"

def read_data(data_path):
    '''read data from jsonlines file'''
    data = []

    with jsonlines.open(file_path + data_path, "r") as f:
        for meta_data in f:
            data.append(meta_data)

    return data


def save_data(data_path, data):
    '''write all_data list to a new jsonl'''
    with jsonlines.open(file_path + data_path, "w") as w:
        for meta_data in data:
            w.write(meta_data)

In [38]:
# split the dataset: 8:1:1
train_split = int(len(llm_data) * 0.8)
val_split = int(len(llm_data) * 0.1)
train = llm_data[:train_split]
val = llm_data[train_split:train_split+val_split]
test = llm_data[train_split+val_split:]

In [44]:
# cut too long sequences
save_data("train_0105.json", train)
save_data("val_0105.json", val)
save_data("test_0105.json", test)