In [51]:
import pandas as pd
import json
import datetime
import collections
import numpy as np
from sklearn.model_selection import train_test_split

In [9]:
def processing_mimic3(file_adm, file_dxx, output_file):

    m_adm = pd.read_csv(file_adm, dtype={'HOSPITAL_EXPIRE_FLAG': object})
    m_dxx = pd.read_csv(file_dxx, dtype={'ICD9_CODE': object})

    # get total unique patients
    unique_pats = m_dxx.SUBJECT_ID.unique()

    patients = []  # store all preprocessed patients' data
    for sub_id in unique_pats:
        patient = dict()
        patient['pid'] = str(sub_id)
        pat_dxx = m_dxx[m_dxx.SUBJECT_ID == sub_id]  # get a specific patient's all data in dxx file
        uni_hadm = pat_dxx.HADM_ID.unique()  # get all unique admissions
        grouped = pat_dxx.groupby(['HADM_ID'])
        visits = []
        for hadm in uni_hadm:
            act = dict()
            adm = m_adm[(m_adm.SUBJECT_ID == sub_id) & (m_adm.HADM_ID == hadm)]
            admsn_dt = datetime.datetime.strptime(adm.ADMITTIME.values[0], "%Y-%m-%d %H:%M:%S")
            disch_dt = datetime.datetime.strptime(adm.DISCHTIME.values[0], "%Y-%m-%d %H:%M:%S")
            death_flag = adm.HOSPITAL_EXPIRE_FLAG.values[0]

            delta = disch_dt - admsn_dt
            act['admsn_dt'] = admsn_dt.strftime("%Y%m%d")
            act['day_cnt'] = str(delta.days + 1)

            codes = grouped.get_group(hadm)  # get all diagnosis codes in the adm
            DXs = []
            for index, row in codes.iterrows():
                dx = row['ICD9_CODE']
                # if dx is not NaN
                if dx == dx:
                    DXs.append(dx)

            act['DXs'] = DXs
            act['Death'] = death_flag
            visits.append(act)
        #print('patient {} is processed!'.format(sub_id))
        patient['visits'] = visits
        patients.append(patient)

    with open(output_file, 'w') as outfile:
        json.dump(patients, outfile)

    return patients



In [36]:
def data_clip(patients_full, visit_threshold, code_min_freq, output_file):
    patients = [patient for patient in patients_full if len(patient['visits']) >= visit_threshold]
    
    all_codes = []
    total_visits = 0
    
    for patient in patients:
        for visit in patient['visits']:
            total_visits += 1
            dxs = visit['DXs']
            for dx in dxs:
                all_codes.append(dx)
    
    count_org = []
    count_org.extend(collections.Counter(all_codes).most_common())
    
    count = []
    words_count = 0
    for word, c in count_org:
        word_tuple = [word, c]
        if c >= code_min_freq:
            count.append(word_tuple)
            words_count += c
    
    code_no_per_visit = words_count / total_visits
    print("visits per patient: {}".format(total_visits/len(patients)))
    print("diagnosis code per visit: {}".format(code_no_per_visit))
    
    code_dictionary = {}
    code_dictionary['PAD'] = 0
    for word, cnt in count:
        index = len(code_dictionary)
        code_dictionary[word] = index
    
    max_visits = 0
    max_len_visit = 0
    
    for p in patients:
        visits = p['visits']
        len_visits = len(visits)
        if len_visits > max_visits:
            max_visits = len_visits        
        for visit in visits:
            dxs = visit['DXs']
            if len(dxs) == 0:
                continue
            else:
                visit['DXs'] = [code_dictionary[dx] for dx in dxs if dx in code_dictionary]
            len_current_visit = len(visit['DXs'])   
            if len_current_visit > max_len_visit:
                max_len_visit = len_current_visit
            
    with open(output_file, 'w') as fp:
        json.dump(patients, fp)
    return patients, max_visits, max_len_visit, code_dictionary


In [65]:
def data_process(patients, valid_visits_threshold, max_len_visit, code_dictionary, output_file):
    batches = []
    n_zeros = 0
    for patient in patients:
        pid = patient['pid']
        # get patient's visits
        visits = patient['visits']
        # sorting visits by admission date
        sorted_visits = sorted(visits, key=lambda visit: visit['admsn_dt'])
        valid_visits = []
        for v in sorted_visits:
            if len(v['DXs']) > 0 and sum(v['DXs']) > 0:
                valid_visits.append(v)

        if (len(valid_visits)) < 2:
            continue    
            
        # number of visits and only use 10 visits to predict last one if number of visits is larger than 11
        no_visits = len(valid_visits)
        last_visit = valid_visits[no_visits - 1]
        second_last_visit = valid_visits[no_visits - 2]

        ls_codes = []
        ls_intervals = []
        # only use 10 visits to predict last one if number of visits is larger than 11
        if no_visits > valid_visits_threshold+1:
            feature_visits = valid_visits[no_visits-(valid_visits_threshold+1):no_visits-1]
        else:
            feature_visits = valid_visits[0:no_visits - 1]

        n_visits = len(feature_visits)

        
        first_valid_visit_dt = datetime.datetime.strptime(feature_visits[0]['admsn_dt'], "%Y%m%d")
        for i in range(n_visits):
            visit = feature_visits[i]
            codes = visit['DXs']

            if sum(codes) == 0:
                n_zeros += 1

            current_dt = datetime.datetime.strptime(visit['admsn_dt'], "%Y%m%d")
            interval = (current_dt - first_valid_visit_dt).days + 1
            ls_intervals.append(interval)
            code_size = len(codes)
            # code padding
            if code_size < max_len_visit:
                list_zeros = [0] * (max_len_visit - code_size)
                codes.extend(list_zeros)
            ls_codes.append(codes)

        # visit padding
        if n_visits < valid_visits_threshold:
            for i in range(valid_visits_threshold - n_visits):
                list_zeros = [0] * max_len_visit
                ls_codes.append(list_zeros)
                ls_intervals.append(0)


        last_dt = datetime.datetime.strptime(last_visit['admsn_dt'], "%Y%m%d")
        second_last_dt = datetime.datetime.strptime(second_last_visit['admsn_dt'], "%Y%m%d")
        days = (last_dt - second_last_dt).days
        if days <= 30:
            adm_label = 1
        else:
            adm_label = 0
        # --------- end readmission label --------------------     
        
        one_hot_labels = np.zeros(len(code_dictionary)).astype(int)
        last_codes = last_visit['DXs']
        for code in last_codes:
            index = code
            one_hot_labels[index] = 1
        
        batches.append(
                [np.array(ls_codes, dtype=np.int32), one_hot_labels, np.array([adm_label], dtype=np.int32), pid,
                 np.array(ls_intervals, dtype=np.int32)])

    codes = []
    dx_labels = []
    re_labels = []
    pids = []
    intervals = []
    for batch in batches:
        codes.append(batch[0])
        dx_labels.append(batch[1])
        re_labels.append(batch[2])
        pids.append(batch[3])
        intervals.append(batch[4])

    data = [codes, dx_labels, re_labels, pids, intervals]
    
    
    
    context_codes = data[0]
    labels_1 = data[1]
    labels_2 = data[2]
    pids = data[3]
    intervals = data[4]

    context_codes = np.array(context_codes, dtype=np.int32)
    intervals = np.array(intervals, dtype=np.int32)
    labels_1 = np.array(labels_1, dtype=np.int32)
    labels_2 = np.array(labels_2, dtype=np.int32)
    pids = np.array(pids, dtype=np.int32)

    train_context_codes, vt_context_codes, train_labels_1, vt_labels_1, \
    train_labels_2, vt_labels_2, train_pids, vt_pids, train_intervals, vt_intervals\
        = train_test_split(context_codes, labels_1, labels_2, pids, intervals, test_size=0.2, random_state=42)

    train_size = len(train_context_codes)

    dev_context_codes, test_context_codes, dev_labels_1, test_labels_1, dev_labels_2, \
    test_labels_2, dev_pids, test_pids, dev_intervals, test_intervals \
        = train_test_split(vt_context_codes, vt_labels_1, vt_labels_2, vt_pids,vt_intervals,
                           test_size=0.5, random_state=42)
    
    train = {"context_codes": train_context_codes.tolist(), "labels_1": train_labels_1.tolist(), "labels_2": train_labels_2.tolist(), "pids": train_pids.tolist(), "intervals": train_intervals.tolist()}
    dev = {"context_codes": dev_context_codes.tolist(), "labels_1": dev_labels_1.tolist(), "labels_2": dev_labels_2.tolist(), "pids": dev_pids.tolist(), "intervals": dev_intervals.tolist()}
    test = {"context_codes": test_context_codes.tolist(), "labels_1": test_labels_1.tolist(), "labels_2": test_labels_2.tolist(), "pids": test_pids.tolist(), "intervals": test_intervals.tolist()}
    
    #print(type(train["context_codes"]))
    
    with open(output_file + 'train.json', 'w') as fp:
        json.dump(train, fp)
    with open(output_file + 'dev.json', 'w') as fp:
        json.dump(dev, fp)
    with open(output_file + 'test.json', 'w') as fp:
        json.dump(test, fp)
        
    return train,dev,test

In [44]:
file_adm = './data/ADMISSIONS.csv'
file_dxx = './data/DIAGNOSES_ICD.csv'
output_file = './data/patients_mimic3_dx.json'

patients_full = processing_mimic3(file_adm, file_dxx, output_file)

In [45]:
len(patients_full)

46520

In [46]:
visit_threshold = 2
code_min_freq = 5
output_file = './data/patients_mimic3_dx_clip.json'
patients, max_visits, max_len_visit, code_dictionary = data_clip(patients_full, visit_threshold, code_min_freq, output_file)

visits per patient: 2.6526469417540137
diagnosis code per visit: 12.794127944780673


In [47]:
len(patients)

7537

In [66]:
output_file = './data/'
valid_visits_threshold = 10
train, dev, test = data_process(patients, valid_visits_threshold, max_len_visit, code_dictionary, output_file)

In [67]:
print(max_visits, max_len_visit)

42 39


In [71]:
len(train["context_codes"][5][0])

39

In [72]:
len(code_dictionary)

2438

In [74]:
len(train["context_codes"])

5992