# Fine-tuning BERT for multi-label text classification

In this notebook, we are going to fine-tune BERT to predict multiple labels for a given piece of text. 


## Load dataset

Next, let's load our processed **balanced** dataset to address the class imbalance in the raw dataset.

For this experiment, I use datasets as it is more compatible with with loading the data while fine-tuning the BERT-like models in the huggingface ecosystem:   

* Load the created canonical datasets splits as dataframe, 
* next, cast the dataset to a huggingface dataset object.


In [2]:
import pandas as pd
from datasets import Dataset, DatasetDict

train_df = pd.read_csv('../data/processed/clean_train_upsampled.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')


# dataset = load_dataset("ethos", "multilabel")

ds_dict = {'train' : Dataset.from_pandas(train_df),
           'valid' : Dataset.from_pandas(valid_df)}

dataset = DatasetDict(ds_dict)
dataset

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['clean_content', 'cyber_label', 'environmental_issue'],
        num_rows: 2904
    })
    valid: Dataset({
        features: ['clean_content', 'cyber_label', 'environmental_issue'],
        num_rows: 252
    })
})

As we can see, the dataset contains 3 splits: one for training, one for validation and one for testing.

Let's check the first example of the training split:

In [4]:
example = dataset['train'][0]
example

{'clean_content': 'Additionally the risk of natural catastrophes, solar\nstorms or cyber attacks could impact the infrastructure, (incl. satellites, GPS and communications systems).\nAlso energy transition may impact stability of energy supply. A smoothly functioning digital infrastructure is\nbecoming increasingly important, especially in times of remote home office working.\n2008 & 2011\nCyber Risks The volume and sophistication of malicious cyber activity has increased substantially, and there are growing\nconcerns regarding the security of proprietary corporate data and critical industrial control systems.\nCloud computing poses elevated risks due to increased concentration and accumulations. Operational\nrisks exist for corporations and could also lead to large property losses with high and previously unknown\naccumulation potential if industrial facilities were simultaneously attacked. The growing request for personal\nidentification and authentication, the use of biometric ident

Let's create a list that contains the labels, as well as 2 dictionaries that map labels to integers and back.

In [5]:
labels = [label for label in dataset['train'].features.keys() if label not in ['clean_content']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['cyber_label', 'environmental_issue']

## Preprocess data

As models like BERT don't expect text as direct input, but rather `input_ids`, etc., I tokenize the text using the tokenizer. Here I'm using the `AutoTokenizer` API, which will automatically load the appropriate tokenizer based on the checkpoint on the hub.

What's a bit tricky is that we also need to provide labels to the model. For multi-label text classification, this is a matrix of shape (batch_size, num_labels). Also important: this should be a tensor of floats rather than integers, otherwise PyTorch' `BCEWithLogitsLoss` (which the model uses) will complain, as explained [here](https://discuss.pytorch.org/t/multi-label-binary-classification-result-type-float-cant-be-cast-to-the-desired-output-type-long/117915/3).

In [6]:
from transformers import AutoTokenizer
import numpy as np

model_id = "distilbert/distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id)

def preprocess_data(examples):
  # take a batch of texts
  text = examples["clean_content"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  print(labels)
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()

  return encoding

Let's apply this on pre-processing function on our dataset using `map` method..

In [7]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)

Map: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2904/2904 [00:00<00:00, 11561.43 examples/s]


['cyber_label', 'environmental_issue']
['cyber_label', 'environmental_issue']
['cyber_label', 'environmental_issue']


Map:   0%|                                                                                                                                          | 0/252 [00:00<?, ? examples/s]

['cyber_label', 'environmental_issue']


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 252/252 [00:00<00:00, 10790.75 examples/s]


In [8]:
example = encoded_dataset['train'][0]
print(example.keys())

dict_keys(['input_ids', 'attention_mask', 'labels'])


In [9]:
tokenizer.decode(example['input_ids'])

'[CLS] additionally the risk of natural catastrophes, solar storms or cyber attacks could impact the infrastructure, ( incl. satellites, gps and communications systems ). also energy transition may impact stability of energy supply. a smoothly functioning digital infrastructure is becoming increasingly important, especially in times of remote home office working. 2008 & 2011 cyber risks the volume and sophistication of malicious cyber activity has increased substantially, and there are growing concerns regarding the security of proprietary corporate data and critical industrial control systems. cloud computing poses elevated risks due to increased concentration and accumulations. operational risks exist for corporations and could also lead to large property losses with high and [SEP]'

In [10]:
example['labels']  # actual labels of the example

[1.0, 1.0]

In [11]:
[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]  # names of the labels if class value is 1

['cyber_label', 'environmental_issue']

Finally, we set the format of our data to PyTorch tensors. This will turn the training, validation and test sets into standard PyTorch [datasets](https://pytorch.org/docs/stable/data.html).

In [12]:
encoded_dataset.set_format("torch")

## Define model

Here we define a model that includes a pre-trained base (i.e. the weights from bert-base-uncased or other optimized checkpoints like distilbert-base-uncased) are loaded, with a random initialized classification head (linear layer) on top. This head is going to be fine-tuned, together with the pre-trained base on our labeled dataset.

This is also printed by the warning.

We set the `problem_type` to be "multi_label_classification", as this will make sure the appropriate loss function is used (namely [`BCEWithLogitsLoss`](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)). We also make sure the output layer has `len(labels)` output neurons, and we set the id2label and label2id mappings.

In [13]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
   model_id,
   problem_type="multi_label_classification",
   num_labels=len(labels),
   id2label=id2label,
   label2id=label2id,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train the model!

We are going to train the model using HuggingFace's Trainer API. This requires us to define 2 things:

* `TrainingArguments`, which specify training hyperparameters. All options can be found in the [docs](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments). Below, we for example specify that we want to evaluate after every epoch of training, we would like to save the model every epoch, we set the learning rate, the batch size to use for training/evaluation, how many epochs to train for, and so on.
* a `Trainer` object (docs can be found [here](https://huggingface.co/transformers/main_classes/trainer.html#id1)).

In [14]:
batch_size = 4
metric_name = "f1"

In [15]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"distilbert-finetuned-MLC-upsampled",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #push_to_hub=True,
)

We are also going to compute metrics while training. For this, we need to define a `compute_metrics` function, that returns a dictionary with the desired metric values.

I also added the classification reports for each of the classes for better visibility on how well the model is progressing.

In [16]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, classification_report
from transformers import EvalPrediction
import torch

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}

    # Compute metrics for each label
    for i in range(y_true.shape[1]):  # Assuming labels are in the columns
        metrics[f'f1_label_{id2label[i]}'] = f1_score(y_true[:, i], y_pred[:, i], average='binary', zero_division=0)
        metrics[f'roc_auc_label_{id2label[i]}'] = roc_auc_score(y_true[:, i], y_pred[:, i])
        report = classification_report(y_true[:, i], y_pred[:, i], output_dict=True, zero_division=0)
        metrics[f'classification_report_label_{id2label[i]}'] = report

    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions,
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds,
        labels=p.label_ids)
    return result

Let's verify a batch as well as a forward pass:

In [17]:
encoded_dataset['train'][0]['labels'].type()

'torch.FloatTensor'

In [18]:
encoded_dataset['train']['input_ids'][0]

tensor([  101,  5678,  1996,  3891,  1997,  3019, 25539,  2015,  1010,  5943,
        12642,  2030, 16941,  4491,  2071,  4254,  1996,  6502,  1010,  1006,
         4297,  2140,  1012, 14549,  1010, 14658,  1998,  4806,  3001,  1007,
         1012,  2036,  2943,  6653,  2089,  4254,  9211,  1997,  2943,  4425,
         1012,  1037, 15299, 12285,  3617,  6502,  2003,  3352,  6233,  2590,
         1010,  2926,  1999,  2335,  1997,  6556,  2188,  2436,  2551,  1012,
         2263,  1004,  2249, 16941, 10831,  1996,  3872,  1998,  2061, 21850,
        10074,  3370,  1997, 24391, 16941,  4023,  2038,  3445, 12381,  1010,
         1998,  2045,  2024,  3652,  5936,  4953,  1996,  3036,  1997, 16350,
         5971,  2951,  1998,  4187,  3919,  2491,  3001,  1012,  6112,  9798,
        22382,  8319, 10831,  2349,  2000,  3445,  6693,  1998, 20299,  2015,
         1012,  6515, 10831,  4839,  2005, 11578,  1998,  2071,  2036,  2599,
         2000,  2312,  3200,  6409,  2007,  2152,  1998,   102])

In [19]:
# Making sure forward pass works fine
outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train'][0]['labels'].unsqueeze(0))
outputs

SequenceClassifierOutput(loss=tensor(0.6675, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), logits=tensor([[0.0291, 0.0749]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

Let's start training!

In [20]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["valid"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy,F1 Label Cyber Label,Roc Auc Label Cyber Label,Classification Report Label Cyber Label,F1 Label Environmental Issue,Roc Auc Label Environmental Issue,Classification Report Label Environmental Issue
1,0.3807,0.407735,0.59893,0.834533,0.722222,0.5,0.743429,"{'0.0': {'precision': 0.9656652360515021, 'recall': 0.9574468085106383, 'f1-score': 0.9615384615384616, 'support': 235.0}, '1.0': {'precision': 0.47368421052631576, 'recall': 0.5294117647058824, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9285714285714286, 'macro avg': {'precision': 0.7196747232889089, 'recall': 0.7434292866082604, 'f1-score': 0.7307692307692308, 'support': 252.0}, 'weighted avg': {'precision': 0.9324760398851205, 'recall': 0.9285714285714286, 'f1-score': 0.9304029304029304, 'support': 252.0}}",0.622517,0.821923,"{'0.0': {'precision': 0.9673202614379085, 'recall': 0.74, 'f1-score': 0.8385269121813032, 'support': 200.0}, '1.0': {'precision': 0.47474747474747475, 'recall': 0.9038461538461539, 'f1-score': 0.6225165562913907, 'support': 52.0}, 'accuracy': 0.7738095238095238, 'macro avg': {'precision': 0.7210338680926917, 'recall': 0.821923076923077, 'f1-score': 0.7305217342363469, 'support': 252.0}, 'weighted avg': {'precision': 0.8656782578351205, 'recall': 0.7738095238095238, 'f1-score': 0.7939533466802101, 'support': 252.0}}"
2,0.1071,0.404711,0.592593,0.75997,0.789683,0.5,0.743429,"{'0.0': {'precision': 0.9656652360515021, 'recall': 0.9574468085106383, 'f1-score': 0.9615384615384616, 'support': 235.0}, '1.0': {'precision': 0.47368421052631576, 'recall': 0.5294117647058824, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9285714285714286, 'macro avg': {'precision': 0.7196747232889089, 'recall': 0.7434292866082604, 'f1-score': 0.7307692307692308, 'support': 252.0}, 'weighted avg': {'precision': 0.9324760398851205, 'recall': 0.9285714285714286, 'f1-score': 0.9304029304029304, 'support': 252.0}}",0.626263,0.758077,"{'0.0': {'precision': 0.8975609756097561, 'recall': 0.92, 'f1-score': 0.908641975308642, 'support': 200.0}, '1.0': {'precision': 0.6595744680851063, 'recall': 0.5961538461538461, 'f1-score': 0.6262626262626263, 'support': 52.0}, 'accuracy': 0.8531746031746031, 'macro avg': {'precision': 0.7785677218474312, 'recall': 0.7580769230769231, 'f1-score': 0.7674523007856342, 'support': 252.0}, 'weighted avg': {'precision': 0.8484526486602253, 'recall': 0.8531746031746031, 'f1-score': 0.8503732207435912, 'support': 252.0}}"
3,0.0487,0.516696,0.609929,0.778261,0.793651,0.5,0.766458,"{'0.0': {'precision': 0.9694323144104804, 'recall': 0.9446808510638298, 'f1-score': 0.9568965517241379, 'support': 235.0}, '1.0': {'precision': 0.43478260869565216, 'recall': 0.5882352941176471, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9206349206349206, 'macro avg': {'precision': 0.7021074615530662, 'recall': 0.7664580725907384, 'f1-score': 0.728448275862069, 'support': 252.0}, 'weighted avg': {'precision': 0.9333646755328928, 'recall': 0.9206349206349206, 'f1-score': 0.9260741652983032, 'support': 252.0}}",0.653465,0.777308,"{'0.0': {'precision': 0.9064039408866995, 'recall': 0.92, 'f1-score': 0.913151364764268, 'support': 200.0}, '1.0': {'precision': 0.673469387755102, 'recall': 0.6346153846153846, 'f1-score': 0.6534653465346535, 'support': 52.0}, 'accuracy': 0.8611111111111112, 'macro avg': {'precision': 0.7899366643209007, 'recall': 0.7773076923076923, 'f1-score': 0.7833083556494607, 'support': 252.0}, 'weighted avg': {'precision': 0.8583380807166873, 'recall': 0.8611111111111112, 'f1-score': 0.8595653610026015, 'support': 252.0}}"
4,0.0254,0.650661,0.619048,0.822789,0.757937,0.421053,0.707635,"{'0.0': {'precision': 0.961038961038961, 'recall': 0.9446808510638298, 'f1-score': 0.9527896995708155, 'support': 235.0}, '1.0': {'precision': 0.38095238095238093, 'recall': 0.47058823529411764, 'f1-score': 0.42105263157894735, 'support': 17.0}, 'accuracy': 0.9126984126984127, 'macro avg': {'precision': 0.670995670995671, 'recall': 0.7076345431789737, 'f1-score': 0.6869211655748815, 'support': 252.0}, 'weighted avg': {'precision': 0.9219061361918505, 'recall': 0.9126984126984127, 'f1-score': 0.9169185481586656, 'support': 252.0}}",0.676923,0.838077,"{'0.0': {'precision': 0.9540229885057471, 'recall': 0.83, 'f1-score': 0.8877005347593583, 'support': 200.0}, '1.0': {'precision': 0.5641025641025641, 'recall': 0.8461538461538461, 'f1-score': 0.676923076923077, 'support': 52.0}, 'accuracy': 0.8333333333333334, 'macro avg': {'precision': 0.7590627763041555, 'recall': 0.838076923076923, 'f1-score': 0.7823118058412176, 'support': 252.0}, 'weighted avg': {'precision': 0.8735632183908046, 'recall': 0.8333333333333334, 'f1-score': 0.8442067736185384, 'support': 252.0}}"
5,0.0143,0.836014,0.629834,0.849825,0.742063,0.4,0.703379,"{'0.0': {'precision': 0.9606986899563319, 'recall': 0.9361702127659575, 'f1-score': 0.9482758620689655, 'support': 235.0}, '1.0': {'precision': 0.34782608695652173, 'recall': 0.47058823529411764, 'f1-score': 0.4, 'support': 17.0}, 'accuracy': 0.9047619047619048, 'macro avg': {'precision': 0.6542623884564268, 'recall': 0.7033792240300376, 'f1-score': 0.6741379310344828, 'support': 252.0}, 'weighted avg': {'precision': 0.9193541095952336, 'recall': 0.9047619047619048, 'f1-score': 0.9112889983579638, 'support': 252.0}}",0.695035,0.871154,"{'0.0': {'precision': 0.9815950920245399, 'recall': 0.8, 'f1-score': 0.8815426997245179, 'support': 200.0}, '1.0': {'precision': 0.550561797752809, 'recall': 0.9423076923076923, 'f1-score': 0.6950354609929078, 'support': 52.0}, 'accuracy': 0.8293650793650794, 'macro avg': {'precision': 0.7660784448886744, 'recall': 0.8711538461538462, 'f1-score': 0.7882890803587128, 'support': 252.0}, 'weighted avg': {'precision': 0.8926517138414842, 'recall': 0.8293650793650794, 'f1-score': 0.8430570790338683, 'support': 252.0}}"
6,0.0137,0.709693,0.670732,0.852574,0.797619,0.470588,0.716145,"{'0.0': {'precision': 0.9617021276595744, 'recall': 0.9617021276595744, 'f1-score': 0.9617021276595744, 'support': 235.0}, '1.0': {'precision': 0.47058823529411764, 'recall': 0.47058823529411764, 'f1-score': 0.47058823529411764, 'support': 17.0}, 'accuracy': 0.9285714285714286, 'macro avg': {'precision': 0.716145181476846, 'recall': 0.716145181476846, 'f1-score': 0.716145181476846, 'support': 252.0}, 'weighted avg': {'precision': 0.9285714285714286, 'recall': 0.9285714285714286, 'f1-score': 0.9285714285714286, 'support': 252.0}}",0.723077,0.874423,"{'0.0': {'precision': 0.9712643678160919, 'recall': 0.845, 'f1-score': 0.9037433155080213, 'support': 200.0}, '1.0': {'precision': 0.6025641025641025, 'recall': 0.9038461538461539, 'f1-score': 0.7230769230769231, 'support': 52.0}, 'accuracy': 0.8571428571428571, 'macro avg': {'precision': 0.7869142351900973, 'recall': 0.874423076923077, 'f1-score': 0.8134101192924722, 'support': 252.0}, 'weighted avg': {'precision': 0.895183360700602, 'recall': 0.8571428571428571, 'f1-score': 0.8664629488158899, 'support': 252.0}}"
7,0.0138,0.70388,0.58209,0.752724,0.785714,0.421053,0.707635,"{'0.0': {'precision': 0.961038961038961, 'recall': 0.9446808510638298, 'f1-score': 0.9527896995708155, 'support': 235.0}, '1.0': {'precision': 0.38095238095238093, 'recall': 0.47058823529411764, 'f1-score': 0.42105263157894735, 'support': 17.0}, 'accuracy': 0.9126984126984127, 'macro avg': {'precision': 0.670995670995671, 'recall': 0.7076345431789737, 'f1-score': 0.6869211655748815, 'support': 252.0}, 'weighted avg': {'precision': 0.9219061361918505, 'recall': 0.9126984126984127, 'f1-score': 0.9169185481586656, 'support': 252.0}}",0.645833,0.765577,"{'0.0': {'precision': 0.8990384615384616, 'recall': 0.935, 'f1-score': 0.9166666666666666, 'support': 200.0}, '1.0': {'precision': 0.7045454545454546, 'recall': 0.5961538461538461, 'f1-score': 0.6458333333333334, 'support': 52.0}, 'accuracy': 0.8650793650793651, 'macro avg': {'precision': 0.8017919580419581, 'recall': 0.7655769230769232, 'f1-score': 0.78125, 'support': 252.0}, 'weighted avg': {'precision': 0.8589049839049839, 'recall': 0.8650793650793651, 'f1-score': 0.8607804232804233, 'support': 252.0}}"
8,0.0138,0.840524,0.606452,0.795752,0.765873,0.487805,0.76433,"{'0.0': {'precision': 0.9692982456140351, 'recall': 0.9404255319148936, 'f1-score': 0.9546436285097192, 'support': 235.0}, '1.0': {'precision': 0.4166666666666667, 'recall': 0.5882352941176471, 'f1-score': 0.4878048780487805, 'support': 17.0}, 'accuracy': 0.9166666666666666, 'macro avg': {'precision': 0.6929824561403509, 'recall': 0.7643304130162704, 'f1-score': 0.7212242532792499, 'support': 252.0}, 'weighted avg': {'precision': 0.9320175438596492, 'recall': 0.9166666666666666, 'f1-score': 0.9231505382008462, 'support': 252.0}}",0.649123,0.793269,"{'0.0': {'precision': 0.9210526315789473, 'recall': 0.875, 'f1-score': 0.8974358974358975, 'support': 200.0}, '1.0': {'precision': 0.5967741935483871, 'recall': 0.7115384615384616, 'f1-score': 0.6491228070175439, 'support': 52.0}, 'accuracy': 0.8412698412698413, 'macro avg': {'precision': 0.7589134125636672, 'recall': 0.7932692307692308, 'f1-score': 0.7732793522267207, 'support': 252.0}, 'weighted avg': {'precision': 0.854138033255181, 'recall': 0.8412698412698413, 'f1-score': 0.8461966883019515, 'support': 252.0}}"
9,0.0077,0.719106,0.652482,0.803448,0.809524,0.484848,0.718273,"{'0.0': {'precision': 0.961864406779661, 'recall': 0.9659574468085106, 'f1-score': 0.9639065817409767, 'support': 235.0}, '1.0': {'precision': 0.5, 'recall': 0.47058823529411764, 'f1-score': 0.48484848484848486, 'support': 17.0}, 'accuracy': 0.9325396825396826, 'macro avg': {'precision': 0.7309322033898304, 'recall': 0.7182728410513142, 'f1-score': 0.7243775332947308, 'support': 252.0}, 'weighted avg': {'precision': 0.9307068872746839, 'recall': 0.9325396825396826, 'f1-score': 0.9315891704426736, 'support': 252.0}}",0.703704,0.820385,"{'0.0': {'precision': 0.9285714285714286, 'recall': 0.91, 'f1-score': 0.9191919191919192, 'support': 200.0}, '1.0': {'precision': 0.6785714285714286, 'recall': 0.7307692307692307, 'f1-score': 0.7037037037037037, 'support': 52.0}, 'accuracy': 0.873015873015873, 'macro avg': {'precision': 0.8035714285714286, 'recall': 0.8203846153846154, 'f1-score': 0.8114478114478114, 'support': 252.0}, 'weighted avg': {'precision': 0.876984126984127, 'recall': 0.873015873015873, 'f1-score': 0.8747260969483193, 'support': 252.0}}"
10,0.0067,0.924912,0.6125,0.806797,0.757937,0.473684,0.739174,"{'0.0': {'precision': 0.9653679653679653, 'recall': 0.948936170212766, 'f1-score': 0.9570815450643777, 'support': 235.0}, '1.0': {'precision': 0.42857142857142855, 'recall': 0.5294117647058824, 'f1-score': 0.47368421052631576, 'support': 17.0}, 'accuracy': 0.9206349206349206, 'macro avg': {'precision': 0.6969696969696969, 'recall': 0.7391739674593242, 'f1-score': 0.7153828777953467, 'support': 252.0}, 'weighted avg': {'precision': 0.9291555005840719, 'recall': 0.9206349206349206, 'f1-score': 0.9244714074169688, 'support': 252.0}}",0.655738,0.809615,"{'0.0': {'precision': 0.9340659340659341, 'recall': 0.85, 'f1-score': 0.8900523560209425, 'support': 200.0}, '1.0': {'precision': 0.5714285714285714, 'recall': 0.7692307692307693, 'f1-score': 0.6557377049180327, 'support': 52.0}, 'accuracy': 0.8333333333333334, 'macro avg': {'precision': 0.7527472527472527, 'recall': 0.8096153846153846, 'f1-score': 0.7728950304694876, 'support': 252.0}, 'weighted avg': {'precision': 0.859236002093145, 'recall': 0.8333333333333334, 'f1-score': 0.8417017137298659, 'support': 252.0}}"


TrainOutput(global_step=14520, training_loss=0.026142534255315503, metrics={'train_runtime': 1150.1263, 'train_samples_per_second': 50.499, 'train_steps_per_second': 12.625, 'total_flos': 1923426628485120.0, 'train_loss': 0.026142534255315503, 'epoch': 20.0})

## Evaluate

After training, we evaluate our model on the validation set.

In [22]:
scores = trainer.evaluate()

{'eval_loss': 0.7096929550170898, 'eval_f1': 0.6707317073170732, 'eval_roc_auc': 0.8525737131434282, 'eval_accuracy': 0.7976190476190477, 'eval_f1_label_cyber_label': 0.47058823529411764, 'eval_roc_auc_label_cyber_label': 0.7161451814768461, 'eval_classification_report_label_cyber_label': {'0.0': {'precision': 0.9617021276595744, 'recall': 0.9617021276595744, 'f1-score': 0.9617021276595744, 'support': 235.0}, '1.0': {'precision': 0.47058823529411764, 'recall': 0.47058823529411764, 'f1-score': 0.47058823529411764, 'support': 17.0}, 'accuracy': 0.9285714285714286, 'macro avg': {'precision': 0.716145181476846, 'recall': 0.716145181476846, 'f1-score': 0.716145181476846, 'support': 252.0}, 'weighted avg': {'precision': 0.9285714285714286, 'recall': 0.9285714285714286, 'f1-score': 0.9285714285714286, 'support': 252.0}}, 'eval_f1_label_environmental_issue': 0.7230769230769231, 'eval_roc_auc_label_environmental_issue': 0.8744230769230769, 'eval_classification_report_label_environmental_issue':

In [27]:
scores

{'eval_loss': 0.7096929550170898,
 'eval_f1': 0.6707317073170732,
 'eval_roc_auc': 0.8525737131434282,
 'eval_accuracy': 0.7976190476190477,
 'eval_f1_label_cyber_label': 0.47058823529411764,
 'eval_roc_auc_label_cyber_label': 0.7161451814768461,
 'eval_classification_report_label_cyber_label': {'0.0': {'precision': 0.9617021276595744,
   'recall': 0.9617021276595744,
   'f1-score': 0.9617021276595744,
   'support': 235.0},
  '1.0': {'precision': 0.47058823529411764,
   'recall': 0.47058823529411764,
   'f1-score': 0.47058823529411764,
   'support': 17.0},
  'accuracy': 0.9285714285714286,
  'macro avg': {'precision': 0.716145181476846,
   'recall': 0.716145181476846,
   'f1-score': 0.716145181476846,
   'support': 252.0},
  'weighted avg': {'precision': 0.9285714285714286,
   'recall': 0.9285714285714286,
   'f1-score': 0.9285714285714286,
   'support': 252.0}},
 'eval_f1_label_environmental_issue': 0.7230769230769231,
 'eval_roc_auc_label_environmental_issue': 0.8744230769230769,
 'e

## Inference

Let's test the model on a new sentence:

In [24]:
text = "I'm happy I can finally train a model for multi-label classification"

encoding = tokenizer(text, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

outputs = trainer.model(**encoding)

The logits that come out of the model are of shape (batch_size, num_labels). As we are only forwarding a single sentence through the model, the `batch_size` equals 1. The logits is a tensor that contains the (unnormalized) scores for every individual label.

In [None]:
logits = outputs.logits
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
print(predicted_labels)