In [18]:
import pandas as pd
import numpy as np
from string import punctuation
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
# from nltk import word_tokenize
from sklearn.model_selection import train_test_split
from transformers import BertModel, BertConfig, BertTokenizer
# from transformers import pipeline, FeatureExtractionPipeline

In [2]:
import torchtext.data as data
import torchtext.vocab as vocab

# Load BlueBERT Model

# Admissions Data

In [3]:
# strict copd coding
strict_icd9 = [
    "49120",
    "49121",
    "49122",
    "49320",
    "49321",
    "49322",
    "496",
]

# regular copd coding
reg_icd9 = [
    "4911",
    "4920",
    "4928",
]

print("Loading dx codes...")
df = pd.read_csv("~/Documents/data/mimic/DIAGNOSES_ICD.csv")
n = df.HADM_ID.nunique()
copd_hadmids = df[df.ICD9_CODE.isin(strict_icd9 + reg_icd9)].HADM_ID.unique()
n_copd = len(copd_hadmids)
print(f"Admission: {n}")
print(f"COPD Admissions: {n_copd}")

patients = pd.read_csv("~/Documents/data/mimic/PATIENTS.csv", parse_dates=['DOB', 'DOD', 'DOD_HOSP'])
print(f"Num patients: {patients['SUBJECT_ID'].nunique()}")
print(f"Num female: {patients[patients['GENDER'] == 'F']['SUBJECT_ID'].nunique()}")

admission_cols = [
    'HADM_ID',
    'ADMISSION_TYPE',
    'ADMITTIME',
    'DISCHTIME',
    'DEATHTIME',
    'EDREGTIME',
    'EDOUTTIME',
    'HOSPITAL_EXPIRE_FLAG',
    'HAS_CHARTEVENTS_DATA',
]
print("Loading admission events...")
tmp = pd.read_csv("~/Documents/data/mimic/ADMISSIONS.csv", parse_dates=['ADMITTIME', 'DISCHTIME','DEATHTIME', 'EDREGTIME', 'EDOUTTIME',])[admission_cols]

# concat primary dx onto admissions
admits = tmp.merge(df, on=['HADM_ID']).drop_duplicates(subset=["HADM_ID"])

# get rid of spurrious admissions and ignore newborns
admits = admits[(admits['DISCHTIME'] > admits['ADMITTIME']) & (admits.ADMISSION_TYPE != "NEWBORN")]

# add age information
admits = admits.merge(patients[['SUBJECT_ID', 'DOB']], on='SUBJECT_ID', how='left')
admits['age'] = admits.apply(lambda x: (x['ADMITTIME'].date() - x['DOB'].date()).days // 365.242, axis=1)

# tag copd admissions
admits['copd'] = admits.HADM_ID.isin(copd_hadmids)

# get the type and time of the next admission
admits.sort_values(by=['SUBJECT_ID', 'ADMITTIME'],inplace=True)
admits['next_admit_time'] = admits.groupby('SUBJECT_ID').ADMITTIME.shift(-1)
admits['next_admit_type'] = admits.groupby('SUBJECT_ID').ADMISSION_TYPE.shift(-1)
# if the next admission is elective, nullify and back fill
admits.loc[admits.next_admit_type == "ELECTIVE", 'next_admit_time'] = pd.NaT
admits.loc[admits.next_admit_type == "ELECTIVE", 'next_admit_type'] = np.nan
admits[['next_admit_time','next_admit_type']] = admits.groupby(['SUBJECT_ID'])[['next_admit_time','next_admit_type']].fillna(method = 'bfill')

# compute readmission stats
admits['readmit_time'] = admits.groupby('SUBJECT_ID').apply(lambda x: x['next_admit_time'] - x['DISCHTIME']).reset_index(level=0, drop=True)
admits['7d_readmit'] = (admits['readmit_time'].dt.total_seconds() < 7 * 24 * 3600).astype(int)
admits['30d_readmit'] = (admits['readmit_time'].dt.total_seconds() < 30 * 24 * 3600).astype(int)

Loading dx codes...
Admission: 58976
COPD Admissions: 7459
Num patients: 46520
Num female: 20399
Loading admission events...


In [4]:
def print_summary(df):
    gb = df.groupby(['copd','7d_readmit']).HADM_ID.count()
    non_rate = gb[0][1] / gb[0].sum()
    copd_rate = gb[1][1] / gb[1].sum()
    print("Non-COPD 7d readmit rate: {:.1%}".format(non_rate))
    print("COPD 7d readmit rate:     {:.1%}".format(copd_rate))
    print('')

    gb = df.groupby(['copd','30d_readmit']).HADM_ID.count()
    non_rate = gb[0][1] / gb[0].sum()
    copd_rate = gb[1][1] / gb[1].sum()
    print("Non-COPD 30d readmit rate: {:.1%}".format(non_rate))
    print("COPD 30d readmit rate:     {:.1%}".format(copd_rate))
    print('')

    gb = df[df.DEATHTIME.notnull()].drop_duplicates(subset=['SUBJECT_ID']).groupby('copd').size()
    print("Non-COPD mortality rate: {:.1%}".format(gb[0] / df[df.copd == False].shape[0]))
    print("COPD mortality rate:     {:.1%}".format(gb[1] / df[df.copd].shape[0]))

print("<65 Admissions")
print("-"*25)
print_summary(admits[admits['age'] < 65])

print("\n\n65+ Admissions")
print("-"*25)
print_summary(admits[admits['age'] >= 65])

print("\n\nAll Admissions")
print("-"*25)
print_summary(admits)

<65 Admissions
-------------------------
Non-COPD 7d readmit rate: 1.9%
COPD 7d readmit rate:     3.2%

Non-COPD 30d readmit rate: 5.3%
COPD 30d readmit rate:     9.1%

Non-COPD mortality rate: 7.4%
COPD mortality rate:     8.1%


65+ Admissions
-------------------------
Non-COPD 7d readmit rate: 2.2%
COPD 7d readmit rate:     2.6%

Non-COPD 30d readmit rate: 5.7%
COPD 30d readmit rate:     7.6%

Non-COPD mortality rate: 14.5%
COPD mortality rate:     15.2%


All Admissions
-------------------------
Non-COPD 7d readmit rate: 2.1%
COPD 7d readmit rate:     2.8%

Non-COPD 30d readmit rate: 5.5%
COPD 30d readmit rate:     8.1%

Non-COPD mortality rate: 10.8%
COPD mortality rate:     13.0%


# Discharge Notes Data

In [5]:
# subjects that died in the hosp
deceased_subj_ids = admits[admits.DEATHTIME.notnull()].SUBJECT_ID.unique()
# subjects w/ at least one copd related admission
copd_subj_ids = admits[admits.copd].SUBJECT_ID.unique()
# all admissions for subjects w/ at least one copd related admission
hadm_ids_w_copd = admits[admits.SUBJECT_ID.isin(copd_subj_ids)].HADM_ID.unique()


print('Loading medical notes...')

chunk_reader = pd.read_csv("~/Documents/data/mimic/NOTEEVENTS.csv", chunksize=100000, usecols=['SUBJECT_ID','HADM_ID', 'CHARTDATE','CATEGORY', 'DESCRIPTION', 'TEXT',])
chunk_li = []
iteration = 0
for chunk in chunk_reader:
    
    if iteration % 5 == 0:
        print(f"Iteration {iteration}")
        # keep only admissions for subjects that had at least one copd admit
#     chunk_li.append(chunk[(chunk['HADM_ID'].isin(hadm_ids_w_copd))])
        # keep just the discharge summaries
        chunk_li.append(chunk[(chunk['CATEGORY'] == 'Discharge summary')])
    iteration += 1
    
print("Done.")

notes = pd.concat(chunk_li, ignore_index=True)
# keep only one discharge summary per admission
notes = notes.sort_values(by=['SUBJECT_ID', 'HADM_ID', 'CHARTDATE']).groupby(['HADM_ID']).nth(-1)
cols = ['HADM_ID', 'SUBJECT_ID','age', 'copd', 'HOSPITAL_EXPIRE_FLAG','ADMISSION_TYPE', 'ADMITTIME', 'DISCHTIME', 'DEATHTIME','next_admit_time', 'next_admit_type','30d_readmit',]
notes = notes.merge(admits[cols], on=['SUBJECT_ID','HADM_ID'], how='inner')
notes.head()

Loading medical notes...
Iteration 0
Iteration 5
Iteration 10
Iteration 15
Iteration 20
Done.


Unnamed: 0,SUBJECT_ID,HADM_ID,CHARTDATE,CATEGORY,DESCRIPTION,TEXT,age,copd,HOSPITAL_EXPIRE_FLAG,ADMISSION_TYPE,ADMITTIME,DISCHTIME,DEATHTIME,next_admit_time,next_admit_type,30d_readmit
0,58526,100001.0,2117-09-17,Discharge summary,Report,Admission Date: [**2117-9-11**] ...,35.0,False,0,EMERGENCY,2117-09-11 11:46:00,2117-09-17 16:45:00,NaT,2118-07-07 06:26:00,EMERGENCY,0
1,54610,100003.0,2150-04-21,Discharge summary,Report,Admission Date: [**2150-4-17**] ...,59.0,False,0,EMERGENCY,2150-04-17 15:34:00,2150-04-21 17:30:00,NaT,2150-07-13 18:56:00,EMERGENCY,0
2,9895,100006.0,2108-04-18,Discharge summary,Addendum,"Name: [**Known lastname 470**], [**Known firs...",48.0,True,0,EMERGENCY,2108-04-06 15:49:00,2108-04-18 17:18:00,NaT,2108-08-02 15:36:00,EMERGENCY,0
3,23018,100007.0,2145-04-07,Discharge summary,Report,Admission Date: [**2145-3-31**] ...,73.0,False,0,EMERGENCY,2145-03-31 05:33:00,2145-04-07 12:40:00,NaT,NaT,,0
4,533,100009.0,2162-05-21,Discharge summary,Report,Admission Date: [**2162-5-16**] ...,60.0,False,0,EMERGENCY,2162-05-16 15:56:00,2162-05-21 13:37:00,NaT,NaT,,0


# Pre-Process Notes

In [53]:
def preprocess_text(df):
    # This function preprocesses the text by filling not a number 
    # and replacing new lines ('\n') and carriage returns ('\r')
    df.TEXT = df.TEXT.fillna(' ')
    df.TEXT = df.TEXT.str.replace('\n',' ')
    df.TEXT = df.TEXT.str.replace('\r',' ')
    return df

# discard admissions where the patient died during the stay
# use 30d readmission as the label
text_data = preprocess_text(notes[notes.HOSPITAL_EXPIRE_FLAG == 0].sample(frac=0.3, random_state=42))['TEXT'].str[:2500]
y = notes[notes.HOSPITAL_EXPIRE_FLAG == 0].sample(frac=0.3, random_state=42)['30d_readmit']
text_data.str.len().describe()

count    13162.000000
mean      2401.058958
std        374.180210
min        409.000000
25%       2500.000000
50%       2500.000000
75%       2500.000000
max       2500.000000
Name: TEXT, dtype: float64

In [58]:
text_data

31181    Admission Date:  [**2149-11-4**]              ...
28182    Admission Date:  [**2137-3-21**]              ...
21728    Admission Date:  [**2109-7-6**]              D...
45550    Name:  [**Known lastname **], [**Known firstna...
34146    Admission Date:  [**2176-2-8**]       Discharg...
                               ...                        
7409     Admission Date:  [**2106-11-24**]             ...
151      Admission Date: [**2126-2-7**]        Discharg...
32593    Admission Date: [**2199-6-7**]        Discharg...
48852    Admission Date:  [**2200-12-12**]             ...
10715    Admission Date:  [**2146-1-20**]     Discharge...
Name: TEXT, Length: 13162, dtype: object

### Toy example of using BERT and transformers library to tokenize and generate embeddings

In [110]:
text = """Prior to the hospitalization, she had a L parotidectomy for what turned out to be parotiditis 
    and sialadenitis with a large  retained duct stone. Ultimately, it became clear she had no persistent
    infectious process in the parotid bed, but had evolving carbapenem and cephalosporin erythroderm.
    Her rashes improved dramatically with transition to from meropenem to cefepime to aztreonam. Her course
    was further complicated by a fever curve that had regular Tmax in the 101 range, resolving while on 
    vancomycin, aztreonam, clindamycin and micafungin, but  then recurred first low grade then becoming 
    very hectic and high  grade to 104 without any focal findings. The vancomycin was stopped and she 
    defervesced after 72 hours. She soon thereafter recovered her counts and all antibiotics were 
    discontinued when her ANC approached 500."""
bert_base_tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

# 1. Add the special tokens
marked_text = "[CLS] " + text + " [SEP]"

# 2. Split the sentence into tokens
tokenized_text = bert_base_tokenizer.tokenize(marked_text)
tokenized_text = tokenized_text[:150]

# 3. Map the token strings to their vocabulary indices
indexed_tokens = bert_base_tokenizer.convert_tokens_to_ids(tokenized_text)

# 4. Display the words with their indices.
for tup in zip(tokenized_text, indexed_tokens):
    print('{:<12} {:>6,}'.format(tup[0], tup[1]))

# 5. Segment IDs: 0 for sentence 1 and 1 for sentence 2
segments_ids = [1] * len(tokenized_text)

print(indexed_tokens)
print(segments_ids)
# 6. convert the lists to tensors
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])

model = BertModel.from_pretrained('bert-base-uncased')

model.eval()
with torch.no_grad():
    last_hidden_state, pooler_output = model(tokens_tensor, segments_tensors)
last_hidden_state.shape

[CLS]           101
Prior         4,602
to            1,106
the           1,103
hospital      2,704
##ization     2,734
,               117
she           1,131
had           1,125
a               170
L               149
par          14,247
##ot          3,329
##ide         3,269
##ct          5,822
##omy        18,574
for           1,111
what          1,184
turned        1,454
out           1,149
to            1,106
be            1,129
par          14,247
##ot          3,329
##id          2,386
##itis       10,721
and           1,105
si           27,466
##ala         5,971
##den         2,883
##itis       10,721
with          1,114
a               170
large         1,415
retained      5,366
duct         26,862
stone         2,576
.               119
Ultimately   16,266
,               117
it            1,122
became        1,245
clear         2,330
she           1,131
had           1,125
no            1,185
persistent   15,970
infectious   20,342
process       1,965
in            1,107


HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…

KeyboardInterrupt: 

In [58]:
last_hidden_state.shape

torch.Size([1, 150, 768])

## Load pre-trained BlueBERT model and tokenizer

In [10]:
blue_bert_tokenizer = BertTokenizer.from_pretrained("/Users/kevin/Documents/data/BlueBERT/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12/")
configuration = BertConfig.from_json_file("/Users/kevin/Documents/data/BlueBERT/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12/bert_config.json")
model = BertModel.from_pretrained("/Users/kevin/Documents/data/BlueBERT/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12/pytorch_model.bin", config=configuration)


## Generate embeddings

In [11]:
# tokenizer requires input text to be either a single string or list of strings
batch_size = 32


tokens = blue_bert_tokenizer(
    text_data[:10].values.tolist(),
    padding='max_length',
    truncation=True,
    max_length=128,
    return_tensors="pt",
)

In [12]:
tokens['input_ids'].shape

torch.Size([10, 128])

In [54]:
batch_size = 32
max_seq_length = 64
token_li = []
for i in range(text_data.shape[0] // batch_size + 1):
    start_idx = i * 32
    stop_idx = (i + 1) * 32
    
    if i % 100 == 0:
        print(f"Tokenizing batch {i + 1}...")
    # shape: (32, 128)
    tokens = blue_bert_tokenizer(
        text_data[start_idx:stop_idx].values.tolist(),
        padding='max_length',
        truncation=True,
        max_length=max_seq_length,
        return_tensors="pt",
    )
    token_li.append(tokens)

Tokenizing batch 1...
Tokenizing batch 101...
Tokenizing batch 201...
Tokenizing batch 301...
Tokenizing batch 401...


In [140]:
# 1372*32
len(token_li)

1372

In [55]:
model.eval()
hidden_states = []
with torch.no_grad():
    for i, batch in enumerate(token_li):
        if i % 25 == 0:
            print(f"Embedding batch {i + 1}...")
        last_hidden_state, _ = model(**batch)
        hidden_states.append(last_hidden_state)
        

Embedding batch 1...
Embedding batch 26...
Embedding batch 51...
Embedding batch 76...
Embedding batch 101...
Embedding batch 126...
Embedding batch 151...
Embedding batch 176...
Embedding batch 201...
Embedding batch 226...
Embedding batch 251...
Embedding batch 276...
Embedding batch 301...
Embedding batch 326...
Embedding batch 351...
Embedding batch 376...
Embedding batch 401...


In [59]:
embeddings = torch.cat(hidden_states)
print(embeddings.shape)

torch.Size([13162, 64, 768])


# Generate Tokens

In [54]:
my_list = [0,1,2,3,4,5,6,7,8]
my_list[:5]

[0, 1, 2, 3, 4]

# word2vec model

In [61]:
len(set(vocab.keys()) & set(wv.vocab.keys()))
    

28584

In [3]:
wv = gensim.models.KeyedVectors.load_word2vec_format('/Users/kevin/gensim-data/word2vec-google-news-300/word2vec-google-news-300.gz', binary=True)
# BlueBERT
# wv = gensim.models.KeyedVectors.load_word2vec_format('/Users/kevin/gensim-data/NCBI_BERT_pubmed_mimic_uncased_L-12_H-768_A-12.zip', binary=True)



UnicodeDecodeError: 'utf-8' codec can't decode byte 0x83 in position 10: invalid start byte

In [59]:
for i, word in enumerate(vocab):
    if i == 10:
        break
    print(word)

vitals
lf
bs
cxr
found
parents
guarding
aspirin
clear
nad


# Word Embeddings

In [35]:
pairs = [
    ('car', 'vehicle'),
    ('car', 'minivan'),   # a minivan is a kind of car
    ('car', 'bicycle'),   # still a wheeled vehicle
    ('car', 'airplane'),  # ok, no wheels, but still a vehicle
    ('car', 'cereal'),    # ... and so on
    ('car', 'communism'),
]
for w1, w2 in pairs:
    print('%r\t%r\t%.2f' % (w1, w2, wv.similarity(w1, w2)))

'car'	'vehicle'	0.78
'car'	'minivan'	0.69
'car'	'bicycle'	0.54
'car'	'airplane'	0.42
'car'	'cereal'	0.14
'car'	'communism'	0.06


In [38]:
def get_word2vec(tokens_list, vector, generate_missing=False, k=300):
    if generate_missing:
        vectorized = [vector[word] if word in vector else np.random.uniform(-0.25,0.25,k) for word in tokens_list]
    else:
        vectorized = [vector[word] if word in vector else np.zeros(k) for word in tokens_list]
    return vectorized
sample_sequence = get_word2vec(tokens[0], wv, generate_missing=True)

In [41]:
# len(sample_sequence)
sample_sequence[0].shape

(300,)

In [34]:
class TextCNN(nn.Module):
    """ Text CNN from Yoon Kim's 2014 paper: https://arxiv.org/pdf/1408.5882.pdf
    
    Params
    ------
    sequence_length : int
        The length (in words/tokens) of each sentence.
    num_classes : int
        Number of output classes.
    vocab_size : int
        The number of unique tokens in our vocabulary.
    embedding_size: int
        The vector length for word embeddings. (standard word2vec is 300, BERT is 768)
    num_filters: int
        The number of filters to apply.
    kernel_sizes: tuple(int)
        The kernel size for each desired filter (e.g. [3,4,5])
    dropout_rate: float
        Probability of dropping a neuron in the dropout layer.  Must be in the range [0.0, 1.0]
        Default = 0.5
    embedding_weights: torch.FloatTensor
        Pre-trained embedding weights to optionally pass, otherwise embedding weights will be learned.
        Default is None.
    
    Returns
    -------
    model : nn.Module
    """
    def __init__(
        self,
        num_classes,
        embedding_size,
        num_filters,
        kernel_sizes=(3,4,5),
        dropout_rate=0.5,
    ):
        super(TextCNN, self).__init__()
        
        # convolutional layer
        self.convs = nn.ModuleList([nn.Conv2d(
            in_channels=1,
            out_channels=num_filters,
            kernel_size=(kernel_size, embedding_size),
        ) for kernel_size in kernel_sizes])
        
        # dropout layer
        self.dropout = nn.Dropout(dropout_rate)
        
        # fully connected layer
        self.fc = nn.Linear(
            in_features=num_filters * len(kernel_sizes),
            out_features=num_classes
        )
        
    def forward(self, x):
        x = x.unsqueeze(1)  # (batch_size, in_channels, sequence_length, embedding_size)
        
        x_li = []
        for conv in self.convs:
            _x = F.relu(conv(x)) # (batch_size, out_channels, sequence_length, 1)
            _x = _x.squeeze(3) # (batch_size, out_channels, sequence_length)
            _x = F.max_pool1d(_x, _x.size(2)).squeeze(2) # (batch_size, out_channels)
            x_li.append(_x)
            
        x = torch.cat(x_li, 1)
        x = self.dropout(x) # (batch_size, len(kernel_sizes) * out_channels)
        logits = self.fc(x) # (batch_size, num_classes)
        
        probs = F.softmax(logits, dim=1) # (batch_size, num_classes)
        classes = torch.max(probs, 1)[1] # (batch_size)
        
        return probs, classes

In [60]:
# Convert training data to pytorch dataset
train_tensors = torch.utils.data.TensorDataset(
    embeddings,
    torch.tensor(y.values).long()
)

# Create iterable
trainloader = torch.utils.data.DataLoader(train_tensors, batch_size=32,
                                          shuffle=True, num_workers=2)

In [62]:
def train():
    model = TextCNN(
        num_classes=2,
        embedding_size=768,
        num_filters=128,
        dropout_rate=0.5,
    )
    
    loss_function = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

            
    for epoch in range(2):  # loop over the dataset multiple times

        running_loss = 0
        running_corrects = 0
        for i, data in enumerate(trainloader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward
            probs, classes = model(inputs)
            # backprop
            loss = loss_function(probs, labels)
            loss.backward()
            # update/optimize
            optimizer.step()

#             running_losses.append(loss.data[0])
#             if i % 50 == 0:
#                 disp_loss = sum(running_losses) / len(running_losses)
# #                 writer.add_scalar("train/loss", loss, step)
#                 logging.info("step = {}, loss = {}".format(i, loss))
#                 running_losses = []
            # print statistics
            running_loss += loss.item()
            running_corrects += torch.sum(classes == labels.data)
            if i % 50 == 0 and i != 0:    # print every 50 mini-batches
                print('[%d, %5d] loss: %.3f acc %.3f' %
                      (epoch + 1, i + 1, running_loss / 50, running_corrects / 50))
                running_loss = 0.0
                running_corrects = 0.0
                
#             epoch_loss = running_loss / dataset_sizes[phase]
#             epoch_acc = running_corrects.double() / dataset_sizes[phase]

    print('Finished Training')

In [63]:
train()



[1,    51] loss: 0.383 acc 30.000
[1,   101] loss: 0.378 acc 29.940
[1,   151] loss: 0.390 acc 29.540
[1,   201] loss: 0.381 acc 29.840
[1,   251] loss: 0.377 acc 29.960
[1,   301] loss: 0.383 acc 29.780
[1,   351] loss: 0.376 acc 29.980
[1,   401] loss: 0.369 acc 30.220
[2,    51] loss: 0.384 acc 30.000
[2,   101] loss: 0.392 acc 29.480
[2,   151] loss: 0.368 acc 30.240
[2,   201] loss: 0.385 acc 29.720
[2,   251] loss: 0.380 acc 29.860
[2,   301] loss: 0.381 acc 29.840
[2,   351] loss: 0.373 acc 30.080
[2,   401] loss: 0.368 acc 30.240
Finished Training
