# Multi-label classification with the Longformer

In a previous post I explored the functionality of the Longformer for text classification. In this post I will explore the performance of the Longformer in a setting of multilabel classification problem.

For this dataset we need to download it manually from Kaggle and load it like we usually do with the datasets library. 'jigsaw_toxicity_pred', data_dir='/path/to/extracted/data/'

In [1]:
import torch
import transformers
import pandas as pd
import numpy as np
from torch.nn import BCEWithLogitsLoss
from transformers import LongformerTokenizerFast, \
LongformerModel, LongformerConfig, Trainer, TrainingArguments,EvalPrediction, AutoTokenizer
from transformers.models.longformer.modeling_longformer import LongformerPreTrainedModel, LongformerClassificationHead
from torch.utils.data import Dataset, DataLoader
import wandb
import random

In [2]:
# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

In [3]:
#wandb.login()

We are going to instantiate a raw LongFormer Model and add a classifier head on top. 

talk about pos_weight

In [4]:
# read the dataframe
insults = pd.read_csv('../data/jigsaw/train.csv')
#insults = insults.iloc[0:6000]
insults['labels'] = insults[insults.columns[2:]].values.tolist()
insults = insults[['id','comment_text', 'labels']].reset_index(drop=True)

In [5]:
'''
from sklearn.model_selection import train_test_split
insults_train, insults_test = train_test_split(insults,
                                               random_state = 55,
                                               test_size = 0.35)
insults_test.head()
insults_test.columns
'''
train_size = 0.9
train_dataset=insults.sample(frac=train_size,random_state=200)
test_dataset=insults.drop(train_dataset.index).reset_index(drop=True)
train_dataset = train_dataset.reset_index(drop=True)

In [6]:
#insults_test = pd.read_csv('../data/jigsaw/test.csv')
#insults_test_ids = pd.read_csv('../data/jigsaw/test_labels.csv')
#insults_test['labels']
train_dataset

Unnamed: 0,id,comment_text,labels
0,6725d5a6391e5c77,Goal scored for Portugal \n\nThis could be mil...,"[0, 0, 0, 0, 0, 0]"
1,28ea0d2c61db3137,My mistake someone was vandalizing the page so...,"[0, 0, 0, 0, 0, 0]"
2,4de3d6b966b58ec7,Test card F music \nCeefax music isn't the sam...,"[0, 0, 0, 0, 0, 0]"
3,af602c4c5f1b09bc,""":Meh, I guess I can live with either outcome,...","[0, 0, 0, 0, 0, 0]"
4,9e412a7965873237,"UV is my error, above. I'm told by Kimberly Ja...","[0, 0, 0, 0, 0, 0]"
...,...,...,...
143609,8ef0a57c815d2e28,Silencing references through IP blocks? You're...,"[1, 0, 0, 0, 1, 0]"
143610,67f6241012e2cd6f,REGARDING THE KAINTHESCION BULLSHIT SHAM RFAR\...,"[1, 0, 1, 0, 1, 0]"
143611,1e450b9450c1ac12,Red links are fine \n\nPlease read WP:REDLINK;...,"[0, 0, 0, 0, 0, 0]"
143612,a59468f4d867e9f9,"""\n\nHow is """"And You Will Know Us by the Trai...","[0, 0, 0, 0, 0, 0]"


In [7]:
# instantiate a class that will handle the data
class Data_Processing(object):
    def __init__(self, tokenizer, id_column, text_column, label_column):
        
        # define the text column from the dataframe
        self.text_column = text_column.tolist()
    
        # define the label column and transform it to list
        self.label_column = label_column
        
        # define the id column and transform it to list
        self.id_column = id_column.tolist()
        
    
# iter method to get each element at the time and tokenize it using bert        
    def __getitem__(self, index):
        comment_text = str(self.text_column[index])
        comment_text = " ".join(comment_text.split())
        
        inputs = tokenizer.encode_plus(comment_text,
                                       add_special_tokens = True,
                                       max_length= 3048,
                                       padding = 'max_length',
                                       return_attention_mask = True,
                                       truncation = True,
                                       return_tensors='pt')
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        
        labels_ = torch.tensor(self.label_column[index], dtype=torch.float)
        id_ = self.id_column[index]
        return {'input_ids':input_ids[0], 'attention_mask':attention_mask[0], 
                'labels':labels_, 'id_':id_}
  
    def __len__(self):
        return len(self.text_column) 

In [8]:
batch_size = 4
# create a class to process the traininga and test data
tokenizer = AutoTokenizer.from_pretrained('allenai/longformer-base-4096', 
                                                    padding = 'max_length',
                                                    truncation=True, 
                                                    max_length = 3048)
training_data = Data_Processing(tokenizer, 
                                train_dataset['id'], 
                                train_dataset['comment_text'], 
                                train_dataset['labels'])

test_data =  Data_Processing(tokenizer, 
                             test_dataset['id'], 
                             test_dataset['comment_text'], 
                             test_dataset['labels'])

# use the dataloaders class to load the data
dataloaders_dict = {'train': DataLoader(training_data, batch_size=batch_size, shuffle=True, num_workers=2),
                    'val': DataLoader(test_data, batch_size=batch_size, shuffle=True, num_workers=2)
                   }

dataset_sizes = {'train':len(training_data),
                 'val':len(test_data)
                }

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [9]:
# check we are getting the right output
a = next(iter(dataloaders_dict['val']))
a['id_']
#len(dataloaders_dict['train'])

['f1d0cb1d6f088ae6',
 'e8409c8ccdb62cf7',
 'cca0d9164a93747d',
 '0977b33c30bf9bf0']

In [10]:
# instantiate a Longformer for multilabel classification class

class LongformerForMultiLabelSequenceClassification(LongformerPreTrainedModel):
    """
    We instantiate a class of LongFormer adapted for a multilabel classification task. 
    This instance takes the pooled output of the LongFormer based model and passes it through a
    classification head. We replace the traditional Cross Entropy loss with a BCE loss that generate probabilities
    for all the labels that we feed into the model.
    """

    def __init__(self, config, pos_weight=None):
        super(LongformerForMultiLabelSequenceClassification, self).__init__(config)
        self.num_labels = config.num_labels
        self.pos_weight = pos_weight
        self.longformer = LongformerModel(config)
        self.classifier = LongformerClassificationHead(config)
        self.init_weights()
        
    def forward(self, input_ids=None, attention_mask=None, global_attention_mask=None, 
                token_type_ids=None, position_ids=None, inputs_embeds=None, 
                labels=None):
        
        # create global attention on sequence, and a global attention token on the `s` token
        # the equivalent of the CLS token on BERT models
        if global_attention_mask is None:
            global_attention_mask = torch.zeros_like(input_ids)
            global_attention_mask[:, 0] = 1
        
        # pass arguments to longformer model
        outputs = self.longformer(
            input_ids = input_ids,
            attention_mask = attention_mask,
            global_attention_mask = global_attention_mask,
            token_type_ids = token_type_ids,
            position_ids = position_ids)
        
        # if specified the model can return a dict where each key corresponds to the output of a
        # LongformerPooler output class. In this case we take the last hidden state of the sequence
        # which will have the shape (batch_size, sequence_length, hidden_size). 
        sequence_output = outputs['last_hidden_state']
        
        
        # pass the hidden states through the classifier to obtain thee logits
        logits = self.classifier(sequence_output)
        outputs = (logits,) + outputs[2:]
        if labels is not None:
            loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
            labels = labels.float()
            loss = loss_fct(logits.view(-1, self.num_labels), 
                            labels.view(-1, self.num_labels))
            #outputs = (loss,) + outputs
            outputs = (loss,) + outputs
        
        
        return outputs


In [11]:
model = LongformerForMultiLabelSequenceClassification.from_pretrained('allenai/longformer-base-4096',
                                                  #'/media/data_files/github/website_tutorials/results/longformer_2048_multilabel_jigsaw',
                                                  gradient_checkpointing=False,
                                                  attention_window = 512,
                                                  num_labels = 6,
                                                  cache_dir='/media/data_files/github/website_tutorials/data',
                                                                     return_dict=True)
model

Some weights of the model checkpoint at allenai/longformer-base-4096 were not used when initializing LongformerForMultiLabelSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight']
- This IS expected if you are initializing LongformerForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LongformerForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LongformerForMultiLabelSequenceClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.d

LongformerForMultiLabelSequenceClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0): LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
        

In [12]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
#acc = accuracy_score(labels, preds)
    #acc = accuracy_score(labels, preds)
    
def multi_label_metric(
    predictions, 
    references, 
    ):
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    y_pred = np.zeros(probs.shape)
    y_true = references
    y_pred[np.where(probs >= 0.5)] = 1
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    metrics = {'f1':f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    result = multi_label_metric(
        predictions=preds, 
        references=p.label_ids
    )
    return result

In [13]:
# define the training arguments
training_args = TrainingArguments(
    output_dir = '/media/data_files/github/website_tutorials/results',
    num_train_epochs = 4,
    per_device_train_batch_size = 2,
    gradient_accumulation_steps = 64,    
    per_device_eval_batch_size= 16,
    evaluation_strategy = "epoch",
    disable_tqdm = False, 
    load_best_model_at_end=True,
    warmup_steps = 1500,
    learning_rate = 2e-5,
    weight_decay=0.01,
    logging_steps = 8,
    fp16 = False,
    logging_dir='/media/data_files/github/website_tutorials/logs',
    dataloader_num_workers = 0,
    run_name = 'longformer_multilabel_paper_trainer_3048_2e5'
)

In [14]:
# instantiate the trainer class and check for available devices
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=training_data,
    eval_dataset=test_data,
    compute_metrics = compute_metrics,
    #data_collator = Data_Processing(),

)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [15]:
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mjlealtru[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.20 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 

In [None]:
trainer.evaluate()

In [None]:
insults_test = pd.read_csv('../data/jigsaw/test.csv')
#insults_test = insults_test.iloc[0:101]
insults_test

In [None]:
# instantiate a class that will handle the data
class Data_Processing_test():
    def __init__(self, tokenizer, id_column, text_column):
        
        # define the text column from the dataframe
        self.text_column = text_column.tolist()
                    
        # define the id column and transform it to list
        self.id_column = id_column.tolist()
            
# iter method to get each element at the time and tokenize it using bert        
    def __getitem__(self, index):
        comment_text = str(self.text_column[index])
        comment_text = " ".join(comment_text.split())
        
        inputs = tokenizer.encode_plus(comment_text,
                                       add_special_tokens = True,
                                       max_length= 2048,
                                       padding = 'max_length',
                                       return_attention_mask = True,
                                       truncation = True,
                                       return_tensors='pt')
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']
        id_ = self.id_column[index]
        return {'input_ids':input_ids[0], 'attention_mask':attention_mask[0], 
                'id_':id_}
  
    def __len__(self):
        return len(self.text_column) 

In [None]:
batch_size = 16
# create a class to process the traininga and test data

test_data_pred =  Data_Processing_test(tokenizer,
                                       insults_test['id'], 
                                       insults_test['comment_text'])

# use the dataloaders class to load the data
dataloaders_dict = {'test': DataLoader(test_data_pred,
                                                 batch_size=batch_size, shuffle=True, num_workers=2)}

In [None]:
def prediction():
    prediction_data_frame_list = []
    with torch.no_grad():
        trainer.model.eval()
        for i, batch in enumerate(dataloaders_dict['test']):
            inputs = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            # feed the sequences to the model, specifying the attention mask
            outputs = model(inputs, attention_mask=attention_mask)
            # feed the logits returned by the model to the softmax to classify the function
            sigmoid = torch.nn.Sigmoid()
            probs = sigmoid(torch.Tensor(outputs[0].detach().cpu().data.numpy()))
            #probs.
            probs = np.array(probs)
            #print(np.array([[i] for i in probs]))
            y_pred = np.zeros(probs.shape)
            y_pred = probs
            temp_data = pd.DataFrame(zip(batch['id_'], probs), columns = ['id', 'target'
                                                                         ])
            #print(temp_data)
            prediction_data_frame_list.append(temp_data)                

    prediction_df = pd.concat(prediction_data_frame_list)
    prediction_df[['toxic','severe_toxic',
                   'obscene','threat','insult','identity_hate']] = pd.DataFrame(prediction_df.target.tolist(),
                                                                                index= prediction_df.index)
    prediction_df = prediction_df.drop(columns = 'target')
    return prediction_df

predictions = prediction()

In [None]:
predictions.to_csv('../data/jigsaw/submission_longf_3048_2e5.csv', index=False)

In [None]:
trainer.model.save_pretrained('/media/data_files/github/website_tutorials/results/longformer_base_multilabel_3048_2e5')
tokenizer.save_pretrained('/media/data_files/github/website_tutorials/results/longformer_base_multilabel_3048_2e5')

Mention there several observations longer than 1000