# Fine-tuning BERT (and friends) for multi-label text classification

In this notebook, we are going to fine-tune BERT to predict one or more labels for a given piece of text. Note that this notebook illustrates how to fine-tune a bert-base-uncased model, but you can also fine-tune a RoBERTa, DeBERTa, DistilBERT, CANINE, ... checkpoint in the same way. 

All of those work in the same way: they add a linear layer on top of the base model, which is used to produce a tensor of shape (batch_size, num_labels), indicating the unnormalized scores for a number of labels for every example in the batch.



## Set-up environment

First, we install the libraries which we'll use: HuggingFace Transformers and Datasets.

In [1]:
# !pip install -q transformers datasets

## Load dataset

Next, let's download a multi-label text classification dataset from the [hub](https://huggingface.co/).

At the time of writing, I picked a random one as follows:   

* first, go to the "datasets" tab on huggingface.co
* next, select the "multi-label-classification" tag on the left as well as the the "1k<10k" tag (fo find a relatively small dataset).

Note that you can also easily load your local data (i.e. csv files, txt files, Parquet files, JSON, ...) as explained [here](https://huggingface.co/docs/datasets/loading.html#local-and-remote-files).



In [16]:
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer

In [17]:
from datasets import load_dataset

dataset = load_dataset("nlpaueb/multi_eurlex", "en")

Reusing dataset multi_eurlex (/home/davo/.cache/huggingface/datasets/nlpaueb___multi_eurlex/en/1.0.0/1addee7110a20c2b01cc3de89456786482e4eea1d2ead0bea3d5383b16cc9fce)


  0%|          | 0/3 [00:00<?, ?it/s]

As we can see, the dataset contains 3 splits: one for training, one for validation and one for testing.

In [18]:

dataset = dataset.rename_column("labels", "old_labels")

In [19]:

dataset

DatasetDict({
    train: Dataset({
        features: ['celex_id', 'text', 'old_labels'],
        num_rows: 11000
    })
    test: Dataset({
        features: ['celex_id', 'text', 'old_labels'],
        num_rows: 5000
    })
    validation: Dataset({
        features: ['celex_id', 'text', 'old_labels'],
        num_rows: 1000
    })
})

Let's check the first example of the training split:

In [20]:
example = dataset['train'][0]
example

{'celex_id': '32003R1012',
 'text': 'Commission Regulation (EC) No 1012/2003\nof 12 June 2003\namending for the 19th time Council Regulation (EC) No 881/2002 imposing certain specific restrictive measures directed against certain persons and entities associated with Usama bin Laden, the Al-Qaida network and the Taliban, and repealing Council Regulation (EC) No 467/2001\nTHE COMMISSION OF THE EUROPEAN COMMUNITIES,\nHaving regard to the Treaty establishing the European Community,\nHaving regard to Council Regulation (EC) No 881/2002 of 27 May 2002 imposing certain specific restrictive measures directed against certain persons and entities associated with Usama bin Laden, the Al-Qaida network and the Taliban, and repealing Council Regulation (EC) No 467/2001 prohibiting the export of certain goods and services to Afghanistan, strengthening the flight ban and extending the freeze of funds and other financial resources in respect of the Taliban of Afghanistan(1), as last amended by Commissi

The dataset consists of tweets, labeled with one or more emotions. 

Let's create a list that contains the labels, as well as 2 dictionaries that map labels to integers and back.

In [21]:
labels = ['social questions',
 'industry',
 'finance',
 'trade',
 'business and competition',
 'international relations',
 'agriculture, forestry and fisheries',
 'production, technology and research',
 'transport',
 'employment and working conditions',
 'politics',
 'law',
 'education and communications',
 'international organisations',
 'energy',
 'EUROPEAN UNION',
 'science',
 'agri-foodstuffs',
 'geography',
 'economics',
 'environment']
id2label = {0: 'social questions',
 1: 'industry',
 2: 'finance',
 3: 'trade',
 4: 'business and competition',
 5: 'international relations',
 6: 'agriculture, forestry and fisheries',
 7: 'production, technology and research',
 8: 'transport',
 9: 'employment and working conditions',
 10: 'politics',
 11: 'law',
 12: 'education and communications',
 13: 'international organisations',
 14: 'energy',
 15: 'EUROPEAN UNION',
 16: 'science',
 17: 'agri-foodstuffs',
 18: 'geography',
 19: 'economics',
 20: 'environment'}
label2id = {'social questions': 0,
 'industry': 1,
 'finance': 2,
 'trade': 3,
 'business and competition': 4,
 'international relations': 5,
 'agriculture, forestry and fisheries': 6,
 'production, technology and research': 7,
 'transport': 8,
 'employment and working conditions': 9,
 'politics': 10,
 'law': 11,
 'education and communications': 12,
 'international organisations': 13,
 'energy': 14,
 'EUROPEAN UNION': 15,
 'science': 16,
 'agri-foodstuffs': 17,
 'geography': 18,
 'economics': 19,
 'environment': 20}
labels

['social questions',
 'industry',
 'finance',
 'trade',
 'business and competition',
 'international relations',
 'agriculture, forestry and fisheries',
 'production, technology and research',
 'transport',
 'employment and working conditions',
 'politics',
 'law',
 'education and communications',
 'international organisations',
 'energy',
 'EUROPEAN UNION',
 'science',
 'agri-foodstuffs',
 'geography',
 'economics',
 'environment']

## Preprocess data

As models like BERT don't expect text as direct input, but rather `input_ids`, etc., we tokenize the text using the tokenizer. Here I'm using the `AutoTokenizer` API, which will automatically load the appropriate tokenizer based on the checkpoint on the hub.

What's a bit tricky is that we also need to provide labels to the model. For multi-label text classification, this is a matrix of shape (batch_size, num_labels). Also important: this should be a tensor of floats rather than integers, otherwise PyTorch' `BCEWithLogitsLoss` (which the model will use) will complain, as explained [here](https://discuss.pytorch.org/t/multi-label-binary-classification-result-type-float-cant-be-cast-to-the-desired-output-type-long/117915/3).

In [22]:
def numbers_to_classes(l=[0,10,20]):
  zero_cl = [0.0] * 21
  for i in l:
    zero_cl[i] = 1.0
  
  return np.array(zero_cl, dtype=np.float)

In [23]:
numbers_to_classes(l=[0,10,20])

array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 1.])

In [24]:
from transformers import AutoTokenizer
import numpy as np
import pandas as pd

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def preprocess_data(examples):
  # take a batch of texts
  text = examples["text"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=300)
  # add labels

  labels_matrix = []
  for r in examples["old_labels"]:
    labels_matrix.append(np.array(numbers_to_classes(r), dtype=float))
  # print(labels_matrix)
  encoding["labels"] = np.array(labels_matrix, dtype=float)
  
  return encoding

In [25]:
encoded_dataset = dataset.map(preprocess_data, batched=True)



  0%|          | 0/11 [00:00<?, ?ba/s]

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [26]:
example = encoded_dataset['train'][0]
print(example.keys())

dict_keys(['celex_id', 'text', 'old_labels', 'input_ids', 'token_type_ids', 'attention_mask', 'labels'])


In [27]:
tokenizer.decode(example['input_ids'])

'[CLS] commission regulation ( ec ) no 1012 / 2003 of 12 june 2003 amending for the 19th time council regulation ( ec ) no 881 / 2002 imposing certain specific restrictive measures directed against certain persons and entities associated with usama bin laden, the al - qaida network and the taliban, and repealing council regulation ( ec ) no 467 / 2001 the commission of the european communities, having regard to the treaty establishing the european community, having regard to council regulation ( ec ) no 881 / 2002 of 27 may 2002 imposing certain specific restrictive measures directed against certain persons and entities associated with usama bin laden, the al - qaida network and the taliban, and repealing council regulation ( ec ) no 467 / 2001 prohibiting the export of certain goods and services to afghanistan, strengthening the flight ban and extending the freeze of funds and other financial resources in respect of the taliban of afghanistan ( 1 ), as last amended by commission regul

In [28]:
[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]

['finance',
 'trade',
 'international relations',
 'transport',
 'politics',
 'EUROPEAN UNION',
 'geography']

Finally, we set the format of our data to PyTorch tensors. This will turn the training, validation and test sets into standard PyTorch [datasets](https://pytorch.org/docs/stable/data.html). 

In [29]:
encoded_dataset.set_format("torch")

## Define model

Here we define a model that includes a pre-trained base (i.e. the weights from bert-base-uncased) are loaded, with a random initialized classification head (linear layer) on top. One should fine-tune this head, together with the pre-trained base on a labeled dataset.

This is also printed by the warning.

We set the `problem_type` to be "multi_label_classification", as this will make sure the appropriate loss function is used (namely [`BCEWithLogitsLoss`](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)). We also make sure the output layer has `len(labels)` output neurons, and we set the id2label and label2id mappings.

In [30]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased", 
                                                           problem_type="multi_label_classification", 
                                                           num_labels=len(labels),
                                                           id2label=id2label,
                                                           label2id=label2id)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

## Train the model!

We are going to train the model using HuggingFace's Trainer API. This requires us to define 2 things: 

* `TrainingArguments`, which specify training hyperparameters. All options can be found in the [docs](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments). Below, we for example specify that we want to evaluate after every epoch of training, we would like to save the model every epoch, we set the learning rate, the batch size to use for training/evaluation, how many epochs to train for, and so on.
* a `Trainer` object (docs can be found [here](https://huggingface.co/transformers/main_classes/trainer.html#id1)).

In [31]:
batch_size = 16
metric_name = "f1"

In [32]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-finetuned-sem_eval-english",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=50,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #push_to_hub=True,
)

We are also going to compute metrics while training. For this, we need to define a `compute_metrics` function, that returns a dictionary with the desired metric values.

In [33]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score
from transformers import EvalPrediction
import torch
    
# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}
    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, 
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds, 
        labels=p.label_ids)
    return result

Let's verify a batch as well as a forward pass:

In [34]:
encoded_dataset['train']['labels'][0]

tensor([0., 0., 1., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0.,
        1., 0., 0.])

In [35]:

encoded_dataset['train']['labels']

tensor([[0., 0., 1.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [36]:
encoded_dataset['train']['input_ids'][0]

tensor([  101,  3222,  7816,  1006, 14925,  1007,  2053,  7886,  2475,  1013,
         2494,  1997,  2260,  2238,  2494, 27950,  2075,  2005,  1996,  3708,
         2051,  2473,  7816,  1006, 14925,  1007,  2053,  6070,  2487,  1013,
         2526, 16625,  3056,  3563, 25986,  5761,  2856,  2114,  3056,  5381,
         1998, 11422,  3378,  2007,  3915,  2863,  8026, 14887,  1010,  1996,
         2632,  1011,  1053, 14326,  2050,  2897,  1998,  1996, 16597,  1010,
         1998, 21825,  2075,  2473,  7816,  1006, 14925,  1007,  2053,  4805,
         2581,  1013,  2541,  1996,  3222,  1997,  1996,  2647,  4279,  1010,
         2383,  7634,  2000,  1996,  5036,  7411,  1996,  2647,  2451,  1010,
         2383,  7634,  2000,  2473,  7816,  1006, 14925,  1007,  2053,  6070,
         2487,  1013,  2526,  1997,  2676,  2089,  2526, 16625,  3056,  3563,
        25986,  5761,  2856,  2114,  3056,  5381,  1998, 11422,  3378,  2007,
         3915,  2863,  8026, 14887,  1010,  1996,  2632,  1011, 

In [37]:
#forward pass
outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train']['labels'][0].unsqueeze(0))
outputs

SequenceClassifierOutput(loss=tensor(0.7208, grad_fn=<BinaryCrossEntropyWithLogitsBackward>), logits=tensor([[ 0.6385,  0.0477,  0.0472,  1.1657, -0.4790,  0.0196, -0.0694,  0.7476,
          0.0020,  0.1310,  0.8710,  0.0487, -0.0142,  0.1482,  0.9864, -0.4025,
          0.2935, -0.0678,  0.3514, -0.4830,  0.0882]],
       grad_fn=<AddmmBackward>), hidden_states=None, attentions=None)

Let's start training!

In [38]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["validation"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [39]:
trainer.train()

The following columns in the training set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 11000
  Num Epochs = 50
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 34400
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33mdyada[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy
1,0.2974,0.242892,0.699967,0.782958,0.16
2,0.1956,0.206527,0.768737,0.831223,0.226
3,0.144,0.194616,0.780488,0.839774,0.238
4,0.1264,0.192727,0.77869,0.842318,0.238
5,0.1149,0.1919,0.786677,0.848659,0.238
6,0.0901,0.196106,0.787879,0.850874,0.249
7,0.0813,0.197205,0.795419,0.861934,0.242
8,0.0661,0.202194,0.796569,0.862962,0.244
9,0.0567,0.21339,0.778095,0.852925,0.215
10,0.0512,0.208805,0.791625,0.857577,0.217


The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to bert-finetuned-sem_eval-english/checkpoint-688
Configuration saved in bert-finetuned-sem_eval-english/checkpoint-688/config.json
Model weights saved in bert-finetuned-sem_eval-english/checkpoint-688/pytorch_model.bin
tokenizer config file saved in bert-finetuned-sem_eval-english/checkpoint-688/tokenizer_config.json
Special tokens file saved in bert-finetuned-sem_eval-english/checkpoint-688/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to bert-finetuned-sem_eval-english/checkpoint-7568
Configuration saved in bert-finetuned-sem_eval-english/checkpoint-7568/config.json
Model weights saved in bert-finetuned-sem_eval-english/checkpoint-7568/pytorch_model.bin
tokenizer config file saved in bert-finetuned-sem_eval-english/checkpoint-7568/tokenizer_config.json
Special tokens file saved in bert-finetuned-sem_eval-english/checkpoint-7568/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to bert-finetuned-sem_eval-english/checkpoint-14448
Configuration saved in bert-finetuned-sem_eval-english/checkpoint-14448/config.json
Model weights saved in bert-finetuned-sem_eval-english/checkpoint-14448/pytorch_model.bin
tokenizer config file saved in bert-finetuned-sem_eval-english/checkpoint-14448/tokenizer_config.json
Special tokens file saved in bert-finetuned-sem_eval-english/checkpoint-14448/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_label

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to bert-finetuned-sem_eval-english/checkpoint-21328
Configuration saved in bert-finetuned-sem_eval-english/checkpoint-21328/config.json
Model weights saved in bert-finetuned-sem_eval-english/checkpoint-21328/pytorch_model.bin
tokenizer config file saved in bert-finetuned-sem_eval-english/checkpoint-21328/tokenizer_config.json
Special tokens file saved in bert-finetuned-sem_eval-english/checkpoint-21328/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_label

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16
Saving model checkpoint to bert-finetuned-sem_eval-english/checkpoint-28208
Configuration saved in bert-finetuned-sem_eval-english/checkpoint-28208/config.json
Model weights saved in bert-finetuned-sem_eval-english/checkpoint-28208/pytorch_model.bin
tokenizer config file saved in bert-finetuned-sem_eval-english/checkpoint-28208/tokenizer_config.json
Special tokens file saved in bert-finetuned-sem_eval-english/checkpoint-28208/special_tokens_map.json
The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_label



Training completed. Do not forget to share your model on huggingface.co/models =)


Loading best model from bert-finetuned-sem_eval-english/checkpoint-22704 (score: 0.8050717570015328).


TrainOutput(global_step=34400, training_loss=0.030621453419674274, metrics={'train_runtime': 13689.2812, 'train_samples_per_second': 40.177, 'train_steps_per_second': 2.513, 'total_flos': 8.480611359e+16, 'train_loss': 0.030621453419674274, 'epoch': 50.0})

## Evaluate

After training, we evaluate our model on the validation set.

In [40]:
trainer.evaluate()

The following columns in the evaluation set  don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text, celex_id, old_labels. If text, celex_id, old_labels are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1000
  Batch size = 16


{'eval_loss': 0.31552064418792725,
 'eval_f1': 0.8050717570015328,
 'eval_roc_auc': 0.8746604597581796,
 'eval_accuracy': 0.276,
 'eval_runtime': 7.1405,
 'eval_samples_per_second': 140.045,
 'eval_steps_per_second': 8.823,
 'epoch': 50.0}

## Inference

Let's test the model on a new sentence:

In [52]:
dataset['test']['text'][0]

'COUNCIL REGULATION (EU) No 1390/2013\nof 16 December 2013\non the allocation of fishing opportunities under the Protocol agreed between the European Union and the Union of the Comoros setting out the fishing opportunities and financial contribution provided for in the Fisheries Partnership Agreement currently in force between the two parties\nTHE COUNCIL OF THE EUROPEAN UNION,\nHaving regard to the Treaty on the Functioning of the European Union, and in particular Article 43(3) thereof,\nHaving regard to the proposal from the European Commission,\nWhereas:\n(1)\nOn 5 October 2006, the Council approved the conclusion of the Partnership Agreement in the fisheries sector between the European Community and the Union of the Comoros (the ‘Partnership Agreement’) by adopting Regulation (EC) No 1563/2006 (1).\n(2)\nThe European Union negotiated with the Union of the Comoros a new Protocol to the Partnership Agreement granting vessels of the European Union fishing opportunities in Comoros wate

### Below you can observe a few examples with their predicted outputs for Bert model and what actually was true label

In [77]:
def predict_labels(text, thres_prob=0.5):
    encoding = tokenizer(text, return_tensors="pt")
    encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

    outputs = trainer.model(**encoding)
    logits = outputs.logits
    # apply sigmoid + threshold
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(logits.squeeze().cpu())
    predictions = np.zeros(probs.shape)
    predictions[np.where(probs >= thres_prob)] = 1
    # turn predicted id's into actual label names
    predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]

    return predicted_labels

In [78]:
predicted_labels = predict_labels(text=dataset['test']['text'][0][:512])

In [79]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['international relations', 'agriculture, forestry and fisheries', 'geography']


In [80]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][0])) if label == 1.0])

true labels>>> ['international relations', 'agriculture, forestry and fisheries', 'EUROPEAN UNION', 'geography']


### In the case above Bert model was able to pick up all true categories but missed 'EUROPEAN UNION'  category

In [81]:
predicted_labels = predict_labels(text=dataset['test']['text'][100][:512])

In [82]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['trade', 'agriculture, forestry and fisheries', 'agri-foodstuffs', 'geography']


In [83]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][100])) if label == 1.0])

true labels>>> ['trade', 'agriculture, forestry and fisheries', 'production, technology and research', 'agri-foodstuffs', 'geography']


### In the case above Bert model was able to pick up all true categories but missed 'production, technology and research'  category

In [84]:
predicted_labels = predict_labels(text=dataset['test']['text'][1000][:512])

In [85]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['international relations', 'EUROPEAN UNION', 'geography']


In [86]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][1000])) if label == 1.0])

true labels>>> ['trade', 'international relations', 'EUROPEAN UNION', 'geography']


### In the case above Bert model was able to pick up all true categories but also missed 'trade' category

In [102]:
predicted_labels = predict_labels(text=dataset['test']['text'][1000][:512], thres_prob=0.01)

In [103]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['trade', 'international relations', 'EUROPEAN UNION', 'geography', 'economics']


In [104]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][1000])) if label == 1.0])

true labels>>> ['trade', 'international relations', 'EUROPEAN UNION', 'geography']


### if we reduce threshold to 1% for accepting  a label as related to the class then we capture trade but mislabel economics 

In [105]:
predicted_labels = predict_labels(text=dataset['test']['text'][1111][:512])

In [106]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['social questions', 'trade', 'agri-foodstuffs']


In [107]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][1111])) if label == 1.0])

true labels>>> ['trade', 'agri-foodstuffs']


## we capture trade and agri-foodstuffs but mislabel social questions category

In [130]:
predicted_labels = predict_labels(text=dataset['test']['text'][2230][:512])

In [131]:
print('predicted_labels>>>', predicted_labels)

predicted_labels>>> ['trade', 'agri-foodstuffs']


In [132]:
print('true labels>>>', [id2label[idx] for idx, label in enumerate(numbers_to_classes(dataset['test']['old_labels'][2230])) if label == 1.0])

true labels>>> ['trade', 'agri-foodstuffs']


### Based on output above we can observe that Bert returns also fully correct list of labels as well.

### Based on examples above it is clear thta BERT model learned to pick up majority of the categories but very often missed one category from the true category list.  Good part Bert picks up majority of the classes, bad it often misses one class

## Metrics that would be useful for multilabel classification are F1 score and AUC/ROC curve. We will stick to F1

In [134]:
{'eval_loss': 0.31552064418792725,
 'eval_f1': 0.8050717570015328,
 'eval_roc_auc': 0.8746604597581796,
 'eval_accuracy': 0.276,
 'eval_runtime': 7.1405,
 'eval_samples_per_second': 140.045,
 'eval_steps_per_second': 8.823,
 'epoch': 50.0}

{'eval_loss': 0.31552064418792725,
 'eval_f1': 0.8050717570015328,
 'eval_roc_auc': 0.8746604597581796,
 'eval_accuracy': 0.276,
 'eval_runtime': 7.1405,
 'eval_samples_per_second': 140.045,
 'eval_steps_per_second': 8.823,
 'epoch': 50.0}

## above we can observe performance measures of BERT model trained with 50 epochs for the test set with f1 score around 80%

In [133]:
1/21

0.047619047619047616

## To compare with random classifier we have 21 categories so random classifier which always predicts same category should be able to obtain around 5% accuracy while our model has 27.6% on test set.