# ETL & Data Preprocessing pipeline

In this notebook, i'll mainly cover using pandas to ETL 

In [1]:
import pandas as pd
pd.set_option('display.max_columns', None)

We read our diagnoses and procedures files into our pandas dataframe

In [2]:
icd_diagnoses = pd.read_csv('DIAGNOSES_ICD.csv')
icd_procedures = pd.read_csv('PROCEDURES_ICD.csv')

Take a look at the data

In [3]:
icd_diagnoses.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,1297,109,172335,1.0,40301
1,1298,109,172335,2.0,486
2,1299,109,172335,3.0,58281
3,1300,109,172335,4.0,5855
4,1301,109,172335,5.0,4254


In [4]:
icd_procedures.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE
0,944,62641,154460,3,3404
1,945,2592,130856,1,9671
2,946,2592,130856,2,3893
3,947,55357,119355,1,9672
4,948,55357,119355,2,331


We'll extract only the top 50 most frequent codes for further analysis. This is to reduce dimensionality, class imbalances and of course, to cut down our model training times.

In [5]:
procedures_top_50 = icd_procedures['ICD9_CODE'].value_counts()[:50].index.tolist()
diagnoses_top_50 = icd_diagnoses['ICD9_CODE'].value_counts()[:50].index.tolist()

I have extracted this reformat utility function from James 2018. It puts periods in the right place as the MIMIC-III datasets exclude them. This step is important.

In [6]:
# Reference: https://github.com/jamesmullenbach/caml-mimic/blob/master
def reformat(code, is_diag):
    """
        Put a period in the right place because the MIMIC-3 data files exclude them.
        Generally, procedure codes have dots after the first two digits,
        while diagnosis codes have dots after the first three digits.
    """
    code = ''.join(code.split('.'))
    if is_diag:
        if code.startswith('E'):
            if len(code) > 4:
                code = code[:4] + '.' + code[4:]
        else:
            if len(code) > 3:
                code = code[:3] + '.' + code[3:]
    else:
        code = code[:2] + '.' + code[2:]
    return code

We reformat the code before proceeding the combine the diagnoses and procedures dataframes

In [7]:
icd_diagnoses['tmp_code'] = icd_diagnoses['ICD9_CODE'].apply(lambda x: str(reformat(str(x), True)))
icd_procedures['tmp_code'] = icd_procedures['ICD9_CODE'].apply(lambda x: str(reformat(str(x), False)))

We combine the diagnoses and procedures into a single dataframe

In [8]:
combined_codes = pd.concat([icd_diagnoses, icd_procedures], axis=0)

We extract only the top 50 most frequent codes 

In [9]:
combined_codes_top_50_codes = combined_codes['tmp_code'].value_counts()[:50].index.tolist()

Take a look at the top 50 most frequent codes

In [10]:
combined_codes_top_50_codes

['401.9',
 '38.93',
 '428.0',
 '427.31',
 '414.01',
 '96.04',
 '96.6',
 '584.9',
 '96.71',
 '250.00',
 '272.4',
 '518.81',
 '99.04',
 '39.61',
 '599.0',
 '530.81',
 '96.72',
 '272.0',
 '99.55',
 'V05.3',
 'V29.0',
 '285.9',
 '88.56',
 '244.9',
 '486',
 '38.91',
 '285.1',
 '276.2',
 '496',
 '36.15',
 '99.15',
 '995.92',
 'V58.61',
 '038.9',
 '507.0',
 'V30.00',
 '88.72',
 '585.9',
 '311',
 '403.90',
 '305.1',
 '37.22',
 '412',
 '33.24',
 '39.95',
 '287.5',
 'V45.81',
 '410.71',
 '276.1',
 '424.0']

We peak the top of our combined dataframe using .head fn

In [12]:
combined_codes.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE,tmp_code
0,1297,109,172335,1.0,40301,403.01
1,1298,109,172335,2.0,486,486.0
2,1299,109,172335,3.0,58281,582.81
3,1300,109,172335,4.0,5855,585.5
4,1301,109,172335,5.0,4254,425.4


Since we can have multiple medical codes for a single admission, we combine the codes together delimited by semi-colons.

In [13]:
combined_codes['full_code'] = combined_codes.groupby(['HADM_ID'])['tmp_code'].transform(lambda x: ';'.join(x))
combined_codes.drop_duplicates(subset='full_code', inplace=True)
combined_codes.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE,tmp_code,full_code
0,1297,109,172335,1.0,40301,403.01,403.01;486;582.81;585.5;425.4;276.2;710.0;276....
14,1311,109,173633,1.0,40301,403.01,403.01;585.6;583.81;710.0;558.9;287.5;285.21;4...
28,1488,112,174105,1.0,53100,531.0,531.00;410.71;285.9;414.01;725;44.43;99.04;45.13
33,1493,113,109976,1.0,1915,191.5,191.5;331.4;530.81;15.9;23.9
36,1496,114,178393,1.0,41401,414.01,414.01;411.1;482.83;285.9;272.0;305.1;36.12;36...


We'll need to pre-process the target labels/codes into a multi-hot encoding format for BERT. You do not need to do this for LR/CNN/BiGRU/CAML since it is incorporated into the training code.

In [14]:
for code in combined_codes_top_50_codes:
    combined_codes[code] = combined_codes['full_code'].apply(
        lambda x: True if code in x else False)

The columns are nicely generated

In [15]:
combined_codes.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE,tmp_code,full_code,401.9,38.93,428.0,427.31,414.01,96.04,96.6,584.9,96.71,250.00,272.4,518.81,99.04,39.61,599.0,530.81,96.72,272.0,99.55,V05.3,V29.0,285.9,88.56,244.9,486,38.91,285.1,276.2,496,36.15,99.15,995.92,V58.61,038.9,507.0,V30.00,88.72,585.9,311,403.90,305.1,37.22,412,33.24,39.95,287.5,V45.81,410.71,276.1,424.0
0,1297,109,172335,1.0,40301,403.01,403.01;486;582.81;585.5;425.4;276.2;710.0;276....,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False
14,1311,109,173633,1.0,40301,403.01,403.01;585.6;583.81;710.0;558.9;287.5;285.21;4...,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,True,False,False,False,False
28,1488,112,174105,1.0,53100,531.0,531.00;410.71;285.9;414.01;725;44.43;99.04;45.13,False,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False
33,1493,113,109976,1.0,1915,191.5,191.5;331.4;530.81;15.9;23.9,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
36,1496,114,178393,1.0,41401,414.01,414.01;411.1;482.83;285.9;272.0;305.1;36.12;36...,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,True,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False


We check the number of occurences of each of the top-50 codes in our dataset

In [16]:
for code in combined_codes_top_50_codes:
    print(combined_codes[code].value_counts())

False    35631
True     20654
Name: 401.9, dtype: int64
False    43153
True     13132
Name: 38.93, dtype: int64
False    43174
True     13111
Name: 428.0, dtype: int64
False    43399
True     12886
Name: 427.31, dtype: int64
False    43877
True     12408
Name: 414.01, dtype: int64
False    46363
True      9922
Name: 96.04, dtype: int64
False    45895
True     10390
Name: 96.6, dtype: int64
False    47167
True      9118
Name: 584.9, dtype: int64
False    47341
True      8944
Name: 96.71, dtype: int64
False    47235
True      9050
Name: 250.00, dtype: int64
False    47597
True      8688
Name: 272.4, dtype: int64
False    48788
True      7497
Name: 518.81, dtype: int64
False    49087
True      7198
Name: 99.04, dtype: int64
False    49506
True      6779
Name: 39.61, dtype: int64
False    49731
True      6554
Name: 599.0, dtype: int64
False    49963
True      6322
Name: 530.81, dtype: int64
False    49925
True      6360
Name: 96.72, dtype: int64
False    50372
True      5913
Name: 272.0, d

We created this utility function below to extract only the top-50 codes from our MIMIC-III dataset

In [17]:
def check_code_top_50(codes):
    top_code_count = 0
    for code in codes.split(';'):
        if code in combined_codes_top_50_codes:
            top_code_count = top_code_count + 1
            #print(code)
            
        
    return str(top_code_count)

We keep only the rows with top-50 codes

In [18]:
combined_codes['top_codes'] = combined_codes['full_code'].apply(lambda x: check_code_top_50(x))

We inspect the dataframe visually for errors

In [19]:
combined_codes.head()

Unnamed: 0,ROW_ID,SUBJECT_ID,HADM_ID,SEQ_NUM,ICD9_CODE,tmp_code,full_code,401.9,38.93,428.0,427.31,414.01,96.04,96.6,584.9,96.71,250.00,272.4,518.81,99.04,39.61,599.0,530.81,96.72,272.0,99.55,V05.3,V29.0,285.9,88.56,244.9,486,38.91,285.1,276.2,496,36.15,99.15,995.92,V58.61,038.9,507.0,V30.00,88.72,585.9,311,403.90,305.1,37.22,412,33.24,39.95,287.5,V45.81,410.71,276.1,424.0,top_codes
0,1297,109,172335,1.0,40301,403.01,403.01;486;582.81;585.5;425.4;276.2;710.0;276....,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,3
14,1311,109,173633,1.0,40301,403.01,403.01;585.6;583.81;710.0;558.9;287.5;285.21;4...,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,True,False,False,False,False,4
28,1488,112,174105,1.0,53100,531.0,531.00;410.71;285.9;414.01;725;44.43;99.04;45.13,False,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,4
33,1493,113,109976,1.0,1915,191.5,191.5;331.4;530.81;15.9;23.9,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,1
36,1496,114,178393,1.0,41401,414.01,414.01;411.1;482.83;285.9;272.0;305.1;36.12;36...,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,True,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,6


We check the distribution of the entries

In [20]:
combined_codes['top_codes'].value_counts()

4     7053
3     7050
5     6290
2     5882
6     5603
7     4798
1     3888
8     3790
9     2851
0     2635
10    2040
11    1490
12    1023
13     700
14     449
15     275
16     174
17     120
18      69
19      41
20      28
21      15
22      11
24       4
23       3
26       1
25       1
27       1
Name: top_codes, dtype: int64

We remove entries where the codes do not contain any of the top-50 codes

In [21]:
combined_codes = combined_codes[combined_codes['top_codes']!='0']

We have a total of 53,650 admissions.

In [22]:
combined_codes.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 53650 entries, 0 to 651032
Data columns (total 58 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   ROW_ID      53650 non-null  int64  
 1   SUBJECT_ID  53650 non-null  int64  
 2   HADM_ID     53650 non-null  int64  
 3   SEQ_NUM     53650 non-null  float64
 4   ICD9_CODE   53650 non-null  object 
 5   tmp_code    53650 non-null  object 
 6   full_code   53650 non-null  object 
 7   401.9       53650 non-null  bool   
 8   38.93       53650 non-null  bool   
 9   428.0       53650 non-null  bool   
 10  427.31      53650 non-null  bool   
 11  414.01      53650 non-null  bool   
 12  96.04       53650 non-null  bool   
 13  96.6        53650 non-null  bool   
 14  584.9       53650 non-null  bool   
 15  96.71       53650 non-null  bool   
 16  250.00      53650 non-null  bool   
 17  272.4       53650 non-null  bool   
 18  518.81      53650 non-null  bool   
 19  99.04       53650 non-nu

Next, we proceed to extract the discharge summaries from MIMIC-III dataset.

In [23]:
clinical_notes_df = pd.read_csv('NOTEEVENTS.csv')
clinical_notes_df = clinical_notes_df[clinical_notes_df['CATEGORY']=='Discharge summary']

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


We utilize the regexptokenizer from NLTK to preprocess our discharge summaries, removing stop-words, retaining only alpha-numerics.

In [24]:
from nltk.tokenize import RegexpTokenizer

#retain only alphanumeric
tokenizer = RegexpTokenizer(r'\w+')
clinical_notes_df['TEXT_prop']= clinical_notes_df['TEXT'].apply(lambda note:[t.lower() for t in tokenizer.tokenize(note) if not t.isnumeric()])

We create a column to check on the len of our discharge summaries.

In [25]:
clinical_notes_df['TEXT_prop_len']= clinical_notes_df['TEXT_prop'].apply(lambda x: len(x))

Now, we combine our clinical notes/discharge summaries with our medical code using a left join on the admission ID. We drop rows that do not contain any discharge summaries and proceed to check on our dataframe with a .head().

In [26]:
codes_with_notes = pd.merge(combined_codes, clinical_notes_df, on='HADM_ID', how='left')
codes_with_notes = codes_with_notes[codes_with_notes['TEXT_prop'].notna()]
codes_with_notes.head()

Unnamed: 0,ROW_ID_x,SUBJECT_ID_x,HADM_ID,SEQ_NUM,ICD9_CODE,tmp_code,full_code,401.9,38.93,428.0,427.31,414.01,96.04,96.6,584.9,96.71,250.00,272.4,518.81,99.04,39.61,599.0,530.81,96.72,272.0,99.55,V05.3,V29.0,285.9,88.56,244.9,486,38.91,285.1,276.2,496,36.15,99.15,995.92,V58.61,038.9,507.0,V30.00,88.72,585.9,311,403.90,305.1,37.22,412,33.24,39.95,287.5,V45.81,410.71,276.1,424.0,top_codes,ROW_ID_y,SUBJECT_ID_y,CHARTDATE,CHARTTIME,STORETIME,CATEGORY,DESCRIPTION,CGID,ISERROR,TEXT,TEXT_prop,TEXT_prop_len
0,1297,109,172335,1.0,40301,403.01,403.01;486;582.81;585.5;425.4;276.2;710.0;276....,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,3,14797.0,109.0,2141-09-24,,,Discharge summary,Report,,,Admission Date: [**2141-9-18**] ...,"[admission, date, discharge, date, date, of, b...",2614.0
1,1311,109,173633,1.0,40301,403.01,403.01;585.6;583.81;710.0;558.9;287.5;285.21;4...,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,True,False,False,False,False,4,14801.0,109.0,2141-12-14,,,Discharge summary,Report,,,Admission Date: [**2141-12-8**] ...,"[admission, date, discharge, date, date, of, b...",1936.0
2,1488,112,174105,1.0,53100,531.0,531.00;410.71;285.9;414.01;725;44.43;99.04;45.13,False,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,4,54002.0,112.0,2194-06-18,,,Discharge summary,Report,,,Admission Date: [**2194-6-13**] Dischar...,"[admission, date, discharge, date, service, hi...",969.0
3,1493,113,109976,1.0,1915,191.5,191.5;331.4;530.81;15.9;23.9,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,1,10256.0,113.0,2140-12-27,,,Discharge summary,Report,,,Admission Date: [**2140-12-12**] Discha...,"[admission, date, discharge, date, date, of, b...",539.0
4,1496,114,178393,1.0,41401,414.01,414.01;411.1;482.83;285.9;272.0;305.1;36.12;36...,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,True,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,6,10754.0,114.0,2146-09-03,,,Discharge summary,Report,,,Admission Date: [**2146-8-29**] Dischar...,"[admission, date, discharge, date, date, of, b...",501.0


There are 57,000 entries in our preprocessed dataset.

In [27]:
codes_with_notes.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 57020 entries, 0 to 60369
Data columns (total 70 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   ROW_ID_x       57020 non-null  int64  
 1   SUBJECT_ID_x   57020 non-null  int64  
 2   HADM_ID        57020 non-null  int64  
 3   SEQ_NUM        57020 non-null  float64
 4   ICD9_CODE      57020 non-null  object 
 5   tmp_code       57020 non-null  object 
 6   full_code      57020 non-null  object 
 7   401.9          57020 non-null  bool   
 8   38.93          57020 non-null  bool   
 9   428.0          57020 non-null  bool   
 10  427.31         57020 non-null  bool   
 11  414.01         57020 non-null  bool   
 12  96.04          57020 non-null  bool   
 13  96.6           57020 non-null  bool   
 14  584.9          57020 non-null  bool   
 15  96.71          57020 non-null  bool   
 16  250.00         57020 non-null  bool   
 17  272.4          57020 non-null  bool   
 18  518.81

y is the labels and input to our BERT model. It is already in the correct format. 

In [28]:
y = codes_with_notes.iloc[:,7:57]#.reset_index().drop(columns='index')

We print the labels, y to perform a visual check.

In [29]:
y

Unnamed: 0,401.9,38.93,428.0,427.31,414.01,96.04,96.6,584.9,96.71,250.00,272.4,518.81,99.04,39.61,599.0,530.81,96.72,272.0,99.55,V05.3,V29.0,285.9,88.56,244.9,486,38.91,285.1,276.2,496,36.15,99.15,995.92,V58.61,038.9,507.0,V30.00,88.72,585.9,311,403.90,305.1,37.22,412,33.24,39.95,287.5,V45.81,410.71,276.1,424.0
0,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False
1,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,True,False,False,False,False
2,False,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False
3,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
4,False,False,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,True,False,False,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
60365,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
60366,True,False,True,True,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False
60367,True,False,True,True,False,False,True,False,False,True,True,False,False,False,True,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False
60368,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False


We join the tokenized text from earlier to form X. 

In [30]:
X = codes_with_notes['TEXT_prop'].apply(lambda x: ' '.join(x))#.reset_index().drop(columns='index')

We perform a visual check.

In [31]:
X

0        admission date discharge date date of birth se...
1        admission date discharge date date of birth se...
2        admission date discharge date service history ...
3        admission date discharge date date of birth se...
4        admission date discharge date date of birth se...
                               ...                        
60365    admission date discharge date date of birth se...
60366    admission date discharge date date of birth se...
60367    admission date discharge date date of birth se...
60368    admission date discharge date date of birth se...
60369    admission date discharge date date of birth se...
Name: TEXT_prop, Length: 57020, dtype: object

We perform a train-test split of 80-20 using the X/variables and y/labels.

In [32]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

We visually inspect the dataframe.

In [33]:
X_train.head()

8153     admission date discharge date date of birth se...
51131    admission date discharge date date of birth se...
55065    admission date discharge date date of birth se...
7368     admission date discharge date date of birth se...
29489    admission date discharge date service cardioth...
Name: TEXT_prop, dtype: object

We check the first 3 entries of the clinical text for errors.

In [34]:
X_train.iloc[0]

'admission date discharge date date of birth sex m service cardiac icu history of present illness mr known lastname is a year old gentleman with a complicated past medical history who was transferred from an outside hospital for persistent respiratory failure and pulmonary edema in addition he had staph aureus bacteremia and possible aortic valve endocarditis during his hospitalization he underwent cardiac catheterization twice with stent placement in multiple coronary arteries and was aggressively diuresed a transesopageal echocardiogram demonstrated no evidence of endocarditis he was continued on iv antibiotics for bacteremia of unclear source however his pulmonary edema failed to resolve and on his family decided that they would like to make the patient dnr to withdraw all supportive care except for comfort measures on the ventilator was changed to cpap and the patient was administered morphine sulfate intravenously and he passed away around p m on from respiratory failure consent f

Looks good!

In [35]:
X_train.iloc[1]

'admission date discharge date date of birth sex m service medicine allergies patient recorded as having no known allergies to drugs attending first name3 lf chief complaint loss of consciousness major surgical or invasive procedure none history of present illness initial history and physical is as per the hospital unit name resident year old man with history of etoh abuse and likely withdrawal seizures but on phenytoin in the past presented with a witnessed seizure the patient had a minute long tonic clonic seizure with loss of consciousness while walking down a street near his house denies any urine or stool incontinence woke up in the ambulance patient is not sure if he has had withdrawal seizures in the past he last drank a few days ago on presentation to the ed t hr sbp rr o2 sat ra serum etoh level was negative he received a banana bag lorazepam mg iv x on arrival to the icu the patient was oriented x with t now sbp in the 180s hr 80s ros the patient reports some nonproductive co

Success!

In [36]:
X_train.iloc[2]

'admission date discharge date date of birth sex m service neurology allergies diphenhydramine bee pollen phenytoin attending first name3 lf chief complaint seizure major surgical or invasive procedure intubated extubated history of present illness mr known lastname is intubated and sedated history obtained from transfer records review of omr and speaking with family mr known lastname is a year old man with pmh notable for right parietal lesion initially thought to be stroke vs low grade glioma has been stable on multiple repeat mris with the most recent being followed by dr last name stitle in hospital clinic and seizures unclear if post traumatic s p motorcycle accident or if seizure precipitated the accident he was previously on aeds but this was recently stopped who was transferred from osh with iph and seizure per transfer report he was having headaches beginning last night this morning at work around am his co workers thought he had a deer in headlights look and was confused he w

There are 45,616 rows in the training dataset.

In [37]:
pd.DataFrame(X_train).info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 45616 entries, 8153 to 59758
Data columns (total 1 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   TEXT_prop  45616 non-null  object
dtypes: object(1)
memory usage: 712.8+ KB


There are 11,404 rows in the validation dataset.

In [38]:
pd.DataFrame(X_val).info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 11404 entries, 28602 to 57644
Data columns (total 1 columns):
 #   Column     Non-Null Count  Dtype 
---  ------     --------------  ----- 
 0   TEXT_prop  11404 non-null  object
dtypes: object(1)
memory usage: 178.2+ KB


There are two utility functions namely process_X and process_y used to generate and split the inputs to the inputs of our 5 different BERT models. Each BERT model is sort of in-charge of 500-word sections. 

In [39]:
INDEX_SPECIFIED = 4 

INDEX_L = int(45616*INDEX_SPECIFIED)
INDEX_R = int(45616*(INDEX_SPECIFIED+1))

INDEX_L_VAL = int(11404*INDEX_SPECIFIED)
INDEX_R_VAL = int(11404*(INDEX_SPECIFIED+1))

SAVE_MODEL_NAME = 'model_fifth500only1epoch.bert'

def process_X(X_tmp, is_train=True):
    X_tmp = pd.DataFrame(X_tmp)
    X_tmp['bert_0_500'] = X_tmp['TEXT_prop'].apply(lambda x: ' '.join(x.split(' ')[0:500]))
    X_tmp['bert_500_1000'] = X_tmp['TEXT_prop'].apply(lambda x: ' '.join(x.split(' ')[500:1000]))
    X_tmp['bert_1000_1500'] = X_tmp['TEXT_prop'].apply(lambda x: ' '.join(x.split(' ')[1000:1500]))
    X_tmp['bert_1500_2000'] = X_tmp['TEXT_prop'].apply(lambda x: ' '.join(x.split(' ')[1500:2000]))
    X_tmp['bert_2000_2500'] = X_tmp['TEXT_prop'].apply(lambda x: ' '.join(x.split(' ')[2000:2500]))
    X_tmp['len'] = X_tmp['TEXT_prop'].apply(lambda x: len(x.split(' ')))

    X_train_tmp = X_tmp.drop(columns='TEXT_prop').copy()
    X_train_tmp = X_train_tmp.melt(id_vars=['len'], var_name='text', value_name='TEXT', ignore_index=False)
    X_train_unpivoted = X_train_tmp
    X_train_unpivoted.drop(columns=['len','text'],inplace=True)
    if is_train:
        return X_train_unpivoted.iloc[INDEX_L:INDEX_R,:]
    else:
        return X_train_unpivoted.iloc[INDEX_L_VAL:INDEX_R_VAL,:]

X_train = process_X(X_train)
X_val = process_X(X_val, is_train=False)

#from numpy import savetxt

#X_train.to_csv('X_train_0_500.csv')
#X_val.to_csv('X_val_0_500.csv')

X_train.head()

Unnamed: 0,TEXT
8153,
51131,
55065,
7368,
29489,


Just to be sure there are still 45,616 rows in the training dataset.

In [41]:
X_train.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 45616 entries, 8153 to 59758
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   TEXT    45616 non-null  object
dtypes: object(1)
memory usage: 712.8+ KB


We print the target labels to visually inspect for errors.

In [56]:
y_train

Unnamed: 0,401.9,38.93,428.0,427.31,414.01,96.04,96.6,584.9,96.71,250.00,272.4,518.81,99.04,39.61,599.0,530.81,96.72,272.0,99.55,V05.3,V29.0,285.9,88.56,244.9,486,38.91,285.1,276.2,496,36.15,99.15,995.92,V58.61,038.9,507.0,V30.00,88.72,585.9,311,403.90,305.1,37.22,412,33.24,39.95,287.5,V45.81,410.71,276.1,424.0
8153,False,False,True,False,True,False,False,True,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,True,False,False,True,False,False
51131,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
55065,True,False,False,False,True,False,False,False,True,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
7368,True,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
29489,True,False,False,False,True,False,False,False,False,False,False,False,True,True,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
57609,True,False,False,True,True,False,False,False,False,False,True,False,False,True,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
40947,True,False,True,False,True,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,True,False,False,False,False
936,True,False,True,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False
17109,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,True,True,True,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False,False


This is the utility function used to generate input labels for our 5 different BERT models. All the 5 BERT models are fed with the same target labels.

In [57]:
def process_y(y_tmp, is_train=True):
    if is_train==True:
        return pd.concat([y_tmp]*5, axis=0, ignore_index=True).iloc[INDEX_L:INDEX_R,:]
    else:
        return pd.concat([y_tmp]*5, axis=0, ignore_index=True).iloc[INDEX_L_VAL:INDEX_R_VAL,:]
y_train = process_y(y_train)
y_val = process_y(y_val, is_train=False)

#y_train.to_csv('y_train_0_500.csv')
#y_val.to_csv('y_val_0_500.csv')

We print the first entry and chcek for errors.

In [58]:
y_train.iloc[0,:]

401.9     False
38.93     False
428.0      True
427.31    False
414.01     True
96.04     False
96.6      False
584.9      True
96.71     False
250.00    False
272.4     False
518.81     True
99.04     False
39.61     False
599.0     False
530.81    False
96.72     False
272.0     False
99.55     False
V05.3     False
V29.0     False
285.9     False
88.56     False
244.9     False
486       False
38.91     False
285.1     False
276.2     False
496        True
36.15     False
99.15     False
995.92    False
V58.61    False
038.9     False
507.0     False
V30.00    False
88.72     False
585.9     False
311       False
403.90    False
305.1     False
37.22      True
412       False
33.24     False
39.95      True
287.5     False
V45.81    False
410.71     True
276.1     False
424.0     False
Name: 182464, dtype: bool

There are 45,616 rows in the training dataset.

In [59]:
y_train.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 45616 entries, 182464 to 228079
Data columns (total 50 columns):
 #   Column  Non-Null Count  Dtype
---  ------  --------------  -----
 0   401.9   45616 non-null  bool 
 1   38.93   45616 non-null  bool 
 2   428.0   45616 non-null  bool 
 3   427.31  45616 non-null  bool 
 4   414.01  45616 non-null  bool 
 5   96.04   45616 non-null  bool 
 6   96.6    45616 non-null  bool 
 7   584.9   45616 non-null  bool 
 8   96.71   45616 non-null  bool 
 9   250.00  45616 non-null  bool 
 10  272.4   45616 non-null  bool 
 11  518.81  45616 non-null  bool 
 12  99.04   45616 non-null  bool 
 13  39.61   45616 non-null  bool 
 14  599.0   45616 non-null  bool 
 15  530.81  45616 non-null  bool 
 16  96.72   45616 non-null  bool 
 17  272.0   45616 non-null  bool 
 18  99.55   45616 non-null  bool 
 19  V05.3   45616 non-null  bool 
 20  V29.0   45616 non-null  bool 
 21  285.9   45616 non-null  bool 
 22  88.56   45616 non-null  bool 
 23  244.9

There are 11,404 rows in the validation dataset.

In [60]:
y_val.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11404 entries, 45616 to 57019
Data columns (total 50 columns):
 #   Column  Non-Null Count  Dtype
---  ------  --------------  -----
 0   401.9   11404 non-null  bool 
 1   38.93   11404 non-null  bool 
 2   428.0   11404 non-null  bool 
 3   427.31  11404 non-null  bool 
 4   414.01  11404 non-null  bool 
 5   96.04   11404 non-null  bool 
 6   96.6    11404 non-null  bool 
 7   584.9   11404 non-null  bool 
 8   96.71   11404 non-null  bool 
 9   250.00  11404 non-null  bool 
 10  272.4   11404 non-null  bool 
 11  518.81  11404 non-null  bool 
 12  99.04   11404 non-null  bool 
 13  39.61   11404 non-null  bool 
 14  599.0   11404 non-null  bool 
 15  530.81  11404 non-null  bool 
 16  96.72   11404 non-null  bool 
 17  272.0   11404 non-null  bool 
 18  99.55   11404 non-null  bool 
 19  V05.3   11404 non-null  bool 
 20  V29.0   11404 non-null  bool 
 21  285.9   11404 non-null  bool 
 22  88.56   11404 non-null  bool 
 23  244.9  

# Done!

Note you'll need to re-run and generate the inputs for each of the BERT model changing the INDEX_SPECIFIED at different runs. It is probably a good idea to also to_csv() the x and y labels and differentiate them accordingly. 