In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import torch 
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, SequentialSampler
import tqdm
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, BertModel
from sklearn.preprocessing import LabelEncoder

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])


In [2]:
train = pd.read_csv('train_1k.csv', index_col = 0)
val = pd.read_csv('val_1k.csv', index_col = 0)
test = pd.read_csv('test_1k.csv', index_col = 0)

train.dropna(subset = ['ICD9_CODE_1k'], inplace = True)
val.dropna(subset = ['ICD9_CODE_1k'], inplace = True)
test.dropna(subset = ['ICD9_CODE_1k'], inplace = True)

In [3]:
# Get all labels

label_list = []
for code in train['ICD9_CODE_1k']:
    labels = code.split(',')
    label_list.extend([label for label in labels if label not in label_list])

In [4]:
label_dict = {}
for i, label in enumerate(label_list):
    label_dict[label] = i

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case = True)

In [3]:
bert = BertModel.from_pretrained('bert-base-cased', output_attentions=True)

In [4]:
bert.embeddings.word_embeddings = nn.Embedding(tokenizer.vocab_size, 768, padding_idx = 0)

In [None]:
# Split the sequence in half and then tokenize seperately

def bert_tokenize(data, max_length, label_dict):
    
    input_ids_first = []
    input_ids_second = []
    attention_masks_first = []
    attention_masks_second = []
    labels = []
    
    for sentence in data['TEXT']:
        text_len = len(sentence)

        encoded_dict_first = tokenizer.encode_plus(sentence[:text_len//2], add_special_tokens = True, max_length = max_length,\
                                             pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')
        input_ids_first.append(encoded_dict_first['input_ids'])
        attention_masks_first.append(encoded_dict_first['attention_mask'])
   
        
        encoded_dict_second = tokenizer.encode_plus(sentence[text_len//2:], add_special_tokens = True, max_length = max_length,\
                                             pad_to_max_length = True, return_attention_mask = True, return_tensors = 'pt')
        input_ids_second.append(encoded_dict_second['input_ids'])
        attention_masks_second.append(encoded_dict_second['attention_mask'])
        
    for codes in data['ICD9_CODE_1k']:
        label = [0]*1000
        
        all_code = codes.split(',')
        for code in all_code:
            label[label_dict[code]] = 1
        
        labels.append(torch.tensor(label))


        
    input_ids_first = torch.cat(input_ids_first, dim=0)
    attention_masks_first = torch.cat(attention_masks_first, dim=0)
    input_ids_second = torch.cat(input_ids_second, dim=0)
    attention_masks_second = torch.cat(attention_masks_second, dim=0)
    labels = torch.stack(labels, dim=0)
        
    return input_ids_first, attention_masks_first, input_ids_second, attention_masks_second, labels

In [None]:
batch_size = 4

input_ids_first_train, attention_masks_first_train, input_ids_second_train,\
attention_masks_second_train, labels_train = bert_tokenize(train, 512, label_dict)

train_dataset_bert = TensorDataset(input_ids_first_train, attention_masks_first_train, input_ids_second_train,\
                                   attention_masks_second_train, labels_train)
train_loader_bert = DataLoader(train_dataset_bert, shuffle = True, batch_size = batch_size)

input_ids_first_val, attention_masks_first_val, input_ids_second_val,\
attention_masks_second_val, labels_val = bert_tokenize(val, 512, label_dict)

val_dataset_bert = TensorDataset(input_ids_first_val, attention_masks_first_val, input_ids_second_val,\
                                 attention_masks_second_val, labels_val)
val_loader_bert = DataLoader(val_dataset_bert, shuffle = True, batch_size = batch_size)

input_ids_first_test, attention_masks_first_test, input_ids_second_test,\
attention_masks_second_test, labels_test = bert_tokenize(test, 512, label_dict)

test_dataset_bert = TensorDataset(input_ids_first_test, attention_masks_first_test, input_ids_second_test,\
                                  attention_masks_second_test, labels_test)
test_sampler_bert = SequentialSampler(test_dataset_bert)
test_loader_bert = DataLoader(test_dataset_bert, sampler = test_sampler_bert, batch_size = batch_size)

In [None]:
# torch.save(train_loader_bert, 'train_dataloader.pth')
# torch.save(val_loader_bert, 'val_dataloader.pth')
# torch.save(test_loader_bert, 'test_dataloader.pth')

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
train_loader_bert = torch.load('train_dataloader.pth')
val_loader_bert = torch.load('val_dataloader.pth')
test_loader_bert = torch.load('test_dataloader.pth')

In [4]:
# Train first and second half of the sequence seperately, then concatenate the hidden state output

class BERTClassifier(nn.Module):
    def __init__(self, bert, num_classes):
        super().__init__()
        self.bert = bert
        self.linear = nn.Linear(bert.config.hidden_size*2, num_classes)
        self.num_classes = num_classes
    
    def forward(self, input_ids_first, attention_masks_first, input_ids_second, attention_masks_second):
        h1, _, _ = self.bert(input_ids = input_ids_first, attention_mask = attention_masks_first)
        h1_cls = h1[:, 0]
        h2, _, _ = self.bert(input_ids = input_ids_second, attention_mask = attention_masks_second)
        h2_cls = h2[:, 0]
        h_cls = torch.cat((h1_cls, h2_cls), dim = -1)
        logits = self.linear(h_cls)
        return logits

In [12]:
model_bert = BERTClassifier(bert, 1000).to(device)

In [None]:
criterion = nn.BCEWithLogitsLoss(reduction = 'sum')
optimizer_bert = AdamW(model_bert.parameters(), lr = 1e-5)

In [None]:
train_loss_list_bert = []
val_loss_list_bert = []

for epoch in range(5):
    print("current epoch is "+str(epoch))
    train_loss = 0
    train_correct = 0
    train_total = 0
    val_loss = 0
    val_correct = 0
    val_total = 0
    model_bert.train()
    for i, (input_ids_first, attention_masks_first, input_ids_second,
            attention_masks_second, labels) in enumerate(train_loader_bert):
        
        optimizer_bert.zero_grad()
        input_ids_first = input_ids_first.to(device)
        attention_masks_first = attention_masks_first.to(device)
        input_ids_second = input_ids_second.to(device)
        attention_masks_second = attention_masks_second.to(device)
        labels = labels.to(device).float()

        
        
        logits = model_bert(input_ids_first, attention_masks_first, input_ids_second, attention_masks_second)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer_bert.step()
        
        train_loss += loss.item()
        train_total += input_ids_first.size()[0]
        
    train_avg_loss = train_loss / train_total   
    train_loss_list_bert.append(train_avg_loss)
    
    model_bert.eval()
    with torch.no_grad():
        for i, (input_ids_first, attention_masks_first, input_ids_second,
                attention_masks_second, labels) in enumerate(val_loader_bert):
        
            input_ids_first = input_ids_first.to(device)
            attention_masks_first = attention_masks_first.to(device)
            input_ids_second = input_ids_second.to(device)
            attention_masks_second = attention_masks_second.to(device)
            labels = labels.to(device).float()

            logits = model_bert(input_ids_first, attention_masks_first, input_ids_second, attention_masks_second)
            loss = criterion(logits, labels)

            val_loss += loss.item()
            val_total += input_ids_first.size()[0]       
        

    val_avg_loss = val_loss / val_total  
    val_loss_list_bert.append(val_avg_loss)

current epoch is 0


In [None]:
plt.plot(np.arange(5), train_loss_list_bert, label = 'train')
plt.plot(np.arange(5), val_loss_list_bert, label = 'validation')
plt.legend()
plt.title('Bert Average Loss over Epoch')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.savefig('loss.jpg')

In [None]:
# torch.save(model_bert, 'model.pt')

In [5]:
model_bert = torch.load('model.pt').to(device)

In [18]:
test_label = []
test_logits = []
test_prediction = []

model_bert.eval()
with torch.no_grad():
    for i, (input_ids_first, attention_masks_first, input_ids_second,
            attention_masks_second, labels) in enumerate(test_loader_bert):

        input_ids_first = input_ids_first.to(device)
        attention_masks_first = attention_masks_first.to(device)
        input_ids_second = input_ids_second.to(device)
        attention_masks_second = attention_masks_second.to(device)
        labels = labels.to(device).float()

        logits = model_bert(input_ids_first, attention_masks_first, input_ids_second, attention_masks_second)
        sigmoid_logits = torch.sigmoid(logits)
        prediction = torch.where(sigmoid_logits > 0.5, torch.tensor(1).to(device), torch.tensor(0).to(device))
        
        test_label.extend(labels.tolist())
        test_logits.extend(logits.tolist())
        test_prediction.extend(prediction.tolist())

In [25]:
test_label_array = np.array(test_label)
test_logits_array = np.array(test_logits)
test_prediction_array = np.array(test_prediction)

In [41]:
# Show precision and recall (at 10 and at 5)

from sklearn.metrics import precision_score, recall_score

In [37]:
# At 10

label_at_10 = []
prediction_at_10 = []

top_10 = test_logits_array.argsort(axis = 1)[:,-10:]
for row,top in enumerate(top_10):
    test_label_top = test_label_array[row][top]
    test_prediction_top = test_prediction_array[row][top]

    label_at_10.extend(test_label_top)
    prediction_at_10.extend(test_prediction_top)
    

In [42]:
precision_at_10 = precision_score(label_at_10, prediction_at_10)
recall_at_10 = recall_score(label_at_10, prediction_at_10)

In [46]:
print(precision_at_10)
print(recall_at_10)

0.7183430418737992
0.5064995806722147


In [47]:
# At 5

label_at_5 = []
prediction_at_5 = []

top_5 = test_logits_array.argsort(axis = 1)[:,-5:]
for row,top in enumerate(top_5):
    test_label_top = test_label_array[row][top]
    test_prediction_top = test_prediction_array[row][top]

    label_at_5.extend(test_label_top)
    prediction_at_5.extend(test_prediction_top)

In [50]:
precision_at_5 = precision_score(label_at_5, prediction_at_5)
recall_at_5 = recall_score(label_at_5, prediction_at_5)

In [51]:
print(precision_at_5)
print(recall_at_5)

0.7276980053277889
0.7065688825802382
