In [1]:
import pandas as pd
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, PreTrainedModel

bert_model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
#df = pd.read_csv('data/preprocessed_case_study_data.csv')

In [3]:
#train = pd.read_csv('data/train.csv')
#val = pd.read_csv('data/val.csv')

In [4]:
class BertForMultiLabelSequenceClassification(nn.Module):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self,num_labels=2, freeze_bert = True):
        super(BertForMultiLabelSequenceClassification, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        if freeze_bert:
            for p in self.bert.parameters():
                p.require_grad = False
        
        self.num_labels = num_labels
       
        self.dropout = torch.nn.Dropout(0.2)
        self.classifier = torch.nn.Linear(768, num_labels)
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        
        #_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        #pooled_output = self.dropout(pooled_output)
        #logits = self.classifier(pooled_output)

        cont_reps, _ = self.bert(seq, attention_mask = attention_mask)
        cls_rep = cont_reps[:, 0]
        x = self.dropout(cls_rep)
        x = self.classifier(x)
        
        
        return x
        
  

In [5]:
class CasesDataset(Dataset):
    def __init__(self, filename, maxlen):
        
        self.df = pd.read_csv(filename)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.maxlen = maxlen
        
    def __len__(self):
        
        return len(self.df)
    
    def __getitem__(self, index):
        
        sentence = self.df.loc[index, 'text'] 
        label = self.df.loc[index, 'label']
        
        tokens = self.tokenizer.tokenize(sentence)
        tokens = ['[CLS]'] + tokens 
        
        if len(tokens) < self.maxlen:
            tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))]
        else:
            tokens = tokens[:self.maxlen] 
        
        tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        tokens_ids_tensor = torch.tensor(tokens_ids)
        #print('Dataloader:')
        #print(tokens_ids_tensor.shape)
        attn_mask = (tokens_ids_tensor != 0).long()
        
        return tokens_ids_tensor, attn_mask, label

In [6]:
train_set = CasesDataset('data/train.csv', maxlen = 512)
test_set = CasesDataset('data/val.csv', maxlen = 512)

In [7]:
batch_size = 16
#class_sample_count = train.label.value_counts(sort = False).values.tolist() 
#weights = 1 / torch.Tensor(class_sample_count)
#weights = weights.double()
#sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size)

In [8]:
#train_loader = DataLoader(train_set, batch_size = batch_size, num_workers = 4, shuffle =True)
#val_loader = DataLoader(test_set, batch_size = batch_size, num_workers = 4)
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          shuffle = True)
val_loader = DataLoader(dataset=test_set, batch_size=batch_size, drop_last = True)

In [9]:
len(train_loader)

15096

In [10]:
len(val_loader)

1677

In [11]:
model = BertForMultiLabelSequenceClassification(num_labels = 7, freeze_bert = True)


In [12]:
model.cuda()

BertForMultiLabelSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-1

In [13]:
def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):
    "Compute accuracy when `y_pred` and `y_true` are the same size."
    if sigmoid: 
        y_pred = y_pred.sigmoid()

    return np.mean(((y_pred>thresh)==y_true.byte()).float().cpu().numpy(), axis=1).sum()


In [14]:
#criterion = nn.BCEWithLogitsLoss(reduction = 'mean')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 2e-5)

In [15]:
def evaluate(model, criterion, dataloader):
    
    model.eval()

    mean_loss = 0
   
    count = 0
    with torch.no_grad():
        for seq, attn_masks, labels in dataloader:
            seq, attn_masks, labels = seq.cuda(0), attn_masks.cuda(0),labels.cuda(0)
            logits = model(seq, attn_masks)
            mean_loss += criterion(logits, labels).item()
            
            count += 1

    return mean_loss / count

In [None]:
num_epochs = 3
num_labels = 7
for epoch in range(num_epochs):
    for it, (seq, attn_masks, labels) in enumerate(train_loader):
       
        optimizer.zero_grad()
        seq, attn_masks, labels = seq.cuda(0), attn_masks.cuda(0), labels.cuda(0)
        logits = model(seq, attention_mask = attn_masks)
        #print(labels.size())
        #print(logits.size())
        loss = criterion(logits, labels)
        loss.backward()
        
        optimizer.step()
        
        if (it + 1) % 100 == 0 :
            print("Iteration {} of epoch {} complete. Loss : {} ".format(it+1, epoch+1, loss.item()))
    print('='*100)
    
    val_loss = evaluate(model, criterion, val_loader)
    print("Epoch {} complete!  Validation Loss : {}".format(epoch+1, val_loss))