# Fine-tuning BERT for multi-label text classification

In this notebook, we are going to fine-tune BERT to predict multiple labels for a given piece of text. 


## Load dataset

Next, let's load our **processed** dataset to address the class imbalance in the raw dataset.

For this experiment, I use datasets as it is more compatible with with loading the data while fine-tuning the BERT-like models in the huggingface ecosystem:   

* Load the created canonical datasets splits as dataframe, 
* next, cast the dataset to a huggingface dataset object.


In [2]:
import pandas as pd
from datasets import Dataset, DatasetDict

train_df = pd.read_csv('../data/processed/clean_train.csv')
valid_df = pd.read_csv('../data/processed/clean_valid.csv')


# dataset = load_dataset("ethos", "multilabel")

ds_dict = {'train' : Dataset.from_pandas(train_df),
           'valid' : Dataset.from_pandas(valid_df)}

dataset = DatasetDict(ds_dict)
dataset

  from .autonotebook import tqdm as notebook_tqdm


DatasetDict({
    train: Dataset({
        features: ['clean_content', 'cyber_label', 'environmental_issue'],
        num_rows: 1008
    })
    valid: Dataset({
        features: ['clean_content', 'cyber_label', 'environmental_issue'],
        num_rows: 252
    })
})

As we can see, the dataset contains 3 splits: one for training, one for validation and one for testing.

Let's check the first example of the training split:

In [4]:
example = dataset['train'][0]
example

{'clean_content': "It is a process that puts a heavier burden\nof work on the auditees because they\nmust be prepared to work closely with the\ninternal audit function to share knowledge.\nThat includes details of the AI models they\nhave deployed, which can be complex and\nincreases the level of specialist knowledge\nneeded within the internal audit team.\nOne roundtable attendee said he had\nbrought in these skills initially from\nan external supplier. However, as the\norganisation digitalised it had decided\nto create a dedicated team of experts inhouse that internal audit could buy days\nDIGITAL DISRUPTION\nAND NEW TECHNOLOGY\nfrom through an internal exchange system.\nWhile that worked well, finding subject\nmatter experts in the business areas\naffected by AI remained challenging.\nThe European Union's proposal for draft\nregulation on artificial intelligence, which\nwas published in 2021, is well underway.\nThat is likely to require certification for\nAI models and business area

The dataset consists of tweets, labeled with one or more emotions.

Let's create a list that contains the labels, as well as 2 dictionaries that map labels to integers and back.

In [5]:
labels = [label for label in dataset['train'].features.keys() if label not in ['clean_content']]
id2label = {idx:label for idx, label in enumerate(labels)}
label2id = {label:idx for idx, label in enumerate(labels)}
labels

['cyber_label', 'environmental_issue']

## Preprocess data

As models like BERT don't expect text as direct input, but rather `input_ids`, etc., I tokenize the text using the tokenizer. Here I'm using the `AutoTokenizer` API, which will automatically load the appropriate tokenizer based on the checkpoint on the hub.

What's a bit tricky is that we also need to provide labels to the model. For multi-label text classification, this is a matrix of shape (batch_size, num_labels). Also important: this should be a tensor of floats rather than integers, otherwise PyTorch' `BCEWithLogitsLoss` (which the model uses) will complain, as explained [here](https://discuss.pytorch.org/t/multi-label-binary-classification-result-type-float-cant-be-cast-to-the-desired-output-type-long/117915/3).

In [6]:
from transformers import AutoTokenizer
import numpy as np

model_id = "distilbert/distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id)

def preprocess_data(examples):
  # take a batch of texts
  text = examples["clean_content"]
  # encode them
  encoding = tokenizer(text, padding="max_length", truncation=True, max_length=128)
  # add labels
  labels_batch = {k: examples[k] for k in examples.keys() if k in labels}
  # create numpy array of shape (batch_size, num_labels)
  labels_matrix = np.zeros((len(text), len(labels)))
  # fill numpy array
  print(labels)
  for idx, label in enumerate(labels):
    labels_matrix[:, idx] = labels_batch[label]

  encoding["labels"] = labels_matrix.tolist()

  return encoding

Let's apply this on pre-processing function on our dataset using `map` method..

In [7]:
encoded_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset['train'].column_names)

Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1008/1008 [00:00<00:00, 8352.89 examples/s]


['cyber_label', 'environmental_issue']
['cyber_label', 'environmental_issue']


Map: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 252/252 [00:00<00:00, 10016.53 examples/s]

['cyber_label', 'environmental_issue']





In [8]:
example = encoded_dataset['train'][0]
print(example.keys())

dict_keys(['input_ids', 'attention_mask', 'labels'])


In [9]:
tokenizer.decode(example['input_ids'])

'[CLS] it is a process that puts a heavier burden of work on the auditees because they must be prepared to work closely with the internal audit function to share knowledge. that includes details of the ai models they have deployed, which can be complex and increases the level of specialist knowledge needed within the internal audit team. one roundtable attendee said he had brought in these skills initially from an external supplier. however, as the organisation digitalised it had decided to create a dedicated team of experts inhouse that internal audit could buy days digital disruption and new technology from through an internal exchange system. while that worked well, finding subject matter experts in [SEP]'

In [10]:
example['labels']

[0.0, 0.0]

In [11]:
[id2label[idx] for idx, label in enumerate(example['labels']) if label == 1.0]

[]

Finally, we set the format of our data to PyTorch tensors. This will turn the training, validation and test sets into standard PyTorch [datasets](https://pytorch.org/docs/stable/data.html).

In [12]:
encoded_dataset.set_format("torch")

## Define model

Here we define a model that includes a pre-trained base (i.e. the weights from bert-base-uncased or other optimized checkpoints like distilbert-base-uncased) are loaded, with a random initialized classification head (linear layer) on top. This head is going to be fine-tuned, together with the pre-trained base on our labeled dataset.

This is also printed by the warning.

We set the `problem_type` to be "multi_label_classification", as this will make sure the appropriate loss function is used (namely [`BCEWithLogitsLoss`](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html)). We also make sure the output layer has `len(labels)` output neurons, and we set the id2label and label2id mappings.

In [13]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained(
   model_id,
   problem_type="multi_label_classification",
   num_labels=len(labels),
   id2label=id2label,
   label2id=label2id,
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Train the model!

We are going to train the model using HuggingFace's Trainer API. This requires us to define 2 things:

* `TrainingArguments`, which specify training hyperparameters. All options can be found in the [docs](https://huggingface.co/transformers/main_classes/trainer.html#trainingarguments). Below, we for example specify that we want to evaluate after every epoch of training, we would like to save the model every epoch, we set the learning rate, the batch size to use for training/evaluation, how many epochs to train for, and so on.
* a `Trainer` object (docs can be found [here](https://huggingface.co/transformers/main_classes/trainer.html#id1)).

In [14]:
batch_size = 4
metric_name = "f1"

In [15]:
from transformers import TrainingArguments, Trainer

args = TrainingArguments(
    f"bert-finetuned-sem_eval-english",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=1e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=20,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    #push_to_hub=True,
)

We are also going to compute metrics while training. For this, we need to define a `compute_metrics` function, that returns a dictionary with the desired metric values.

I also added the classification reports for each of the classes for better visibility on how well the model is progressing.

In [16]:
from sklearn.metrics import f1_score, roc_auc_score, accuracy_score, classification_report
from transformers import EvalPrediction
import torch

# source: https://jesusleal.io/2021/04/21/Longformer-multilabel-classification/
def multi_label_metrics(predictions, labels, threshold=0.5):
    # first, apply sigmoid on predictions which are of shape (batch_size, num_labels)
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    # next, use threshold to turn them into integer predictions
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= threshold)] = 1
    # finally, compute metrics
    y_true = labels
    f1_micro_average = f1_score(y_true=y_true, y_pred=y_pred, average='micro')
    roc_auc = roc_auc_score(y_true, y_pred, average = 'micro')
    accuracy = accuracy_score(y_true, y_pred)
    # return as dictionary
    metrics = {'f1': f1_micro_average,
               'roc_auc': roc_auc,
               'accuracy': accuracy}

    # Compute metrics for each label
    for i in range(y_true.shape[1]):  # Assuming labels are in the columns
        metrics[f'f1_label_{id2label[i]}'] = f1_score(y_true[:, i], y_pred[:, i], average='binary', zero_division=0)
        metrics[f'roc_auc_label_{id2label[i]}'] = roc_auc_score(y_true[:, i], y_pred[:, i])
        report = classification_report(y_true[:, i], y_pred[:, i], output_dict=True, zero_division=0)
        metrics[f'classification_report_label_{id2label[i]}'] = report

    return metrics

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions,
            tuple) else p.predictions
    result = multi_label_metrics(
        predictions=preds,
        labels=p.label_ids)
    return result

Let's verify a batch as well as a forward pass:

In [17]:
encoded_dataset['train'][0]['labels'].type()

'torch.FloatTensor'

In [18]:
encoded_dataset['train']['input_ids'][0]

tensor([  101,  2009,  2003,  1037,  2832,  2008,  8509,  1037, 11907, 10859,
         1997,  2147,  2006,  1996, 15727, 10285,  2138,  2027,  2442,  2022,
         4810,  2000,  2147,  4876,  2007,  1996,  4722, 15727,  3853,  2000,
         3745,  3716,  1012,  2008,  2950,  4751,  1997,  1996,  9932,  4275,
         2027,  2031,  7333,  1010,  2029,  2064,  2022,  3375,  1998,  7457,
         1996,  2504,  1997,  8325,  3716,  2734,  2306,  1996,  4722, 15727,
         2136,  1012,  2028,  2461, 10880,  5463,  4402,  2056,  2002,  2018,
         2716,  1999,  2122,  4813,  3322,  2013,  2019,  6327, 17024,  1012,
         2174,  1010,  2004,  1996,  5502,  3617,  5084,  2009,  2018,  2787,
         2000,  3443,  1037,  4056,  2136,  1997,  8519,  1999,  4580,  2008,
         4722, 15727,  2071,  4965,  2420,  3617, 20461,  1998,  2047,  2974,
         2013,  2083,  2019,  4722,  3863,  2291,  1012,  2096,  2008,  2499,
         2092,  1010,  4531,  3395,  3043,  8519,  1999,   102])

In [19]:
#forward pass
outputs = model(input_ids=encoded_dataset['train']['input_ids'][0].unsqueeze(0), labels=encoded_dataset['train'][0]['labels'].unsqueeze(0))
outputs

SequenceClassifierOutput(loss=tensor(0.6676, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), logits=tensor([[-0.1347,  0.0277]], grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

Let's start training!

In [20]:
trainer = Trainer(
    model,
    args,
    train_dataset=encoded_dataset["train"],
    eval_dataset=encoded_dataset["valid"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [21]:
trainer.train()

Epoch,Training Loss,Validation Loss,F1,Roc Auc,Accuracy,F1 Label Cyber Label,Roc Auc Label Cyber Label,Classification Report Label Cyber Label,F1 Label Environmental Issue,Roc Auc Label Environmental Issue,Classification Report Label Environmental Issue
1,No log,0.303823,0.292683,0.585807,0.781746,0.0,0.5,"{'0.0': {'precision': 0.9325396825396826, 'recall': 1.0, 'f1-score': 0.9650924024640657, 'support': 235.0}, '1.0': {'precision': 0.0, 'recall': 0.0, 'f1-score': 0.0, 'support': 17.0}, 'accuracy': 0.9325396825396826, 'macro avg': {'precision': 0.4662698412698413, 'recall': 0.5, 'f1-score': 0.48254620123203285, 'support': 252.0}, 'weighted avg': {'precision': 0.8696302595112119, 'recall': 0.9325396825396826, 'f1-score': 0.8999869626152994, 'support': 252.0}}",0.369231,0.612885,"{'0.0': {'precision': 0.8326359832635983, 'recall': 0.995, 'f1-score': 0.9066059225512528, 'support': 200.0}, '1.0': {'precision': 0.9230769230769231, 'recall': 0.23076923076923078, 'f1-score': 0.36923076923076925, 'support': 52.0}, 'accuracy': 0.8373015873015873, 'macro avg': {'precision': 0.8778564531702607, 'recall': 0.6128846153846154, 'f1-score': 0.637918345891011, 'support': 252.0}, 'weighted avg': {'precision': 0.8512983994155543, 'recall': 0.8373015873015873, 'f1-score': 0.795718986151788, 'support': 252.0}}"
2,0.341200,0.249917,0.59322,0.737531,0.821429,0.533333,0.724656,"{'0.0': {'precision': 0.9623430962343096, 'recall': 0.9787234042553191, 'f1-score': 0.9704641350210971, 'support': 235.0}, '1.0': {'precision': 0.6153846153846154, 'recall': 0.47058823529411764, 'f1-score': 0.5333333333333333, 'support': 17.0}, 'accuracy': 0.9444444444444444, 'macro avg': {'precision': 0.7888638558094625, 'recall': 0.7246558197747184, 'f1-score': 0.7518987341772152, 'support': 252.0}, 'weighted avg': {'precision': 0.9389371669706397, 'recall': 0.9444444444444444, 'f1-score': 0.9409751523675575, 'support': 252.0}}",0.613636,0.737115,"{'0.0': {'precision': 0.8842592592592593, 'recall': 0.955, 'f1-score': 0.9182692307692307, 'support': 200.0}, '1.0': {'precision': 0.75, 'recall': 0.5192307692307693, 'f1-score': 0.6136363636363636, 'support': 52.0}, 'accuracy': 0.8650793650793651, 'macro avg': {'precision': 0.8171296296296297, 'recall': 0.7371153846153846, 'f1-score': 0.7659527972027972, 'support': 252.0}, 'weighted avg': {'precision': 0.8565549676660787, 'recall': 0.8650793650793651, 'f1-score': 0.8554084804084803, 'support': 252.0}}"
3,0.341200,0.246388,0.576,0.737881,0.805556,0.5,0.743429,"{'0.0': {'precision': 0.9656652360515021, 'recall': 0.9574468085106383, 'f1-score': 0.9615384615384616, 'support': 235.0}, '1.0': {'precision': 0.47368421052631576, 'recall': 0.5294117647058824, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9285714285714286, 'macro avg': {'precision': 0.7196747232889089, 'recall': 0.7434292866082604, 'f1-score': 0.7307692307692308, 'support': 252.0}, 'weighted avg': {'precision': 0.9324760398851205, 'recall': 0.9285714285714286, 'f1-score': 0.9304029304029304, 'support': 252.0}}",0.606742,0.734615,"{'0.0': {'precision': 0.8837209302325582, 'recall': 0.95, 'f1-score': 0.9156626506024096, 'support': 200.0}, '1.0': {'precision': 0.7297297297297297, 'recall': 0.5192307692307693, 'f1-score': 0.6067415730337079, 'support': 52.0}, 'accuracy': 0.8611111111111112, 'macro avg': {'precision': 0.8067253299811439, 'recall': 0.7346153846153847, 'f1-score': 0.7612021118180587, 'support': 252.0}, 'weighted avg': {'precision': 0.851944968224038, 'recall': 0.8611111111111112, 'f1-score': 0.8519170314215664, 'support': 252.0}}"
4,0.188200,0.307806,0.666667,0.846477,0.801587,0.5,0.720401,"{'0.0': {'precision': 0.9620253164556962, 'recall': 0.9702127659574468, 'f1-score': 0.9661016949152542, 'support': 235.0}, '1.0': {'precision': 0.5333333333333333, 'recall': 0.47058823529411764, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9365079365079365, 'macro avg': {'precision': 0.7476793248945148, 'recall': 0.7204005006257822, 'f1-score': 0.7330508474576272, 'support': 252.0}, 'weighted avg': {'precision': 0.9331056191815685, 'recall': 0.9365079365079365, 'f1-score': 0.934658326607479, 'support': 252.0}}",0.707692,0.862308,"{'0.0': {'precision': 0.9655172413793104, 'recall': 0.84, 'f1-score': 0.8983957219251337, 'support': 200.0}, '1.0': {'precision': 0.5897435897435898, 'recall': 0.8846153846153846, 'f1-score': 0.7076923076923077, 'support': 52.0}, 'accuracy': 0.8492063492063492, 'macro avg': {'precision': 0.7776304155614501, 'recall': 0.8623076923076922, 'f1-score': 0.8030440148087208, 'support': 252.0}, 'weighted avg': {'precision': 0.8879766465973362, 'recall': 0.8492063492063492, 'f1-score': 0.8590442237501061, 'support': 252.0}}"
5,0.188200,0.327929,0.666667,0.811844,0.825397,0.5,0.697372,"{'0.0': {'precision': 0.9585062240663901, 'recall': 0.9829787234042553, 'f1-score': 0.9705882352941176, 'support': 235.0}, '1.0': {'precision': 0.6363636363636364, 'recall': 0.4117647058823529, 'f1-score': 0.5, 'support': 17.0}, 'accuracy': 0.9444444444444444, 'macro avg': {'precision': 0.7974349302150132, 'recall': 0.697371714643304, 'f1-score': 0.7352941176470589, 'support': 252.0}, 'weighted avg': {'precision': 0.9367743828324742, 'recall': 0.9444444444444444, 'f1-score': 0.9388422035480859, 'support': 252.0}}",0.707965,0.832115,"{'0.0': {'precision': 0.93717277486911, 'recall': 0.895, 'f1-score': 0.9156010230179028, 'support': 200.0}, '1.0': {'precision': 0.6557377049180327, 'recall': 0.7692307692307693, 'f1-score': 0.7079646017699115, 'support': 52.0}, 'accuracy': 0.8690476190476191, 'macro avg': {'precision': 0.7964552398935714, 'recall': 0.8321153846153846, 'f1-score': 0.8117828123939071, 'support': 252.0}, 'weighted avg': {'precision': 0.8790988715458719, 'recall': 0.8690476190476191, 'f1-score': 0.8727554122841903, 'support': 252.0}}"
6,0.111700,0.302297,0.666667,0.801949,0.833333,0.482759,0.695244,"{'0.0': {'precision': 0.9583333333333334, 'recall': 0.9787234042553191, 'f1-score': 0.968421052631579, 'support': 235.0}, '1.0': {'precision': 0.5833333333333334, 'recall': 0.4117647058823529, 'f1-score': 0.4827586206896552, 'support': 17.0}, 'accuracy': 0.9404761904761905, 'macro avg': {'precision': 0.7708333333333334, 'recall': 0.695244055068836, 'f1-score': 0.725589836660617, 'support': 252.0}, 'weighted avg': {'precision': 0.9330357142857143, 'recall': 0.9404761904761905, 'f1-score': 0.935658110794227, 'support': 252.0}}",0.716981,0.825385,"{'0.0': {'precision': 0.9292929292929293, 'recall': 0.92, 'f1-score': 0.9246231155778895, 'support': 200.0}, '1.0': {'precision': 0.7037037037037037, 'recall': 0.7307692307692307, 'f1-score': 0.7169811320754716, 'support': 52.0}, 'accuracy': 0.8809523809523809, 'macro avg': {'precision': 0.8164983164983165, 'recall': 0.8253846153846154, 'f1-score': 0.8208021238266805, 'support': 252.0}, 'weighted avg': {'precision': 0.8827427716316606, 'recall': 0.8809523809523809, 'f1-score': 0.8817763570773905, 'support': 252.0}}"
7,0.111700,0.347805,0.622222,0.776762,0.809524,0.594595,0.80438,"{'0.0': {'precision': 0.9741379310344828, 'recall': 0.9617021276595744, 'f1-score': 0.9678800856531049, 'support': 235.0}, '1.0': {'precision': 0.55, 'recall': 0.6470588235294118, 'f1-score': 0.5945945945945946, 'support': 17.0}, 'accuracy': 0.9404761904761905, 'macro avg': {'precision': 0.7620689655172415, 'recall': 0.8043804755944931, 'f1-score': 0.7812373401238497, 'support': 252.0}, 'weighted avg': {'precision': 0.9455254515599344, 'recall': 0.9404761904761905, 'f1-score': 0.9426981279229673, 'support': 252.0}}",0.632653,0.760577,"{'0.0': {'precision': 0.8980582524271845, 'recall': 0.925, 'f1-score': 0.9113300492610837, 'support': 200.0}, '1.0': {'precision': 0.6739130434782609, 'recall': 0.5961538461538461, 'f1-score': 0.6326530612244898, 'support': 52.0}, 'accuracy': 0.8571428571428571, 'macro avg': {'precision': 0.7859856479527227, 'recall': 0.760576923076923, 'f1-score': 0.7719915552427867, 'support': 252.0}, 'weighted avg': {'precision': 0.8518060664535971, 'recall': 0.8571428571428571, 'f1-score': 0.8538252739519453, 'support': 252.0}}"
8,0.062200,0.441145,0.662722,0.855222,0.785714,0.606061,0.781352,"{'0.0': {'precision': 0.9703389830508474, 'recall': 0.9744680851063829, 'f1-score': 0.9723991507430998, 'support': 235.0}, '1.0': {'precision': 0.625, 'recall': 0.5882352941176471, 'f1-score': 0.6060606060606061, 'support': 17.0}, 'accuracy': 0.9484126984126984, 'macro avg': {'precision': 0.7976694915254237, 'recall': 0.781351689612015, 'f1-score': 0.789229878401853, 'support': 252.0}, 'weighted avg': {'precision': 0.9470423056228141, 'recall': 0.9484126984126984, 'f1-score': 0.9476858362208681, 'support': 252.0}}",0.676471,0.847308,"{'0.0': {'precision': 0.9642857142857143, 'recall': 0.81, 'f1-score': 0.8804347826086957, 'support': 200.0}, '1.0': {'precision': 0.5476190476190477, 'recall': 0.8846153846153846, 'f1-score': 0.6764705882352942, 'support': 52.0}, 'accuracy': 0.8253968253968254, 'macro avg': {'precision': 0.7559523809523809, 'recall': 0.8473076923076923, 'f1-score': 0.7784526854219949, 'support': 252.0}, 'weighted avg': {'precision': 0.8783068783068784, 'recall': 0.8253968253968254, 'f1-score': 0.838346932976089, 'support': 252.0}}"
9,0.062200,0.379679,0.666667,0.811844,0.821429,0.533333,0.724656,"{'0.0': {'precision': 0.9623430962343096, 'recall': 0.9787234042553191, 'f1-score': 0.9704641350210971, 'support': 235.0}, '1.0': {'precision': 0.6153846153846154, 'recall': 0.47058823529411764, 'f1-score': 0.5333333333333333, 'support': 17.0}, 'accuracy': 0.9444444444444444, 'macro avg': {'precision': 0.7888638558094625, 'recall': 0.7246558197747184, 'f1-score': 0.7518987341772152, 'support': 252.0}, 'weighted avg': {'precision': 0.9389371669706397, 'recall': 0.9444444444444444, 'f1-score': 0.9409751523675575, 'support': 252.0}}",0.702703,0.825,"{'0.0': {'precision': 0.9326424870466321, 'recall': 0.9, 'f1-score': 0.916030534351145, 'support': 200.0}, '1.0': {'precision': 0.6610169491525424, 'recall': 0.75, 'f1-score': 0.7027027027027027, 'support': 52.0}, 'accuracy': 0.8690476190476191, 'macro avg': {'precision': 0.7968297180995872, 'recall': 0.825, 'f1-score': 0.8093666185269239, 'support': 252.0}, 'weighted avg': {'precision': 0.8765927728780105, 'recall': 0.8690476190476191, 'f1-score': 0.8720105055982919, 'support': 252.0}}"
10,0.039900,0.383758,0.680272,0.830135,0.825397,0.588235,0.779224,"{'0.0': {'precision': 0.9702127659574468, 'recall': 0.9702127659574468, 'f1-score': 0.9702127659574468, 'support': 235.0}, '1.0': {'precision': 0.5882352941176471, 'recall': 0.5882352941176471, 'f1-score': 0.5882352941176471, 'support': 17.0}, 'accuracy': 0.9444444444444444, 'macro avg': {'precision': 0.7792240300375469, 'recall': 0.7792240300375469, 'f1-score': 0.7792240300375469, 'support': 252.0}, 'weighted avg': {'precision': 0.9444444444444444, 'recall': 0.9444444444444444, 'f1-score': 0.9444444444444444, 'support': 252.0}}",0.707965,0.832115,"{'0.0': {'precision': 0.93717277486911, 'recall': 0.895, 'f1-score': 0.9156010230179028, 'support': 200.0}, '1.0': {'precision': 0.6557377049180327, 'recall': 0.7692307692307693, 'f1-score': 0.7079646017699115, 'support': 52.0}, 'accuracy': 0.8690476190476191, 'macro avg': {'precision': 0.7964552398935714, 'recall': 0.8321153846153846, 'f1-score': 0.8117828123939071, 'support': 252.0}, 'weighted avg': {'precision': 0.8790988715458719, 'recall': 0.8690476190476191, 'f1-score': 0.8727554122841903, 'support': 252.0}}"


TrainOutput(global_step=5040, training_loss=0.08106707925834353, metrics={'train_runtime': 415.7248, 'train_samples_per_second': 48.494, 'train_steps_per_second': 12.123, 'total_flos': 667635689226240.0, 'train_loss': 0.08106707925834353, 'epoch': 20.0})

## Evaluate

After training, we evaluate our model on the validation set.

In [22]:
scores = trainer.evaluate()

{'eval_loss': 0.38375845551490784, 'eval_f1': 0.6802721088435374, 'eval_roc_auc': 0.8301349325337332, 'eval_accuracy': 0.8253968253968254, 'eval_f1_label_cyber_label': 0.5882352941176471, 'eval_roc_auc_label_cyber_label': 0.7792240300375469, 'eval_classification_report_label_cyber_label': {'0.0': {'precision': 0.9702127659574468, 'recall': 0.9702127659574468, 'f1-score': 0.9702127659574468, 'support': 235.0}, '1.0': {'precision': 0.5882352941176471, 'recall': 0.5882352941176471, 'f1-score': 0.5882352941176471, 'support': 17.0}, 'accuracy': 0.9444444444444444, 'macro avg': {'precision': 0.7792240300375469, 'recall': 0.7792240300375469, 'f1-score': 0.7792240300375469, 'support': 252.0}, 'weighted avg': {'precision': 0.9444444444444444, 'recall': 0.9444444444444444, 'f1-score': 0.9444444444444444, 'support': 252.0}}, 'eval_f1_label_environmental_issue': 0.7079646017699115, 'eval_roc_auc_label_environmental_issue': 0.8321153846153846, 'eval_classification_report_label_environmental_issue':

In [27]:
scores

{'eval_loss': 0.38375845551490784,
 'eval_f1': 0.6802721088435374,
 'eval_roc_auc': 0.8301349325337332,
 'eval_accuracy': 0.8253968253968254,
 'eval_f1_label_cyber_label': 0.5882352941176471,
 'eval_roc_auc_label_cyber_label': 0.7792240300375469,
 'eval_classification_report_label_cyber_label': {'0.0': {'precision': 0.9702127659574468,
   'recall': 0.9702127659574468,
   'f1-score': 0.9702127659574468,
   'support': 235.0},
  '1.0': {'precision': 0.5882352941176471,
   'recall': 0.5882352941176471,
   'f1-score': 0.5882352941176471,
   'support': 17.0},
  'accuracy': 0.9444444444444444,
  'macro avg': {'precision': 0.7792240300375469,
   'recall': 0.7792240300375469,
   'f1-score': 0.7792240300375469,
   'support': 252.0},
  'weighted avg': {'precision': 0.9444444444444444,
   'recall': 0.9444444444444444,
   'f1-score': 0.9444444444444444,
   'support': 252.0}},
 'eval_f1_label_environmental_issue': 0.7079646017699115,
 'eval_roc_auc_label_environmental_issue': 0.8321153846153846,
 'e

## Inference

Let's test the model on a new sentence:

In [24]:
text = "I'm happy I can finally train a model for multi-label classification"

encoding = tokenizer(text, return_tensors="pt")
encoding = {k: v.to(trainer.model.device) for k,v in encoding.items()}

outputs = trainer.model(**encoding)

The logits that come out of the model are of shape (batch_size, num_labels). As we are only forwarding a single sentence through the model, the `batch_size` equals 1. The logits is a tensor that contains the (unnormalized) scores for every individual label.

In [None]:
logits = outputs.logits
# apply sigmoid + threshold
sigmoid = torch.nn.Sigmoid()
probs = sigmoid(logits.squeeze().cpu())
predictions = np.zeros(probs.shape)
predictions[np.where(probs >= 0.5)] = 1
# turn predicted id's into actual label names
predicted_labels = [id2label[idx] for idx, label in enumerate(predictions) if label == 1.0]
print(predicted_labels)