*Copyright (c) Microsoft Corporation. All rights reserved.*  
*Licensed under the MIT License.*

# Named Entity Recognition Using BERT

# Before You Start

The running time shown in this notebook is on a Standard_NC24s_v3 Azure Deep Learning Virtual Machine with 4 NVIDIA Tesla V100 GPUs. If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. 
The table below provides some reference running time on different machine configurations.  

|QUICK_RUN|Machine Configurations|Running time|
|:---------|:----------------------|:------------|
|True|4 **CPU**s, 14GB memory| ~ 2 minutes|
|False|4 **CPU**s, 14GB memory| ~1.5 hours|
|True|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| |
|False|1 NVIDIA Tesla K80 GPUs, 12GB GPU memory| |

If you run into CUDA out-of-memory error, try reducing the `BATCH_SIZE` and `MAX_SEQ_LENGTH`, but note that model performance will be compromised. 

In [1]:
from datetime import datetime
startTime = datetime.now()

In [2]:
## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.
QUICK_RUN = False

TRAIN_DATA_USED_PERCENT = 1
TEST_DATA_USED_PERCENT = 1
NUM_TRAIN_EPOCHS = 5

In [3]:
if QUICK_RUN:
    TRAIN_DATA_USED_PERCENT = 0.1
    TEST_DATA_USED_PERCENT = 0.1
    NUM_TRAIN_EPOCHS = 1

import torch
if torch.cuda.is_available():
    BATCH_SIZE = 16
else:
    BATCH_SIZE = 8

## Summary
This notebook demonstrates how to fine tune [pretrained BERT model](https://github.com/huggingface/pytorch-pretrained-BERT) for named entity recognition (NER) task. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, and model evaluation. 

[BERT (Bidirectional Transformers for Language Understanding)](https://arxiv.org/pdf/1810.04805.pdf) is a powerful pre-trained lanaguage model that can be used for multiple NLP tasks, including text classification, question answering, named entity recognition, etc. It's able to achieve state of the art performance with only a few epochs of fine tuning on task specific datasets.  
The figure below illustrates how BERT can be fine tuned for NER tasks. The input data is a list of tokens representing a sentence. In the training data, each token has an entity label. After fine tuning, the model predicts an entity label for each token in a given testing sentence. 

<img src="https://nlpbp.blob.core.windows.net/images/bert_architecture.png">

In [4]:
import sys
import os
import random
from seqeval.metrics import classification_report

nlp_path = os.path.abspath('../../')
if nlp_path not in sys.path:
    sys.path.insert(0, nlp_path)

from utils_nlp.models.bert.token_classification import BERTTokenClassifier, create_label_map, postprocess_token_labels
from utils_nlp.models.bert.common import Language, Tokenizer
from utils_nlp.dataset.wikigold import load_train_test_dfs, get_unique_labels
from utils_nlp.common.timer import Timer

## Configurations

In [5]:
# path configuration
CACHE_DIR="./temp"

# set random seeds
RANDOM_SEED = 100
torch.manual_seed(RANDOM_SEED)

# model configurations
LANGUAGE = Language.ENGLISHCASED
DO_LOWER_CASE = False
MAX_SEQ_LENGTH = 200

# optimizer configuration
LEARNING_RATE = 3e-5

# data configurations
TEXT_COL = "sentence"
LABELS_COL = "labels"

## Preprocess Data

### Get training and testing data
The dataset used in this notebook is the [wikigold dataset](https://www.aclweb.org/anthology/W09-3302). The wikigold dataset consists of 145 mannually labelled Wikipedia articles, including 1841 sentences and 40k tokens in total. The dataset can be directly downloaded from [here](https://github.com/juand-r/entity-recognition-datasets/tree/master/data/wikigold). 

The helper function `load_train_test_dfs` downloads the data file if it doesn't exist in `local_cache_path`. It splits the dataset into training and testing sets according to `test_percentage`. Because this is a relatively small dataset, we set `test_percentage` to 0.5 in order to have enough data for model evaluation. Running this notebook multiple times with different random seeds produces similar results.   

The helper function `get_unique_labels` returns the unique entity labels in the dataset. There are 5 unique labels in the   original dataset: 'O' (non-entity), 'I-LOC' (location), 'I-MISC' (miscellaneous), 'I-PER' (person), and 'I-ORG' (organization). 

The maximum number of words in a sentence is 144, so we set MAX_SEQ_LENGTH to 200 above, because the number of tokens will grow after WordPiece tokenization.

In [6]:
train_df, test_df = load_train_test_dfs(local_cache_path=CACHE_DIR, test_percentage=0.5,random_seed=RANDOM_SEED)
label_list = get_unique_labels()
print('\nUnique entity labels: \n{}\n'.format(label_list))
print('Sample sentence: \n{}\n'.format(train_df[TEXT_COL][0]))
print('Sample sentence labels: \n{}\n'.format(train_df[LABELS_COL][0]))

Maximum sequence length in the  data is: 144

Unique entity labels: 
['O', 'I-LOC', 'I-MISC', 'I-PER', 'I-ORG']

Sample sentence: 
['Two', ',', 'Samsung', 'based', ',', 'electronic', 'cash', 'registers', 'were', 'reconstructed', 'in', 'order', 'to', 'expand', 'their', 'functions', 'and', 'adapt', 'them', 'for', 'networking', '.']

Sample sentence labels: 
['O', 'O', 'I-ORG', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']



In [7]:
train_df = train_df.sample(frac=TRAIN_DATA_USED_PERCENT).reset_index(drop=True)
test_df = test_df.sample(frac=TEST_DATA_USED_PERCENT).reset_index(drop=True)

**Note that the input text are lists of words instead of raw sentences. This format ensures matching between input words and token labels when the words are further tokenized by Tokenizer.tokenize_ner.**

### Tokenization and Preprocessing


**Create a dictionary that maps labels to numerical values**  
Note there is an argument called `trailing_piece_tag`. BERT uses a WordPiece tokenizer which breaks down some words into multiple tokens, e.g. "criticize" is tokenized into "critic" and "##ize". Since the input data only come with one token label for "criticize", within Tokenizer.prerocess_ner_tokens, the original token label is assigned to the first token "critic" and the second token "##ize" is labeled as "X". By default, `trailing_piece_tag` is set to "X". If "X" already exists in your data, you can set `trailing_piece_tag` to another value that doesn't exist in your data.

In [8]:
label_map = create_label_map(label_list, trailing_piece_tag="X")

**Create a tokenizer**

In [9]:
tokenizer = Tokenizer(language=LANGUAGE, 
                      to_lower=DO_LOWER_CASE, 
                      cache_dir=CACHE_DIR)

**Tokenize and preprocess text**  
The `tokenize_ner` method of the `Tokenizer` class converts text and labels in strings to numerical features, involving the following steps:
1. WordPiece tokenization.
2. Convert tokens and labels to numerical values, i.e. token ids and label ids.
3. Sequence padding or truncation according to the `max_seq_length` configuration.

In [10]:
train_token_ids, train_input_mask, train_trailing_token_mask, train_label_ids = \
    tokenizer.tokenize_ner(text=train_df[TEXT_COL],
                           label_map=label_map,
                           max_len=MAX_SEQ_LENGTH,
                           labels=train_df[LABELS_COL],
                           trailing_piece_tag="X")
test_token_ids, test_input_mask, test_trailing_token_mask, test_label_ids = \
    tokenizer.tokenize_ner(text=test_df[TEXT_COL],
                           label_map=label_map,
                           max_len=MAX_SEQ_LENGTH,
                           labels=test_df[LABELS_COL],
                           trailing_piece_tag="X")

`Tokenizer.tokenize_ner` outputs three or four lists of numerical features lists, each sublist contains features of an input sentence: 
1. token ids: list of numerical values each corresponds to a token.
2. attention mask: list of 1s and 0s, 1 for input tokens and 0 for padded tokens, so that padded tokens are not attended to. 
3. trailing word piece mask: boolean list, `True` for the first word piece of each original word, `False` for the trailing word pieces, e.g. ##ize. This mask is useful for removing predictions on trailing word pieces, so that each original word in the input text has a unique predicted label. 
4. label ids: list of numerical values each corresponds to an entity label, if `labels` is provided.

In [11]:
print("Sample token ids:\n{}\n".format(train_token_ids[0]))
print("Sample attention mask:\n{}\n".format(train_input_mask[0]))
print("Sample trailing token mask:\n{}\n".format(train_trailing_token_mask[0]))
print("Sample label ids:\n{}\n".format(train_label_ids[0]))

Sample token ids:
[107, 1124, 1674, 183, 112, 189, 1541, 1474, 1155, 1115, 1277, 1133, 3093, 1106, 1243, 1103, 2261, 1694, 1268, 119, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0

## Create Token Classifier
The value of the `language` argument determines which BERT model is used:
* Language.ENGLISH: "bert-base-uncased"
* Language.ENGLISHCASED: "bert-base-cased"
* Language.ENGLISHLARGE: "bert-large-uncased"
* Language.ENGLISHLARGECASED: "bert-large-cased"
* Language.CHINESE: "bert-base-chinese"
* Language.MULTILINGUAL: "bert-base-multilingual-cased"

Here we use the base, cased pretrained model.

In [12]:
token_classifier = BERTTokenClassifier(language=LANGUAGE,
                                       num_labels=len(label_map),
                                       cache_dir=CACHE_DIR)

## Train Model

In [13]:
with Timer() as t:
    token_classifier.fit(token_ids=train_token_ids, 
                         input_mask=train_input_mask, 
                         labels=train_label_ids,
                         num_epochs=NUM_TRAIN_EPOCHS, 
                         batch_size=BATCH_SIZE, 
                         learning_rate=LEARNING_RATE)
print("Training time : {:.3f} hrs".format(t.interval / 3600))

t_total value of -1 results in schedule not being applied
Epoch:   0%|          | 0/5 [00:00<?, ?it/s]
Iteration:   0%|          | 0/116 [00:00<?, ?it/s][A
Iteration:   3%|▎         | 4/116 [00:34<16:02,  8.59s/it][A
Iteration:   3%|▎         | 4/116 [00:50<16:02,  8.59s/it][A
Iteration:   7%|▋         | 8/116 [01:08<15:26,  8.57s/it][A
Iteration:   7%|▋         | 8/116 [01:20<15:26,  8.57s/it][A
Iteration:  10%|█         | 12/116 [01:41<14:40,  8.47s/it][A
Iteration:  10%|█         | 12/116 [02:00<14:40,  8.47s/it][A
Iteration:  14%|█▍        | 16/116 [02:14<13:58,  8.38s/it][A
Iteration:  14%|█▍        | 16/116 [02:30<13:58,  8.38s/it][A
Iteration:  17%|█▋        | 20/116 [02:47<13:22,  8.35s/it][A
Iteration:  17%|█▋        | 20/116 [03:00<13:22,  8.35s/it][A
Iteration:  21%|██        | 24/116 [03:20<12:47,  8.35s/it][A
Iteration:  21%|██        | 24/116 [03:40<12:47,  8.35s/it][A
Iteration:  24%|██▍       | 28/116 [03:54<12:19,  8.40s/it][A
Iteration:  24%|██▍       | 

Train loss: 0.3545071521677591



Iteration:   3%|▎         | 4/116 [00:32<15:20,  8.22s/it][A
Iteration:   3%|▎         | 4/116 [00:45<15:20,  8.22s/it][A
Iteration:   7%|▋         | 8/116 [01:06<14:51,  8.25s/it][A
Iteration:   7%|▋         | 8/116 [01:25<14:51,  8.25s/it][A
Iteration:  10%|█         | 12/116 [01:39<14:17,  8.25s/it][A
Iteration:  10%|█         | 12/116 [01:55<14:17,  8.25s/it][A
Iteration:  14%|█▍        | 16/116 [02:12<13:45,  8.26s/it][A
Iteration:  14%|█▍        | 16/116 [02:25<13:45,  8.26s/it][A
Iteration:  17%|█▋        | 20/116 [02:45<13:11,  8.24s/it][A
Iteration:  17%|█▋        | 20/116 [02:55<13:11,  8.24s/it][A
Iteration:  21%|██        | 24/116 [03:17<12:36,  8.22s/it][A
Iteration:  21%|██        | 24/116 [03:35<12:36,  8.22s/it][A
Iteration:  24%|██▍       | 28/116 [03:51<12:05,  8.25s/it][A
Iteration:  24%|██▍       | 28/116 [04:05<12:05,  8.25s/it][A
Iteration:  28%|██▊       | 32/116 [04:24<11:33,  8.26s/it][A
Iteration:  28%|██▊       | 32/116 [04:35<11:33,  8.26s/it

Train loss: 0.10110071701286681



Iteration:   3%|▎         | 4/116 [00:37<17:36,  9.43s/it][A
Iteration:   3%|▎         | 4/116 [00:48<17:36,  9.43s/it][A
Iteration:   7%|▋         | 8/116 [01:13<16:42,  9.28s/it][A
Iteration:   7%|▋         | 8/116 [01:28<16:42,  9.28s/it][A
Iteration:  10%|█         | 12/116 [01:49<15:56,  9.19s/it][A
Iteration:  10%|█         | 12/116 [02:08<15:56,  9.19s/it][A
Iteration:  14%|█▍        | 16/116 [02:24<15:09,  9.10s/it][A
Iteration:  14%|█▍        | 16/116 [02:38<15:09,  9.10s/it][A
Iteration:  17%|█▋        | 20/116 [03:00<14:26,  9.02s/it][A
Iteration:  17%|█▋        | 20/116 [03:18<14:26,  9.02s/it][A
Iteration:  21%|██        | 24/116 [03:36<13:52,  9.05s/it][A
Iteration:  21%|██        | 24/116 [03:48<13:52,  9.05s/it][A
Iteration:  24%|██▍       | 28/116 [04:12<13:11,  8.99s/it][A
Iteration:  24%|██▍       | 28/116 [04:28<13:11,  8.99s/it][A
Iteration:  28%|██▊       | 32/116 [04:45<12:17,  8.78s/it][A
Iteration:  28%|██▊       | 32/116 [04:58<12:17,  8.78s/it

Train loss: 0.05429209667611225



Iteration:   3%|▎         | 4/116 [00:32<15:05,  8.09s/it][A
Iteration:   3%|▎         | 4/116 [00:48<15:05,  8.09s/it][A
Iteration:   7%|▋         | 8/116 [01:04<14:34,  8.10s/it][A
Iteration:   7%|▋         | 8/116 [01:18<14:34,  8.10s/it][A
Iteration:  10%|█         | 12/116 [01:37<14:05,  8.13s/it][A
Iteration:  10%|█         | 12/116 [01:48<14:05,  8.13s/it][A
Iteration:  14%|█▍        | 16/116 [02:10<13:33,  8.14s/it][A
Iteration:  14%|█▍        | 16/116 [02:28<13:33,  8.14s/it][A
Iteration:  17%|█▋        | 20/116 [02:42<13:00,  8.13s/it][A
Iteration:  17%|█▋        | 20/116 [02:58<13:00,  8.13s/it][A
Iteration:  21%|██        | 24/116 [03:16<12:33,  8.19s/it][A
Iteration:  21%|██        | 24/116 [03:28<12:33,  8.19s/it][A
Iteration:  24%|██▍       | 28/116 [03:48<12:01,  8.20s/it][A
Iteration:  24%|██▍       | 28/116 [04:08<12:01,  8.20s/it][A
Iteration:  28%|██▊       | 32/116 [04:21<11:29,  8.21s/it][A
Iteration:  28%|██▊       | 32/116 [04:38<11:29,  8.21s/it

Train loss: 0.03206114985193286



Iteration:   3%|▎         | 4/116 [00:35<16:40,  8.94s/it][A
Iteration:   3%|▎         | 4/116 [00:45<16:40,  8.94s/it][A
Iteration:   7%|▋         | 8/116 [01:11<16:03,  8.92s/it][A
Iteration:   7%|▋         | 8/116 [01:25<16:03,  8.92s/it][A
Iteration:  10%|█         | 12/116 [01:46<15:25,  8.90s/it][A
Iteration:  10%|█         | 12/116 [02:05<15:25,  8.90s/it][A
Iteration:  14%|█▍        | 16/116 [02:23<14:56,  8.96s/it][A
Iteration:  14%|█▍        | 16/116 [02:35<14:56,  8.96s/it][A
Iteration:  17%|█▋        | 20/116 [02:59<14:23,  8.99s/it][A
Iteration:  17%|█▋        | 20/116 [03:15<14:23,  8.99s/it][A
Iteration:  21%|██        | 24/116 [03:35<13:49,  9.02s/it][A
Iteration:  21%|██        | 24/116 [03:45<13:49,  9.02s/it][A
Iteration:  24%|██▍       | 28/116 [04:11<13:14,  9.03s/it][A
Iteration:  24%|██▍       | 28/116 [04:25<13:14,  9.03s/it][A
Iteration:  28%|██▊       | 32/116 [04:47<12:37,  9.02s/it][A
Iteration:  28%|██▊       | 32/116 [05:05<12:37,  9.02s/it

Train loss: 0.023472348440855998
Training time : 1.376 hrs





## Predict on Test Data

In [14]:
with Timer() as t:
    pred_label_ids = token_classifier.predict(token_ids=test_token_ids, 
                                              input_mask=test_input_mask, 
                                              labels=test_label_ids, 
                                              batch_size=BATCH_SIZE)
print("Prediction time : {:.3f} hrs".format(t.interval / 3600))

Iteration: 100%|██████████| 115/115 [03:23<00:00,  1.77s/it]

Evaluation loss: 0.11518098939939038
Prediction time : 0.056 hrs





## Evaluate Model
The `predict` method of the token classifier outputs label ids for all tokens, including the padded tokens. `postprocess_token_labels` is a helper function that removes the predictions on padded tokens. If a `label_map` is provided, it maps the numerical label ids back to original token labels which are usually string type. 

In [15]:
pred_tags_no_padding = postprocess_token_labels(pred_label_ids, 
                                                test_input_mask, 
                                                label_map)
true_tags_no_padding = postprocess_token_labels(test_label_ids, 
                                                test_input_mask, 
                                                label_map)
print(classification_report(true_tags_no_padding, pred_tags_no_padding, digits=2))

           precision    recall  f1-score   support

        X       0.96      0.97      0.96      1983
      ORG       0.73      0.71      0.72       538
      LOC       0.85      0.90      0.87       543
      PER       0.95      0.93      0.94       550
     MISC       0.61      0.84      0.71       396

micro avg       0.87      0.91      0.89      4010
macro avg       0.88      0.91      0.89      4010



`postprocess_token_labels` also provides an option to remove the predictions on trailing word pieces, e.g. ##ize, so that the final predicted labels correspond to the original words in the input text. The `trailing_token_mask` is obtained from `tokenizer.tokenize_ner`

In [16]:
pred_tags_no_padding_no_trailing = postprocess_token_labels(pred_label_ids, 
                                                            test_input_mask, 
                                                            label_map, 
                                                            remove_trailing_word_pieces=True, 
                                                            trailing_token_mask=test_trailing_token_mask)
true_tags_no_padding_no_trailing = postprocess_token_labels(test_label_ids, 
                                                            test_input_mask, 
                                                            label_map, 
                                                            remove_trailing_word_pieces=True, 
                                                            trailing_token_mask=test_trailing_token_mask)
print(classification_report(true_tags_no_padding_no_trailing, pred_tags_no_padding_no_trailing, digits=2))

           precision    recall  f1-score   support

      ORG       0.68      0.70      0.69       442
      LOC       0.85      0.90      0.87       503
      PER       0.94      0.92      0.93       455
     MISC       0.60      0.82      0.69       344

micro avg       0.75      0.84      0.79      1744
macro avg       0.78      0.84      0.81      1744



We can see that the metrics are worse after excluding trailing word pieces, because they are easy to predict. 

## Conclusion
By fine-tuning the pre-trained BERT model for token classification, we achieved significantly better results compared to the [original paper on the wikigold dataset](https://www.aclweb.org/anthology/W09-3302) with a much smaller training dataset. 

In [17]:
print(datetime.now() - startTime)

1:26:10.362535
