# Multi-label classification with the Longformer

In a previous post I explored the functionality of the Longformer for text classification. In this post I will explore the performance of the Longformer in a setting of multilabel classification problem.

For this dataset we need to download it manually from Kaggle and load it like we usually do with the datasets library. 'jigsaw_toxicity_pred', data_dir='/path/to/extracted/data/'

In [1]:
import torch
import transformers
import pandas as pd
import numpy as np
from torch.nn import BCEWithLogitsLoss
from transformers import LongformerTokenizerFast, LongformerModel, LongformerConfig
from transformers.models.longformer.modeling_longformer import LongformerPreTrainedModel, LongformerClassificationHead
from torch.utils.data import Dataset, DataLoader
import wandb

We are going to instantiate a raw LongFormer Model and add a classifier head on top. 

talk about pos_weight

In [2]:
# read the dataframe
insults = pd.read_csv('../data/jigsaw/train.csv')
insults['labels'] = insults[insults.columns[2:]].values.tolist()
insults = insults[['id','comment_text', 'labels']].reset_index(drop=True)

In [3]:
'''
from sklearn.model_selection import train_test_split
insults_train, insults_test = train_test_split(insults,
                                               random_state = 55,
                                               test_size = 0.35)
insults_test.head()
insults_test.columns
'''
train_size = 0.8
train_dataset=insults.sample(frac=train_size,random_state=200)
test_dataset=insults.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

In [4]:
#insults_test = pd.read_csv('../data/jigsaw/test.csv')
#insults_test_ids = pd.read_csv('../data/jigsaw/test_labels.csv')
#insults_test['labels']
train_dataset

Unnamed: 0,id,comment_text,labels
0,6725d5a6391e5c77,Goal scored for Portugal \n\nThis could be mil...,"[0, 0, 0, 0, 0, 0]"
1,28ea0d2c61db3137,My mistake someone was vandalizing the page so...,"[0, 0, 0, 0, 0, 0]"
2,4de3d6b966b58ec7,Test card F music \nCeefax music isn't the sam...,"[0, 0, 0, 0, 0, 0]"
3,af602c4c5f1b09bc,""":Meh, I guess I can live with either outcome,...","[0, 0, 0, 0, 0, 0]"
4,9e412a7965873237,"UV is my error, above. I'm told by Kimberly Ja...","[0, 0, 0, 0, 0, 0]"
...,...,...,...
127652,c123b71e899fe446,"""\n\nFair use rationale for Image:Leonidas_01....","[0, 0, 0, 0, 0, 0]"
127653,00040093b2687caa,alignment on this subject and which are contra...,"[0, 0, 0, 0, 0, 0]"
127654,40fbd29d34b1a027,", especially kickstarter-level films, or any i...","[0, 0, 0, 0, 0, 0]"
127655,1c25899b09994ad0,Commenting on my talk page \n\nThose who are ...,"[0, 0, 0, 0, 0, 0]"


In [5]:
# instantiate a class that will handle the data
class Data_Processing(object):
    def __init__(self, tokenizer, id_column, text_column, label_column):
        
        # define the text column from the dataframe
        self.text_column = text_column.tolist()
    
        # define the label column and transform it to list
        self.label_column = label_column
        
        # define the id column and transform it to list
        self.id_column = id_column.tolist()
        
    
# iter method to get each element at the time and tokenize it using bert        
    def __getitem__(self, index):
        comment_text = str(self.text_column[index])
        comment_text = " ".join(comment_text.split())
        
        inputs = tokenizer.encode_plus(comment_text,
                                       add_special_tokens = True,
                                       max_length= 2048,
                                       padding = 'max_length',
                                       return_attention_mask = True,
                                       truncation = True,
                                       return_tensors='pt')
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        labels_ = torch.tensor(self.label_column[index], dtype=torch.float)
        id_ = self.id_column[index]
        return input_ids, attention_mask, labels_, id_
  
    def __len__(self):
        return len(self.text_column) 

In [6]:
batch_size = 4
# create a class to process the traininga and test data
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096', 
                                                    padding = 'max_length',
                                                    truncation=True, 
                                                    max_length = 2048)
training_data = Data_Processing(tokenizer, 
                                train_dataset['id'], 
                                train_dataset['comment_text'], 
                                train_dataset['labels'])

test_data =  Data_Processing(tokenizer, 
                             test_dataset['id'], 
                             test_dataset['comment_text'], 
                             test_dataset['labels'])

# use the dataloaders class to load the data
dataloaders_dict = {'train': DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2),
                    'val': DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)
                   }

dataset_sizes = {'train':len(training_data),
                 'val':len(test_data)
                }

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [7]:
# check we are getting the right output
a = next(iter(dataloaders_dict.get('val')))
a
#next(iter(test_data))
#a[2].cpu().data.numpy().shape
#len(dataloaders_dict['train'])

[tensor([[[    0,   100,   524,  ...,     1,     1,     1]],
 
         [[    0, 39858, 46975,  ...,     1,     1,     1]],
 
         [[    0, 10643,   768,  ...,     1,     1,     1]],
 
         [[    0, 14517,    36,  ...,     1,     1,     1]]]),
 tensor([[[1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 0, 0, 0]],
 
         [[1, 1, 1,  ..., 0, 0, 0]]]),
 tensor([[0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]),
 ('8cb75aa42b956c08',
  'cbcbf871629fa5ae',
  '51025f6b0c82d402',
  'cdd4fb7e2462c8a6')]

In [8]:
# instantiate a Longformer for multilabel classification class

class LongformerForMultiLabelSequenceClassification(LongformerPreTrainedModel):
    """
    We instantiate a class of LongFormer adapted for a multilabel classification task. 
    This instance takes the pooled output of the LongFormer based model and passes it through a
    classification head. We replace the traditional Cross Entropy loss with a BCE loss that generate probabilities
    for all the labels that we feed into the model.
    """

    def __init__(self, config, pos_weight=None):
        super(LongformerForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.pos_weight = pos_weight
        self.longformer = LongformerModel(config)
        self.classifier = LongformerClassificationHead(config)
        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, 
                token_type_ids=None, position_ids=None, inputs_embeds=None, 
                labels=None):
        
        # create global attention on sequence, and a global attention token on the `s` token
        # the equivalent of the CLS token on BERT models
        if global_attention_mask is None:
            global_attention_mask = torch.zeros_like(input_ids)
            global_attention_mask[:, 0] = 1
        
        # pass arguments to longformer model
        outputs = self.longformer(
            input_ids = input_ids,
            attention_mask = attention_mask,
            global_attention_mask = global_attention_mask,
            token_type_ids = token_type_ids,
            position_ids = position_ids)
        
        # if specified the model can return a dict where each key corresponds to the output of a
        # LongformerPooler output class. In this case we take the last hidden state of the sequence
        # which will have the shape (batch_size, sequence_length, hidden_size). 
        sequence_output = outputs['last_hidden_state']
        
        
        # pass the hidden states through the classifier to obtain thee logits
        logits = self.classifier(sequence_output)
        if labels is not None:
            loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.num_labels), 
                            labels.view(-1, self.num_labels))
            #outputs = (loss,) + outputs
        
        
        return loss, logits


In [9]:
model = LongformerForMultiLabelSequenceClassification.from_pretrained('allenai/longformer-base-4096',
                                                  gradient_checkpointing=False,
                                                  attention_window = 512,
                                                  num_labels = 6,
                                                  cache_dir='/media/data_files/github/website_tutorials/data',
                                                                     return_dict=True)
model

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForMultiLabelSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing LongformerForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForMultiLabelSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.d

LongformerForMultiLabelSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0): LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
        

In [10]:
from transformers import AdamW
from transformers.optimization import get_linear_schedule_with_warmup
from sklearn.metrics import f1_score 
# define optimizer
learning_rate = 3e-5
epochs = 3
optimizer_ft =  AdamW([
    {"params":model.parameters(),
     "lr": learning_rate}])

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = get_linear_schedule_with_warmup(optimizer_ft,
                                                   num_warmup_steps = 500, 
                                                   num_training_steps = len(dataloaders_dict['train']) * epochs)

In [11]:
import tqdm
wandb.init(project='huggingface', name = 'longformer_multilabel_run')
# define the training 
def train(model = None, train_dataloader = None, eval_dataloader = None, epochs = 3, 
          optimizer = None, scheduler = None, accumulation_steps = 8):
    
    wandb.watch(model, log="all", log_freq=25)

    best_eval_f1 = 0.0
    for epoch in range(epochs):
        epoch_loss = []
        f1_scores = []
        model.train()
        
        number_steps = 0
        model.zero_grad()
        
        for step, batch in enumerate(train_dataloader):
            number_steps +=1
            inputs = batch[0].squeeze(1).to(device)
            att_mask = batch[1].squeeze(1).to(device)
            labels = batch[2].to(device, dtype = torch.float)
            ids = batch[3]
            # feed into model
            loss, logits = model(input_ids = inputs, attention_mask=att_mask, labels=labels)
            #
            loss = loss / accumulation_steps
            loss.backward()
            
            if (step+1) % accumulation_steps == 0:
                optimizer.step()                            # Now we can do an optimizer step
                scheduler.step()
                model.zero_grad()
            if number_steps%50==0:
                sigmoid = torch.nn.Sigmoid()
                probs = sigmoid(torch.Tensor(logits.cpu().data.numpy()))
                preds = np.zeros(probs.shape)
                preds[np.where(probs > 0.5)] = 1
                labs_ = labels.cpu().data.numpy()
                f1 = f1_score(labs_, preds, average='micro')
                f1_scores.append(f1)
                print(f' the current f1 score is {np.mean(f1_scores)}')
        epoch_loss.append(loss)
        print(f'the epoch loss was {np.sum(epoch_loss)} and'
                  f'the f1 for epoch was {np.mean(f1_scores)}')
                                
    with torch.no_grad():
        model.eval()
        eval_steps = 0
        eval_loss = []
        f1_scores_val = []
        for step, batch in enumerate(eval_dataloader):
            eval_steps +=1
            iterations_val+=1
            inputs = batch[0].squeeze(1).to(device)
            att_mask = batch[1].squeeze(1).to(device)
            labels = batch[2].to(device, dtype = torch.float)
            ids = batch[3]
            # feed into model
            loss, logits = model(input_ids = inputs, attention_mask=att_mask, labels=labels)
            sig = torch.nn.Sigmoid()
            probs = sigmoid(torch.Tensor(logits))
            preds = np.zeros(probs.shape)
            preds[np.where(probs > 0.5)] = 1
            f1 = f1_score(labels.cpu().data.numpy(), preds, average='micro')
            f1_scores_val.append(f1)
            eval_loss.append(loss)
            if counter_val % 10 == 0:
                print(f'current validation loss is {np.sum(eval_loss)/eval_steps}'
                     f'and current f1 score is {np.mean(f1_scores_val)}')
    
    eval_f1 = np.mean(val_accuracy)
    if eval_f1 >= best_eval_f1:
        best_eval_f1 = eval_f1
        print(f'saving the model with f1 score of {best_eval_f1:,.2%}')
        torch.save(model, '../results/')
        
    else:
        print(f'model did not improve')
                    
    print('Training Completed!')
    return model

[34m[1mwandb[0m: Currently logged in as: [33mjlealtru[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.15 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [12]:
model.to(device)
train(model, train_dataloader = dataloaders_dict['train'], 
      eval_dataloader= dataloaders_dict['val'],
      epochs = 3, 
      optimizer = optimizer_ft, scheduler = exp_lr_scheduler, accumulation_steps = 8)

 the current f1 score is 0.12500000000000003
 the current f1 score is 0.125
 the current f1 score is 0.1346153846153846
 the current f1 score is 0.10096153846153846


KeyboardInterrupt: 

talk about the data so we can decide the maximun length of the text

In [None]:
# get info on how long are each one of the texts
length_text_train = [len(x) for x in insults_train['comment_text']]
length_text_test = [len(x) for x in insults_test['comment_text']]
length_text = length_text_train+length_text_test

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
hist_text_length = sns.histplot(data=length_text, binwidth=200, alpha=0.55, kde=True).set_title(
    'Distribution of text length')
hist_text_length

Mention there several observations longer than 1000

In [None]:
'''
model = LongformerForSequenceClassification.from_pretrained('allenai/longformer-base-4096',
                                                           gradient_checkpointing=False,
                                                           attention_window = 512,
                                                           cache_dir='/media/data_files/github/website_tutorials/data')
tokenizer = LongformerTokenizerFast.from_pretrained('allenai/longformer-base-4096', max_length = 2048)
'''

In [None]:
'''from torch.nn import BCEWithLogitsLoss, Dropout, Linear
from transformers import LongformerTokenizerFast, LongformerModel, LongformerConfig
from transformers.models.longformer.modeling_longformer import LongformerPreTrainedModel,LongformerClassificationHead


# instantiate the multi-label classification class

class LongFormerMultilabelClass(LongformerPreTrainedModel):
    def __init__(self, config, pos_weight = None):
        super(LongFormerMultilabelClass, self).__init__(config)
        self.num_labels = config.num_labels
        self.LongformerModel = LongformerModel(config)
        self.dropout = Dropout(0.3)
        self.classifier = Linear(config.hidden_size, config.num_labels)
        self.init_weights()
        

    def forward(self, input_ids = None, attention_mask = None, token_type_ids = None, position_ids = None,
                head_mask = None, inputs_embeds=None, labels = None):
        
        outputs = self.LongformerModel(input_ids, attention_mask=attention_mask, 
                                       token_type_ids=token_type_ids, position_ids=position_ids,
                                       head_mask=head_mask,
                                       inputs_embeds=inputs_embeds)
        
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        outputs = (logits,) + outputs[2:]
        if labels is not None:
            loss_fct = BCEWithLogitsLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))

            outputs = (loss,) + outputs

        return outputs 
'''