# Fine-Tuning BERT for Text Classification


## 1. Introduction
BERT (Bidirectional Encoder Representations from Transformers) is a Machine Learning model based on transformers, i.e. attention components able to learn contextual relations between words.

The Natural Language Processing (NLP) community can leverage powerful tools like BERT in (at least) two ways:

Feature-based approach
1.1 Download a pre-trained BERT model.
1.2 Use BERT to turn natural language sentences into a vector representation.
1.3 Feed the pre-trained vector representations into a model for a downstream task (such as text classification).

Perform fine-tuning
2.1 Download a pre-trained BERT model.
2.2 Update the model weights on the downstream task.
In this post, we will follow the fine-tuning approach on binary text classification example. We will share code snippets that can be easily copied and executed on Google Colab.


## 2. Environment setup
Although it is not essential, the training procedure would benefit from the availability of GPU. In Colab, we can enable GPU by selecting Runtime > Change runtime type.

Then, we install the Hugging Face⁴ transformers library as follows:

In [82]:
!pip install transformers torch sklearn pandas tqdm tabulate datasets evaluate

Collecting datasets
  Downloading datasets-2.5.1-py3-none-any.whl (431 kB)
[K     |████████████████████████████████| 431 kB 1.8 MB/s eta 0:00:01
Collecting xxhash
  Using cached xxhash-3.0.0-cp39-cp39-macosx_11_0_arm64.whl (30 kB)
Collecting dill<0.3.6
  Using cached dill-0.3.5.1-py2.py3-none-any.whl (95 kB)
Collecting pyarrow>=6.0.0
  Downloading pyarrow-9.0.0-cp39-cp39-macosx_11_0_arm64.whl (21.6 MB)
[K     |████████████████████████████████| 21.6 MB 10.2 MB/s eta 0:00:01
[?25hCollecting responses<0.19
  Using cached responses-0.18.0-py3-none-any.whl (38 kB)
Collecting multiprocess
  Using cached multiprocess-0.70.13-py39-none-any.whl (132 kB)
Collecting fsspec[http]>=2021.11.1
  Downloading fsspec-2022.8.2-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 9.3 MB/s eta 0:00:01
[?25hCollecting aiohttp
  Downloading aiohttp-3.8.3-cp39-cp39-macosx_11_0_arm64.whl (337 kB)
[K     |████████████████████████████████| 337 kB 10.6 MB/s eta 0:00:01
[?25hCollecting

We import the needed dependencies:

In [2]:
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split

import pandas as pd
import numpy as np

from tabulate import tabulate
from tqdm import trange
import random

## 3. Dataset
We use the public SMS Spam Collection Data Set⁵ from the UCI Machine Learning Repository⁶. The data consists of a text file with a set of SMS messages labeled as either spam or ham. From the Colab notebook:

Download the dataset as a zip folder:

In [3]:
!mkdir datasets
!wget 'https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip' -O './datasets/smsspamcollection.zip'

mkdir: data: File exists
--2022-09-24 02:54:42--  https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip
Распознаётся archive.ics.uci.edu (archive.ics.uci.edu)… 128.195.10.252
Подключение к archive.ics.uci.edu (archive.ics.uci.edu)|128.195.10.252|:443... соединение установлено.
HTTP-запрос отправлен. Ожидание ответа… 200 OK
Длина: 203415 (199K) [application/x-httpd-php]
Сохранение в: «./data/smsspamcollection.zip»


2022-09-24 02:54:43 (339 KB/s) - «./data/smsspamcollection.zip» сохранён [203415/203415]



Unpack the folder:

In [4]:
!unzip -o ./datasets/smsspamcollection.zip -d data

Archive:  ./data/smsspamcollection.zip
  inflating: data/SMSSpamCollection  
  inflating: data/readme             


Inspect the first rows of the data file:

In [5]:
!head -10 ./datasets/SMSSpamCollection

ham	Go until jurong point, crazy.. Available only in bugis n great world la e buffet... Cine there got amore wat...
ham	Ok lar... Joking wif u oni...
spam	Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
ham	U dun say so early hor... U c already then say...
ham	Nah I don't think he goes to usf, he lives around here though
spam	FreeMsg Hey there darling it's been 3 week's now and no word back! I'd like some fun you up for it still? Tb ok! XxX std chgs to send, £1.50 to rcv
ham	Even my brother is not like to speak with me. They treat me like aids patent.
ham	As per your request 'Melle Melle (Oru Minnaminunginte Nurungu Vettam)' has been set as your callertune for all Callers. Press *9 to copy your friends Callertune
spam	WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only

In [6]:
df = pd.read_csv('./datasets/SMSSpamCollection', sep='\t')
df.columns =['label', 'text']
df['label'] = df['label'].apply(lambda x: 1 if x=='spam'  else 0)

In [7]:
df.head()

Unnamed: 0,label,text
0,0,Ok lar... Joking wif u oni...
1,1,Free entry in 2 a wkly comp to win FA Cup fina...
2,0,U dun say so early hor... U c already then say...
3,0,"Nah I don't think he goes to usf, he lives aro..."
4,1,FreeMsg Hey there darling it's been 3 week's n...


In [34]:
import json
spam_datasets = json.loads(df.to_json(orient='records'))

We extract text and label values:

In [8]:
text = df.text.values
labels = df.label.values

## 4. Preprocessing
We need to preprocess the text source before feeding it to BERT. To do so, we download the BertTokenizer:

In [9]:
tokenizer = BertTokenizer.from_pretrained(
    'bert-base-uncased',
    do_lower_case = True
    )

Let us observe how the tokenizer can split a random sentence into word-level tokens and map them to their respective IDs in the BERT vocabulary:

In [10]:
def print_rand_sentence():
    '''Displays the tokens and respective IDs of a random text sample'''
    index = random.randint(0, len(text)-1)
    table = np.array([tokenizer.tokenize(text[index]), 
                    tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text[index]))]).T
    print(tabulate(table,
                 headers = ['Tokens', 'Token IDs'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence()

╒═══════════╤═════════════╕
│ Tokens    │   Token IDs │
╞═══════════╪═════════════╡
│ not       │        2025 │
├───────────┼─────────────┤
│ yet       │        2664 │
├───────────┼─────────────┤
│ .         │        1012 │
├───────────┼─────────────┤
│ just      │        2074 │
├───────────┼─────────────┤
│ i         │        1045 │
├───────────┼─────────────┤
│ '         │        1005 │
├───────────┼─────────────┤
│ d         │        1040 │
├───────────┼─────────────┤
│ like      │        2066 │
├───────────┼─────────────┤
│ to        │        2000 │
├───────────┼─────────────┤
│ keep      │        2562 │
├───────────┼─────────────┤
│ in        │        1999 │
├───────────┼─────────────┤
│ touch     │        3543 │
├───────────┼─────────────┤
│ and       │        1998 │
├───────────┼─────────────┤
│ it        │        2009 │
├───────────┼─────────────┤
│ will      │        2097 │
├───────────┼─────────────┤
│ be        │        2022 │
├───────────┼─────────────┤
│ the       │       

BERT requires the following preprocessing steps:

1. Add special tokens:
- [CLS]: at the beginning of each sentence (ID 101)
- [SEP]: at the end of each sentence (ID 102)
2. Make sentences of the same length:
- This is achieved by padding, i.e. adding values of convenience to shorter sequences to match the desired length. Longer sequences are truncated.
- The padding ([PAD]) tokens have ID 0.
- The maximum sequence length allowed is of 512 tokens¹.
3. Create an attention mask:
- List of 0/1 indicating whether the model should consider the tokens or not when learning their contextual representation. We expect [PAD] tokens to have value 0.
The process can be represented as follows:

![](img/1_vaw98m1VVncgKxNFWI0d2Q.png)

We can perform all the needed steps by using the tokenizer.encode_plus⁷ method. When called, it returns a `transformers.tokenization.tokenization-utils_base`.BatchEncoding object with the following fields:

- `input_ids:` list of token IDs.
- `token_type_ids:` list of token type IDs.
- `attention_mask:` list of 0/1 indicating which tokens should be considered by the model (`return_attention_mask = True`).
As we choose `max_length = 32`, longer sentences will be truncated, while shorter sentences will be populated with `[PAD]` tokens (id: 0) until they reach the desired length.


Note: the idea of using the tokenizer.encode_plus method (plus the code for it) was borrowed from this post: BERT Fine-Tuning Tutorial with PyTorch⁸ by Chris McCormick and Nick Ryan.


In [11]:
token_id = []
attention_masks = []

def preprocessing(input_text, tokenizer):
  '''
  Returns <class transformers.tokenization_utils_base.BatchEncoding> with the following fields:
    - input_ids: list of token ids
    - token_type_ids: list of token type ids
    - attention_mask: list of indices (0,1) specifying which tokens should considered by the model (return_attention_mask = True).
  '''
  return tokenizer.encode_plus(
                        input_text,
                        add_special_tokens = True,
                        max_length = 32,
                        pad_to_max_length = True,
                        return_attention_mask = True,
                        return_tensors = 'pt'
                   )


for sample in text:
  encoding_dict = preprocessing(sample, tokenizer)
  token_id.append(encoding_dict['input_ids']) 
  attention_masks.append(encoding_dict['attention_mask'])


token_id = torch.cat(token_id, dim = 0)
attention_masks = torch.cat(attention_masks, dim = 0)
labels = torch.tensor(labels)

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


We can observe the token IDs for a text sample and recognize the presence of the special tokens [CLS] and [SEP], as well as the padding [PAD] up to the desired max_length:

In [12]:
token_id[5]

tensor([ 101, 2130, 2026, 2567, 2003, 2025, 2066, 2000, 3713, 2007, 2033, 1012,
        2027, 7438, 2033, 2066, 8387, 7353, 1012,  102,    0,    0,    0,    0,
           0,    0,    0,    0,    0,    0,    0,    0])

![](img/1_I--QXIaxEu9kJT2_UQQK5w.png)

We can also verify the output of `tokenizer.encode_plus` by inspecting tokens, their IDs and the attention mask for random text samples as follows:

In [13]:
def print_rand_sentence_encoding():
  '''Displays tokens, token IDs and attention mask of a random text sample'''
  index = random.randint(0, len(text) - 1)
  tokens = tokenizer.tokenize(tokenizer.decode(token_id[index]))
  token_ids = [i.numpy() for i in token_id[index]]
  attention = [i.numpy() for i in attention_masks[index]]

  table = np.array([tokens, token_ids, attention]).T
  print(tabulate(table, 
                 headers = ['Tokens', 'Token IDs', 'Attention Mask'],
                 tablefmt = 'fancy_grid'))

print_rand_sentence_encoding()

╒══════════╤═════════════╤══════════════════╕
│ Tokens   │   Token IDs │   Attention Mask │
╞══════════╪═════════════╪══════════════════╡
│ [CLS]    │         101 │                1 │
├──────────┼─────────────┼──────────────────┤
│ enjoy    │        5959 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ur       │       24471 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##sel    │       11246 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##f      │        2546 │                1 │
├──────────┼─────────────┼──────────────────┤
│ t        │        1056 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##m      │        2213 │                1 │
├──────────┼─────────────┼──────────────────┤
│ ##r      │        2099 │                1 │
├──────────┼─────────────┼──────────────────┤
│ .        │        1012 │                1 │
├──────────┼─────────────┼──────────────────┤
│ .        │        1012 │        

Note: BERT is a model with absolute position embeddings, so it is usually advised to pad the inputs on the right (end of the sequence) rather than the left (beginning of the sequence). In our case, tokenizer.encode_plus takes care of the needed preprocessing.

## 5. Data split
We split the dataset into train (80%) and validation (20%) sets, and wrap them around a torch.utils.data.DataLoader object. With its intuitive syntax, DataLoader provides an iterable over the given dataset.

More information on DataLoader can be found here:

Datasets & DataLoader — Pytorch Tutorials⁹
DataLoader Documentation¹⁰

In [38]:
val_ratio = 0.2
# Recommended batch size: 16, 32. See: https://arxiv.org/pdf/1810.04805.pdf
batch_size = 16

In [80]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [132]:
from datasets import Dataset

dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.15, seed=228)

In [133]:
tokenized_imdb = dataset.map(preprocess_function, batched=True)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

In [134]:
from transformers import DataCollatorWithPadding

In [135]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 6. Train
It is time for the fine-tuning task:

Select hyperparameters based on the recommendations from the BERT paper¹:
The optimal hyperparameter values are task-specific, but we found the following range of possible values to work well across all tasks:

- Batch size: 16, 32

- Learning rate (Adam): 5e-5, 3e-5, 2e-5

- Number of epochs: 2, 3, 4

Define some functions to assess validation metrics (accuracy, precision, recall and specificity) during the training process:


Download transformers.BertForSequenceClassification¹¹, which is a BERT model with a linear layer for sentence classification (or regression) on top of the pooled output:

In [136]:
# Load the BertForSequenceClassification model
model = BertForSequenceClassification.from_pretrained(
    'bert-base-uncased',
    num_labels = 2
)

loading configuration file config.json from cache at /Users/asmazaev/.cache/huggingface/hub/models--bert-base-uncased/snapshots/bdb420bf56ef3f72ee07cd75ab6df1b765b6012a/config.json
Model config BertConfig {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.22.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

loading weights file pytorch_model.bin from cache at /Users/asmazaev/.cache/huggingface/hub/models--bert-base-uncased/snapshots/bdb420bf56ef3f72ee07cd75ab6df1b765b6012a/pytorch_model.bin
Some weight

In [149]:
import datasets
metric = datasets.load_metric('f1')

Note: it is preferable to run this notebook in the presence of GPU. In order to execute it on CPU, we should comment model.cuda() in the above snippet to avoid a runtime error.

Perform the training procedure:

In [140]:
training_args = TrainingArguments(
    output_dir="./results",
    metric_for_best_model=""
    learning_rate=2e-3,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    num_train_epochs=1,
    weight_decay=0.01,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [141]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_imdb["train"],
    eval_dataset=tokenized_imdb["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

In [142]:
trainer.train()

***** Running training *****
  Num examples = 4735
  Num Epochs = 1
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 148
The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: text. If text are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.


Step,Training Loss




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=148, training_loss=0.5019146687275654, metrics={'train_runtime': 482.7773, 'train_samples_per_second': 9.808, 'train_steps_per_second': 0.307, 'total_flos': 182372447467860.0, 'train_loss': 0.5019146687275654, 'epoch': 1.0})

In [158]:
from evaluate import evaluator
import evaluate
eval = evaluator("text-classification")

In [170]:
metric = evaluate.combine(["accuracy", "f1"])

In [171]:
results = eval.compute(model_or_pipeline=model, data=tokenized_imdb["test"], tokenizer=tokenizer, metric=metric, 
                      label_mapping={"LABEL_0": 0, "LABEL_1": 1})

In [172]:
results

{'accuracy': 0.8624401913875598,
 'f1': 0.0,
 'total_time_in_seconds': 42.448227750000115,
 'samples_per_second': 19.694579592901796,
 'latency_in_seconds': 0.050775392045454684}

## 7. Predict
After a training procedure, it is a good practice to assess the model’s performances on a test set. For the purpose of this example, we simply predict the class (ham vs. spam) of a new text sample: