In [1]:
import pandas as pd
import pickle
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score
from torch.utils.data import Dataset, DataLoader
from transformers import BertModel, BertTokenizer, BertForSequenceClassification, PreTrainedModel

bert_model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [2]:
with open('data/label_mapping.pickle', 'rb') as f:
    map_dct = pickle.load(f)

In [3]:
class CasesDataset(Dataset):
    def __init__(self, filename, maxlen):
        
        self.df = pd.read_csv(filename)
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.maxlen = maxlen
        
    def __len__(self):
        
        return len(self.df)
    
    def __getitem__(self, index):
        
        sentence = self.df.loc[index, 'text'] 
        label = self.df.loc[index, 'label']
        
        tokens = self.tokenizer.tokenize(sentence)
        tokens = ['[CLS]'] + tokens 
        
        if len(tokens) < self.maxlen:
            tokens = tokens + ['[PAD]' for _ in range(self.maxlen - len(tokens))]
        else:
            tokens = tokens[:self.maxlen] 
        
        tokens_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        tokens_ids_tensor = torch.tensor(tokens_ids)
        #print('Dataloader:')
        #print(tokens_ids_tensor.shape)
        attn_mask = (tokens_ids_tensor != 0).long()
        
        return tokens_ids_tensor, attn_mask, label

In [4]:
class BertForMultiLabelSequenceClassification(nn.Module):
    """BERT model for classification.
    This module is composed of the BERT model with a linear layer on top of
    the pooled output.
    """
    def __init__(self,num_labels=2, freeze_bert = True):
        super(BertForMultiLabelSequenceClassification, self).__init__()
        
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        
        if freeze_bert:
            for p in self.bert.parameters():
                p.require_grad = False
        
        self.num_labels = num_labels
       
        self.dropout = torch.nn.Dropout(0.2)
        self.classifier = torch.nn.Linear(768, num_labels)
        
    def forward(self, input_ids, token_type_ids=None, attention_mask=None):
        
        #_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask)
        #pooled_output = self.dropout(pooled_output)
        #logits = self.classifier(pooled_output)
        
        cont_reps, _ = self.bert(input_ids, attention_mask = attention_mask)
        cls_rep = cont_reps[:, 0]
        x = self.dropout(cls_rep)
        x = self.classifier(x)
        
        
        return x
        

In [5]:
train_set = CasesDataset('data/train_small.csv', maxlen = 512)
test_set = CasesDataset('data/test.csv', maxlen = 512)

In [6]:
batch_size = 16
#class_sample_count = train.label.value_counts(sort = False).values.tolist() 
#weights = 1 / torch.Tensor(class_sample_count)
#weights = weights.double()
#sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, batch_size)

In [7]:
#train_loader = DataLoader(train_set, batch_size = batch_size, num_workers = 4, shuffle =True)
#val_loader = DataLoader(test_set, batch_size = batch_size, num_workers = 4)
train_loader = DataLoader(dataset=train_set,
                          batch_size=batch_size,
                          num_workers = 4,
                          shuffle = True)
val_loader = DataLoader(dataset=test_set, batch_size=batch_size, num_workers = 4)

In [8]:
model = BertForMultiLabelSequenceClassification(num_labels = 7, freeze_bert = False)
model.cuda()

BertForMultiLabelSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-1

In [9]:
#criterion = nn.BCEWithLogitsLoss(reduction = 'mean')
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 2e-5)

In [10]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

def evaluate(model, criterion, dataloader):
    
    model.eval()

    mean_loss = 0
   
    count = 0
    with torch.no_grad():
        for it, (seq, attn_masks, labels) in enumerate(dataloader):
            #print(seq.size())
            #print(attn_masks.size())
            #print(labels.size())
            seq, attn_masks, labels = seq.cuda(0), attn_masks.cuda(0),labels.cuda(0)
            #print(seq.size())
            #print(attn_masks.size())
            #print(labels.size())
            logits = model(seq, attention_mask = attn_masks)
            #print(logits.size())
            #print(labels.size())
            mean_loss += criterion(logits, labels).item()
            
            count += 1

    return mean_loss / count

def evaluate_v2(model, test_loader, map_dct):
    y_pred = []
    y_true = []

    model.eval()
    with torch.no_grad():
        for it, (seq, attn_masks, labels) in enumerate(test_loader):

                seq, attn_masks, labels = seq.cuda(0), attn_masks.cuda(0),labels.cuda(0)          
                logits = model(seq, attention_mask = attn_masks)
                
                y_pred.extend(torch.argmax(logits, 1).tolist())
                y_true.extend(labels.tolist())
                if it%100==0:
                    print(it, ' iterations complete')
    
    print('Classification Report:')
    print(classification_report(y_true, y_pred, digits=4, target_names = map_dct.keys()))
    return y_true, y_pred
    #cm = confusion_matrix(y_true, y_pred)
    #ax= plt.subplot()
    #sns.heatmap(cm, annot=True, ax = ax, cmap='Blues', fmt="d")

    #ax.set_title('Confusion Matrix')

    #ax.set_xlabel('Predicted Labels')
    #ax.set_ylabel('True Labels')



In [11]:
len(train_loader)

242

In [12]:
len(test_set)

26836

In [13]:
num_epochs = 4
num_labels = 7
for epoch in range(num_epochs):
    for it, (seq, attn_masks, labels) in enumerate(train_loader):
       
        optimizer.zero_grad()
        seq, attn_masks, labels = seq.cuda(0), attn_masks.cuda(0), labels.cuda(0)
        logits = model(seq, attention_mask = attn_masks)
        #print(labels.size())
        #print(logits.size())
        loss = criterion(logits, labels)
        loss.backward()
        
        optimizer.step()
        
        if (it + 1) % 10 == 0 :
            print("Iteration {} of epoch {} complete. Loss : {} ".format(it+1, epoch+1, loss.item()))
    print('='*100)
    
    val_loss = evaluate(model, criterion, val_loader)
    print("Epoch {} complete!  Validation Loss : {}".format(epoch+1, val_loss))
    evaluate_v2(model, val_loader, map_dct)

Iteration 10 of epoch 1 complete. Loss : 1.6598114967346191 
Iteration 20 of epoch 1 complete. Loss : 1.5736589431762695 
Iteration 30 of epoch 1 complete. Loss : 1.7666893005371094 
Iteration 40 of epoch 1 complete. Loss : 1.6009200811386108 
Iteration 50 of epoch 1 complete. Loss : 1.5415629148483276 
Iteration 60 of epoch 1 complete. Loss : 1.290280818939209 
Iteration 70 of epoch 1 complete. Loss : 1.2392823696136475 
Iteration 80 of epoch 1 complete. Loss : 1.0721619129180908 
Iteration 90 of epoch 1 complete. Loss : 0.8319083452224731 
Iteration 100 of epoch 1 complete. Loss : 1.1615474224090576 
Iteration 110 of epoch 1 complete. Loss : 1.1091748476028442 
Iteration 120 of epoch 1 complete. Loss : 0.6013498306274414 
Iteration 130 of epoch 1 complete. Loss : 0.8384524583816528 
Iteration 140 of epoch 1 complete. Loss : 1.11197829246521 
Iteration 150 of epoch 1 complete. Loss : 0.901842474937439 
Iteration 160 of epoch 1 complete. Loss : 0.8336580991744995 
Iteration 170 of epoc

Iteration 220 of epoch 4 complete. Loss : 0.01659846305847168 
Iteration 230 of epoch 4 complete. Loss : 0.07759702205657959 
Iteration 240 of epoch 4 complete. Loss : 0.12310564517974854 
Epoch 4 complete!  Validation Loss : 0.695194383690198
Classification Report:
                  precision    recall  f1-score   support

    bank_service     0.7512    0.7733    0.7621      2007
     credit_card     0.6627    0.8190    0.7326      2955
credit_reporting     0.8720    0.8051    0.8372      8123
 debt_collection     0.8383    0.7691    0.8022      6146
            loan     0.7022    0.8289    0.7603      3104
 money_transfers     0.7108    0.6237    0.6644       473
        mortgage     0.9190    0.8898    0.9041      4028

        accuracy                         0.8083     26836
       macro avg     0.7795    0.7870    0.7804     26836
    weighted avg     0.8168    0.8083    0.8102     26836

