In [1]:
import numpy as np
import pandas as pd
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, BertModel
from sklearn.preprocessing import LabelEncoder

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])


In [31]:
notes = pd.read_csv('NOTEEVENTS.csv')

  interactivity=interactivity, compiler=compiler, result=result)


In [32]:
notes = notes[notes['CATEGORY'] == 'Discharge summary'][['SUBJECT_ID', 'HADM_ID', 'CHARTDATE', 'TEXT']]

In [33]:
icd = pd.read_csv('DIAGNOSES_ICD.csv')

In [34]:
notes_icd = notes.merge(icd[['HADM_ID', 'ICD9_CODE']], how = 'left', on = 'HADM_ID')

In [35]:
# Only keep the earliest patient record

notes_icd = notes_icd.sort_values(['SUBJECT_ID', 'CHARTDATE'])
notes_icd = notes_icd.drop_duplicates('SUBJECT_ID')

In [36]:
notes_icd['TEXT'].str.len().describe()

count    41127.000000
mean      9867.789968
std       5054.366511
min        215.000000
25%       6297.500000
50%       8937.000000
75%      12368.500000
max      55728.000000
Name: TEXT, dtype: float64

In [37]:
notes_icd = notes_icd.sample(frac = 0.08)

In [38]:
# Split data

train, validate, test = np.split(notes_icd.sample(frac=1), [int(.7*len(notes_icd)), int(.9*len(notes_icd))])

In [39]:
validate = validate[validate['ICD9_CODE'].isin(train['ICD9_CODE'])]
test = test[test['ICD9_CODE'].isin(train['ICD9_CODE'])]

In [40]:
train.reset_index(drop =  True, inplace = True)
validate.reset_index(drop =  True, inplace = True)
test.reset_index(drop =  True, inplace = True)

In [41]:
le = LabelEncoder()
le.fit(train['ICD9_CODE'].astype('str'))

train['LABEL'] = le.transform(train['ICD9_CODE'].astype('str'))
validate['LABEL'] = le.transform(validate['ICD9_CODE'].astype('str'))
test['LABEL'] = le.transform(test['ICD9_CODE'].astype('str'))

In [42]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)

In [43]:
def bert_tokenize(data, max_length = 128, split_num = 512):
    
    input_ids = []
    attention_masks = []
    
    for sentence in data['TEXT']:
        partial_num = 0
        partial_sentence = sentence[0:split_num]
        input_id = torch.zeros(max_length)
        attention_mask = torch.ones(max_length)
        while len(partial_sentence) == split_num:
            encoded_dict = tokenizer.encode_plus(partial_sentence, add_special_tokens = True, max_length = max_length,\
                                                 pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')

            input_id = input_id+encoded_dict['input_ids']
            attention_mask = attention_mask*encoded_dict['attention_mask']
                
            partial_num += 1
            partial_sentence = sentence[partial_num*split_num:(partial_num+1)*split_num]
#             print(input_id)
#             print(attention_mask)
#             print('---------------------------------')
        
        encoded_dict = tokenizer.encode_plus(partial_sentence, add_special_tokens = True, max_length = max_length,\
                                             pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')  
        input_id = input_id+encoded_dict['input_ids']
        attention_mask = attention_mask*encoded_dict['attention_mask']
#         print('--------------------------------------------')
#         print(input_id)
        
        input_ids.append(input_id)
        attention_masks.append(attention_mask)
        
    input_ids = torch.cat(input_ids, dim=0)
    attention_masks = torch.cat(attention_masks, dim=0)
    labels = torch.tensor(data['LABEL'].tolist())
        
    return input_ids, attention_masks, labels

In [44]:
batch_size = 32

input_ids_train, attention_masks_train, labels_train = bert_tokenize(train)
train_dataset_bert = TensorDataset(input_ids_train, attention_masks_train, labels_train)
train_loader_bert = DataLoader(train_dataset_bert, shuffle = True, batch_size = batch_size)

input_ids_val, attention_masks_val, labels_val = bert_tokenize(validate)
val_dataset_bert = TensorDataset(input_ids_val, attention_masks_val, labels_val)
val_loader_bert = DataLoader(val_dataset_bert, shuffle = True, batch_size = batch_size)

input_ids_test, attention_masks_test, labels_test = bert_tokenize(test)
test_dataset_bert = TensorDataset(input_ids_test, attention_masks_test, labels_test)
test_sampler_bert = SequentialSampler(test_dataset_bert)
test_loader_bert = DataLoader(test_dataset_bert, sampler = test_sampler_bert, batch_size = batch_size)

In [45]:
bert = BertForSequenceClassification.from_pretrained('bert-base-cased', output_attentions=True)

In [46]:
# device = torch.device('cuda') if torch.cuda.is_available() else troch.device('cpu')

In [47]:
device = torch.device('cpu')

In [48]:
model_bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels = train['LABEL'].nunique(),\
                                                            output_attentions = False, output_hidden_states = False).to(device)

In [49]:
optimizer_bert = AdamW(model_bert.parameters(), lr = 1e-5)

In [51]:
model_bert

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [50]:
train_loss_list_bert = []
train_accuracy_list_bert = []
val_loss_list_bert = []
val_accuracy_list_bert = []

for epoch in range(5):
    train_loss = 0
    train_correct = 0
    train_total = 0
    val_loss = 0
    val_correct = 0
    val_total = 0
    model_bert.train()
    for i, (input_ids, attention_masks, labels) in enumerate(train_loader_bert):
        optimizer_bert.zero_grad()
        input_ids = input_ids.long()
        input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
        print(input_ids.shape)
        print(input_ids)
        loss, outputs = model_bert(input_ids, token_type_ids = None, attention_mask = attention_masks, labels = labels)
#         _, preds = torch.max(outputs, dim = 1)
#         loss.backward()
#         optimizer_bert.step()
        
#         train_loss += loss.item()
#         train_correct += torch.sum(preds == labels).item()
#         train_total += input_ids.size()[0]
        
#     train_avg_loss = train_loss / train_total   
#     train_acc = train_correct / train_total  
#     train_loss_list_bert.append(train_avg_loss)
#     train_accuracy_list_bert.append(train_acc)
    
#     model_bert.eval()
#     with torch.no_grad():
#         for i, (input_ids, attention_masks, labels) in enumerate(val_loader_bert):
#             input_ids = input_ids.long()
#             input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
#             loss, outputs = model_bert(input_ids, token_type_ids = None, attention_mask = attention_masks, labels = labels)
#             _, preds = torch.max(outputs, dim = 1)
#             val_loss += loss.item()
#             val_correct += torch.sum(preds == labels).item()
#             val_total += input_ids.size()[0]

#     val_avg_loss = val_loss / val_total  
#     val_acc = val_correct / val_total 
#     val_loss_list_bert.append(val_avg_loss)
#     val_accuracy_list_bert.append(val_acc)

torch.Size([32, 128])
tensor([[  1515, 118405, 117677,  ...,  59465,  72025,   1224],
        [  1313, 105439,  63500,  ...,  47584,  74988,    918],
        [  1010,  75649,  70947,  ...,  19197,  34121,    816],
        ...,
        [   808,  64829,  66890,  ...,   9285,  25655,    510],
        [  1919, 168584,  84960,  ..., 102714,  69387,   1326],
        [  2222, 231132, 120415,  ..., 112088,  42805,   1326]])


RuntimeError: index out of range: Tried to access index 118405 out of table with 30521 rows. at /pytorch/aten/src/TH/generic/THTensorEvenMoreMath.cpp:418

In [None]:
plt.plot(np.arange(5), train_loss_list_bert, label = 'train')
plt.plot(np.arange(5), val_loss_list_bert, label = 'validation')
plt.legend()
plt.title('Bert Average Loss over Epoch')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')

In [None]:
plt.plot(np.arange(5), train_accuracy_list_bert, label = 'train')
plt.plot(np.arange(5), val_accuracy_list_bert, label = 'validation')
plt.legend()
plt.title('Bert Accuracy over Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

In [None]:
test_correct_bert = 0
test_total_bert = 0
test_pred_list_bert = []
model_bert.eval()
with torch.no_grad():
    for i, (input_ids, attention_masks, labels) in enumerate(test_loader_bert):
        input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
        loss, outputs = model_bert(input_ids, token_type_ids = None, attention_mask = attention_masks, labels = labels)
        _, preds = torch.max(outputs, dim = 1)   
        test_pred_list_bert.extend(preds.tolist())
        
        test_correct_bert += torch.sum(preds == labels).item()
        test_total_bert += input_ids.size()[0]
        
test_acc_bert = test_correct_bert / test_total_bert
print('Bert test data accuracy is {:10.4f}'.format(test_acc_bert)) 