# Extract Word Embeddings from BERT

In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
import re
import string
import numpy as np
from nltk.corpus import stopwords 
from nltk.stem import WordNetLemmatizer
from nltk.stem import LancasterStemmer
from simpletransformers.classification import MultiLabelClassificationModel
import logging
import custom_sentence_tokenizer
import matplotlib.pyplot as plt
from scipy import stats
from ast import literal_eval
import pickle
import time

In [2]:
#imp.reload(custom_sentence_tokenizer)

In [3]:
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT")
model = AutoModel.from_pretrained("emilyalsentzer/Bio_Discharge_Summary_BERT")

## 1. Read and preprocess Discharge summaries

In [4]:
NOTEEVENTS_DF = pd.read_csv('C:/Users/kfpj179/Desktop/Final Project/data/NOTEEVENTS.csv', 
                            sep=',', header = 0,
                            usecols = lambda column : column not in ["CHARTTIME" , "STORETIME"])

#### 1.1 Concat all discharge summaries at admission level

In [5]:
NOTEEVENTS_DISCHARGE_DF = NOTEEVENTS_DF[(NOTEEVENTS_DF['CATEGORY']=='Discharge summary') & 
                                       (NOTEEVENTS_DF['ISERROR'].isnull())][['HADM_ID','CATEGORY','TEXT']]
NOTEEVENTS_DISCHARGE_DF['HADM_ID'].count()

59652

In [6]:
NOTEEVENTS_DISCHARGE_COMB_DF = NOTEEVENTS_DISCHARGE_DF.astype({'HADM_ID': 'str'}).copy()
NOTEEVENTS_DISCHARGE_COMB_DF['TEXT'] = NOTEEVENTS_DISCHARGE_COMB_DF.groupby(['HADM_ID','CATEGORY'])['TEXT']\
                            .transform(lambda x: ','.join(x))

In [7]:
NOTEEVENTS_DISCHARGE_COMB_DF = NOTEEVENTS_DISCHARGE_COMB_DF.drop_duplicates()

In [8]:
NOTEEVENTS_DISCHARGE_COMB_DF['HADM_ID'] = NOTEEVENTS_DISCHARGE_COMB_DF['HADM_ID'].transform(lambda x:x[:6])
NOTEEVENTS_DISCHARGE_COMB_DF.head()

Unnamed: 0,HADM_ID,CATEGORY,TEXT
0,167853,Discharge summary,Admission Date: [**2151-7-16**] Dischar...
1,107527,Discharge summary,Admission Date: [**2118-6-2**] Discharg...
2,167118,Discharge summary,Admission Date: [**2119-5-4**] D...
3,196489,Discharge summary,Admission Date: [**2124-7-21**] ...
4,135453,Discharge summary,Admission Date: [**2162-3-3**] D...


In [9]:
NOTEEVENTS_DISCHARGE_COMB_DF['HADM_ID'].count()
NOTEEVENTS_DF = None
NOTEEVENTS_DISCHARGE_DF = None

## 2 Preprocessing Layer

In [10]:
#NOTEEVENTS_DISCHARGE_NLP_DF = NOTEEVENTS_DISCHARGE_COMB_DF.sample(n=10, random_state=1).copy()
NOTEEVENTS_DISCHARGE_NLP_DF = NOTEEVENTS_DISCHARGE_COMB_DF.copy()
NOTEEVENTS_DISCHARGE_NLP_DF.head(5)

Unnamed: 0,HADM_ID,CATEGORY,TEXT
27902,194356,Discharge summary,Admission Date: [**2186-4-11**] ...
18286,175191,Discharge summary,Admission Date: [**2118-12-26**] Discharg...
51539,147153,Discharge summary,Admission Date: [**2194-6-27**] ...
55968,137467,Discharge summary,Admission Date: [**2150-4-11**] ...
44874,139922,Discharge summary,Admission Date: [**2181-3-8**] D...


In [11]:
NOTEEVENTS_DISCHARGE_COMB_DF = None

### 2.1 Replace medical abbrivations

In [13]:
    replace_LIST = [
                     ['dr\.','']
                    ,['DR\.','']
                    ,['m\.d\.','']
                    ,['M\.D\.','']
                    ,['p\.o', 'orally']
                    ,['P\.O', 'orally']
                    ,['q\.d\.', 'once a day']
                    ,['Q\.D\.', 'once a day']
                    ,['I\.M\.', 'intramuscularly']
                    ,['i\.m\.', 'intramuscularly']
                    ,['b\.i\.d\.', 'twice a day']
                    ,['B\.I\.D\.', 'twice a day']
                    ,['Subq\.', 'subcutaneous']
                    ,['SUBQ\.', 'subcutaneous']
                    ,['t\.i\.d\.', 'three times a day']
                    ,['T\.I\.D\.', 'three times a day']
                    ,['q\.i\.d\.', 'four times a day']
                    ,['Q\.I\.D\.', 'four times a day']
                    ,['I\.V\.', 'intravenous']
                    ,['i\.v\.', 'intravenous']
                    ,['q\.h\.s\.', 'before bed']
                    ,['Q\.H\.S\.', 'before bed']
                    ,['O\.D\.', 'in the right eye']
                    ,['o\.d\.', 'in the right eye']
                    ,['5X', 'a day five times a day']
                    ,['5x', 'a day five times a day']
                    ,['O\.S\.', 'in the left eye']
                    ,['o\.s\.', 'in the left eye']
                    ,['q\.4h', 'every four hours']
                    ,['Q\.4H', 'every four hours']
                    ,['O\.U\.', 'in both eyes']
                    ,['o\.u\.', 'in both eyes']
                    ,['q\.6h', 'every six hours']
                    ,['Q\.6H', 'every six hours']
                    ,['q\.o\.d\.', 'every other day']
                    ,['Q\.O\.D\.', 'every other day']
                    ,['prn\.', 'as needed']
                    ,['PRN\.', 'as needed']
                    ,['[0-9]+\.','']
                    ,[r'\[\*.+\*\]','']
                    ]

    def preprocess_re_sub(x):
        processed_text = x
        for find,replace in replace_LIST:
            processed_text=re.sub(find,replace,processed_text)
        return processed_text

In [14]:
NOTEEVENTS_DISCHARGE_NLP_DF['TEXT'] = NOTEEVENTS_DISCHARGE_NLP_DF['TEXT'].transform(lambda x:preprocess_re_sub(x))
NOTEEVENTS_DISCHARGE_NLP_DF['ID'] = NOTEEVENTS_DISCHARGE_NLP_DF.index

In [15]:
NOTEEVENTS_DISCHARGE_NLP_DF.head()

Unnamed: 0,HADM_ID,CATEGORY,TEXT,ID
27902,194356,Discharge summary,Admission Date: \n\nDate of Birth: ...,27902
18286,175191,Discharge summary,Admission Date: \n\nDate of Birth: Sex...,18286
51539,147153,Discharge summary,Admission Date: \n\nDate of Birth: ...,51539
55968,137467,Discharge summary,Admission Date: \n\nDate of Birth: ...,55968
44874,139922,Discharge summary,Admission Date: \n\nDate of Birth: ...,44874


### 2.2 Sentence tokenizer

In [16]:
NOTEEVENTS_DISCHARGE_NLP_DF['PREPROC_TEXT'] = NOTEEVENTS_DISCHARGE_NLP_DF['TEXT']\
                                    .transform(lambda x:custom_sentence_tokenizer.custom_sentence_tokenizer(x 
                                                        , testing = False, verbose = False))

In [17]:
def combine_sent(x):
    combined_sent = ''
    combined_sent_list = []
    for i in range(len(x)):
        sent = x[i]
        if sent!='.':
            sent = re.sub('^\s+|\n|\r',' ',
            re.sub('\s\s|\t|\.|\,||admission date:|discharge date:|date of birth:|addendum:|--|__|==','',
                   sent.lower())).strip()
            sent_len = len(sent.split(' '))
            combined_sent_len = len(combined_sent.split(' '))
            
            if i == 0:
                combined_sent = sent
                
            else:
                # when len of sentence + combined sent < 92, combine the existin combined list with current sentence
                if sent_len + combined_sent_len <= 64:
                    combined_sent = combined_sent + ' . ' + sent
                    if i == len(x) - 1:
                        combined_sent_list.append(combined_sent) 
                else:
                # when len is longer then append current combined sent into final list and reinitialize combined sent with current sent 
                    combined_sent_list.append(combined_sent)
                    combined_sent = sent
                
    return combined_sent_list

In [18]:
NOTEEVENTS_DISCHARGE_NLP_DF['PREPROC_TEXT_COMB'] = NOTEEVENTS_DISCHARGE_NLP_DF['PREPROC_TEXT']\
                                    .transform(lambda x:combine_sent(x))

### 2.3 Clinical BERT Word tokenizer

In [19]:
def create_token(x):
    tokens = []
    for i in x:
        tokens.append(tokenizer.convert_tokens_to_ids(tokenizer.tokenize(i)))
    return tokens

NOTEEVENTS_DISCHARGE_NLP_DF['TEXT_TOKENS'] = NOTEEVENTS_DISCHARGE_NLP_DF['PREPROC_TEXT_COMB']\
                                                .transform(lambda x:create_token(x))

### 2.4 Explore sentence length and count of tokens

In [20]:
arr = []
for j in range(10):
    arr.extend(np.array([len(i) for i in NOTEEVENTS_DISCHARGE_NLP_DF['TEXT_TOKENS'].iloc[j]]))
arr = np.array(arr)

In [21]:
for i in range(62, 260, 16):
    print(str(i) + ': '+str(len(arr[arr<=i])/len(arr)))

62: 0.13793103448275862
78: 0.3146551724137931
94: 0.6422413793103449
110: 0.8275862068965517
126: 0.9008620689655172
142: 0.9224137931034483
158: 0.9482758620689655
174: 0.9525862068965517
190: 0.9568965517241379
206: 0.9568965517241379
222: 0.9568965517241379
238: 0.9568965517241379
254: 0.9568965517241379


### 2.4 Merge with chapter labels

In [22]:
CHAPTER_LABEL_DF = pd.read_csv('C:/Users/kfpj179/Desktop/Final Project/data/chapter_label.csv', 
                            sep=',', header = 0).astype({'HADM_ID': 'str'})

In [23]:
NOTEEVENTS_CHAPTER_DF = NOTEEVENTS_DISCHARGE_NLP_DF[['HADM_ID','TEXT_TOKENS']].set_index('HADM_ID')\
            .join(CHAPTER_LABEL_DF.set_index('HADM_ID'), how='left')

## 3. Prepare input for BERT model

### 3.1 Truncate and zero pad

In [24]:
def truncate_sent(x):
    max_len = 126
    sent_list = []
    for i in x:
        sent_len = len(i)
        if sent_len < max_len:
            i.extend(np.zeros(max_len - sent_len).astype(int))
            sent_list.append(i)

        else:
            sent_list.append(i[:max_len])

    return sent_list

In [25]:
NOTEEVENTS_CHAPTER_DF['TEXT_TOKENS'] = NOTEEVENTS_CHAPTER_DF['TEXT_TOKENS'].transform(lambda x:truncate_sent(x))

### 3.2 Include cls and sep tags

In [26]:
def bert_tags(x):
    sent_list = []
    for i in x:
        sent = [101]
        sent.extend(i)
        sent.append(102)
        sent_list.append(sent)
    return sent_list

In [27]:
NOTEEVENTS_CHAPTER_DF['TEXT_TOKENS'] = NOTEEVENTS_CHAPTER_DF['TEXT_TOKENS'].transform(lambda x:bert_tags(x))

### 3.3 Attention mask

In [28]:
def attention_mask(x):
    max_word = 128
    sent_len = len(x)
    attention_mask = []
    for x_ITEM in x:
        ones_list = np.ones(max_word).astype(int)
        for j, token in enumerate(x_ITEM):
            if token==0:
                ones_list[j] = 0 
        attention_mask.append(ones_list)

    return attention_mask

In [29]:
NOTEEVENTS_CHAPTER_DF['ATTENTION'] = NOTEEVENTS_CHAPTER_DF['TEXT_TOKENS'].transform(lambda x:attention_mask(x))

## 4. Query BERT Pretrained weights to generate embeddings

### 4.1 Generate embeddings from clinical bert base model

In [30]:
# Approximately 6.5 seconds/ note on 32 GB, 1.9GHz i7 CPU; i.e, 100 hrs (4.5 days) for all notes
embeddings = {}
#count = 0
for i in NOTEEVENTS_CHAPTER_DF.index:
    #now = time.time()
    tokens = torch.tensor(NOTEEVENTS_CHAPTER_DF.loc[i, 'TEXT_TOKENS'])
    attention_mask = torch.tensor(NOTEEVENTS_CHAPTER_DF.loc[i, 'ATTENTION'])
    embeddings[i] = np.array(model(tokens,attention_mask=attention_mask)[0].data)
    #model.zero_grad()
    #count = count + 1
    #print(count)
    #print(time.time()-now)

### 4.2 Pickle the embeddings for later use

In [32]:
with open('embeddings.pickle', 'wb') as handle:
    pickle.dump(embeddings, handle)

In [31]:
embeddings_embedding_np = np.array(list(embeddings.items()))
embeddings_list = list(zip(embeddings_embedding_np[:,0], [np.float16(i) for i in embeddings_embedding_np[:,1]]))

In [34]:
with open('embeddings_list.pickle', 'wb') as handle:
    pickle.dump(embeddings_list, handle)