# Fine Tuning BERT For Named Entity Recognition On United Nations Documents

Humans understand the world by putting labels on things and examining how these labels relate to each other. A reflection of this natural language processing and information retrievial world is technique called Named Entity Recognition (NER). The objective is to detect the entity type of segments of text in a document. These entities could be organizations, locations, persons or others. 

In this blog post, I will go through an example for learning an named entity recognition model on specific domain. Instead of creating a NER model from scratch, I will use transfer-learning by taking pre-trained language model, BERT, trained on a large number of general examples and fine-tune that neural network on a very specific type of domain. 

Alongside the tutorial on learning an NER model, I will run this project on Layer in order to make use of their metadata store for storing and tracking the datasets and model artifacts as well as their free GPU compute instances. 

Firstly, let's define the problem. We are working with a set of documents from United Nations (UN). Diplomatic jargon is the norm at the UN and these documents contain many specific entities that we don't encounter in everyday language such as the Office for the Coordination of Humanitarian Affairs of the Secretariat and the Office of the United
Nations High Commissioner for Refugees. We would like to automatically detect these entities with their corresponding types. With the entities flagged, we can power many interesting use cases such as information retrivial, question/answering, document similarity etc. 

The dataset is generously made available to the public by Leslie Huang. It consists of transcribed speeches given at the UN General Assembly from 1993-2016, which were scraped from the UN website, parsed (e.g. from PDF), and cleaned. More than 50,000 tokens were manually annotated for NER tags.
https://github.com/leslie-huang/UN-named-entity-recognition

## Installing/Importing Libraries

Let's start by creating a project at Layer so that we can define a reproducible project and dataset and artifacts logged along with parameters for future reference. Layer helps you build, train and track all your machine learning project metadata including ML models and datasets‍ with semantic versioning. It also allows you to use their cloud infrastucture free of charge including access to GPUs. We will work with a pretrained transformer based language model; so added processing power is very welcome.

We will start by installing the necessary libraries. Here we log in to Layer and initialize our ML project called "united-nations-ner-finetuning".

In [58]:
from collections import Counter
from torch.utils.data import Dataset, DataLoader
import torch

import layer
from layer.decorators import dataset, model, pip_requirements, fabric, resources

layer.login()
layer.init("united-nations-ner-finetuning")

TRAIN_EXAMPLES_RATIO = 0.8
MAX_LEN = 128
TRAIN_BATCH_SIZE = 4
VALID_BATCH_SIZE = 2
EPOCHS = 1
LEARNING_RATE = 1e-05
MAX_GRAD_NORM = 10
DEVICE = "cpu"

!git clone https://github.com/leslie-huang/UN-named-entity-recognition

After setting up the ML metadatastore, we will now clone the Github repository that hosts the dataset files.

In [None]:
!git clone https://github.com/leslie-huang/UN-named-entity-recognition

## Dataset

At this step, we will load the tagged documents from both training and test sets and store them in a DataFrame.
As you may have noticed, we are using decorators from Layer to define a dataset artifact that will be logged on our cloud project at Layer. By calling "layer.run()" we will running the function "create_dataset" on the cloud infrastructure.

You may have also noticed we are logging some text metadata with the raw dataset. This enriches our ML project at the readability and reproducability level. As code is more often read then written, so are ML projects. 

Next, we will get the dataset into local memory by calling it from Layer with layer.get_dataset() function. 

Next we will examine the dataset. The annotation follows us specific Named Entity Recognition annotation scheme called IOB-tagging. It stands for Inside-Outside-Beginning. The document is tagged at the word level and entities sometimes comes in word groups. To note the entities that cover a few words we use the Beginning (B) and Inside (I) tags. 
Example: Tim Cook works at Apple. 
[Tim, Cook, works, at, Apple] -> [B-PER, I-PER, O, 0, B-ORG]

Our dataset consists of two columns where each item is a list. At "tokens" column, we have words in the document in a list. In the "ner_tags" column, we have the corresponding tags.

We will now create a Counter object from the NER tags. As expected the most common tag is "O" denoting "Outside" for words that are not a part of a named entity. Second is "I-ORG" tag denoting organisation entities and next in line is location.
An interesting find is that while we have Inside (I) tags, we don't have their beginning (B) tags. We also have some typos that have very low representations. 

In [12]:
def clean_tags(tags, tags_to_remove):
    clean_list = []
    for tag in list(tags):
        if tag != "O":
            if tag not in tags_to_remove:
                clean_list.append(tag)
            else:
                clean_list.append("O")
        else:
            clean_list.append("O")
    return clean_list

@dataset("un_ner_dataset")
@resources(path="./UN-named-entity-recognition")
def create_dataset():
    import os
    import itertools
    import pandas as pd
    from collections import Counter
    
    directories = [
        "./UN-named-entity-recognition/tagged-training/",
        "./UN-named-entity-recognition/tagged-test/",
    ]
    data_files = []
    for dir in directories:
        for filename in os.listdir(dir):
            file_path = os.path.join(dir, filename)

            with open(file_path, "r", encoding="utf8") as f:
                lines = f.readlines()
                split_list = [list(y) for x, y in itertools.groupby(lines, lambda z: z == "\n") if not x]
                tokens = [[x.split("\t")[0] for x in y] for y in split_list]
                entities = [[x.split("\t")[1][:-1] for x in y] for y in split_list]
                data_files.append(pd.DataFrame({"tokens": tokens, "ner_tags": entities}))

    dataset = pd.concat(data_files).reset_index().drop("index", axis=1)

    # Cleaning and removing bad tags
    pre_cleanup_tag_counter = Counter([tag for tags in dataset["ner_tags"] for tag in tags])
    tags_to_remove = ["I-PRG", "I-I-MISC", "I-OR", "VMISC", "I-", "0"]
    dataset["ner_tags"] = dataset["ner_tags"].apply(lambda x: clean_tags(x, tags_to_remove))
    tag_counter = Counter([tag for tags in dataset["ner_tags"] for tag in tags])
    dataset_description = """The corpus consists of a sample of transcribed speeches given at the UN General Assembly 
    from 1993-2016, which were scraped from the UN website, parsed (e.g. from PDF), and cleaned. More than 50,000 tokens 
    in the test data were manually tagged for Named Entity Recognition (O - Not a Named Entity; I-PER - Person; I-ORG - 
    Organization; I-LOC - Location; I-MISC - Other Named Entity)."""
    layer.log({"# Examples": len(dataset)})
    layer.log({"Dataset Description": dataset_description})
    layer.log({"Source": "https://github.com/leslie-huang/UN-named-entity-recognition"})
    layer.log({"Raw Tags Counter": pre_cleanup_tag_counter})
    layer.log({"Clean Tags Counter": tag_counter})

    return dataset

ner_dataset = create_dataset()

Output()

In [19]:
@pip_requirements(packages=["transformers"])
@fabric("f-medium")
@model(name="bert-base-uncased-tokenizer")
def download_tokenizer():
    from transformers import BertTokenizerFast

    tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
    return tokenizer

tokenizer = download_tokenizer()

Output()

In [59]:
class PytorchDataset(Dataset):
    def __init__(self, dataframe, tokenizer, tag_to_id, max_len):
        self.len = len(dataframe)
        self.data = dataframe
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.tag_to_id = tag_to_id

    def __getitem__(self, index):

        label_all_tokens = True
        tokenized_inputs = self.tokenizer(
            [list(self.data.tokens[index])],
            truncation=True,
            is_split_into_words=True,
            max_length=128,
            padding="max_length",
        )

        labels = []
        for i, label in enumerate([list(self.data.ner_tags[index])]):
            word_ids = tokenized_inputs.word_ids(batch_index=i)
            previous_word_idx = None
            label_ids = []
            for word_idx in word_ids:
                if word_idx is None:
                    label_ids.append(-100)
                elif label[word_idx] == "0":
                    label_ids.append(0)
                elif word_idx != previous_word_idx:
                    label_ids.append(self.tag_to_id[label[word_idx]])
                else:
                    label_ids.append(self.tag_to_id[label[word_idx]] if label_all_tokens else -100)
                previous_word_idx = word_idx
            labels.append(label_ids)

        tokenized_inputs["labels"] = labels

        single_tokenized_input = {}
        for k, v in tokenized_inputs.items():
            single_tokenized_input[k] = torch.as_tensor(v[0])

        return single_tokenized_input

    def __len__(self):
        return self.len
    
def create_model_inputs(dataset, tag_to_id):

    train_dataset = dataset.sample(frac=TRAIN_EXAMPLES_RATIO, random_state=200)
    test_dataset = dataset.drop(train_dataset.index).reset_index(drop=True)
    train_dataset = train_dataset.reset_index(drop=True)

    print("FULL Dataset: {}".format(dataset.shape))
    print("TRAIN Dataset: {}".format(train_dataset.shape))
    print("TEST Dataset: {}".format(test_dataset.shape))

    train = PytorchDataset(train_dataset, tokenizer, tag_to_id, MAX_LEN)
    test = PytorchDataset(test_dataset, tokenizer, tag_to_id, MAX_LEN)

    return train, test

tag_counter = Counter([tag for tags in ner_dataset["ner_tags"] for tag in tags])
tag_to_id = {tag: ix for ix, tag in enumerate(tag_counter.keys())}
train_set, test_set = create_model_inputs(ner_dataset, tag_to_id)

FULL Dataset: (5731, 2)
TRAIN Dataset: (4585, 2)
TEST Dataset: (1146, 2)


In [60]:
def train(train_set):
    from sklearn.metrics import accuracy_score
    from transformers import BertForTokenClassification
    from torch.utils.data import DataLoader

    train_params = {"batch_size": TRAIN_BATCH_SIZE, "shuffle": True, "num_workers": 0}
    training_loader = DataLoader(train_set, **train_params)

    model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=len(tag_to_id))
    model.to(DEVICE)

    optimizer = torch.optim.Adam(params=model.parameters(), lr=LEARNING_RATE)

    for epoch in range(EPOCHS):
        print(f"Training epoch: {epoch + 1}")
        tr_loss, tr_accuracy = 0, 0
        nb_tr_examples, nb_tr_steps = 0, 0
        tr_preds, tr_labels = [], []

        model.train()  # model in training mode

        for idx, batch in enumerate(training_loader):

            ids = batch["input_ids"].to(DEVICE, dtype=torch.long)
            mask = batch["attention_mask"].to(DEVICE, dtype=torch.long)
            labels = batch["labels"].to(DEVICE, dtype=torch.long)

            outputs = model(input_ids=ids, attention_mask=mask, labels=labels)
            loss = outputs[0]
            tr_logits = outputs[1]
            tr_loss += loss.item()

            nb_tr_steps += 1
            nb_tr_examples += labels.size(0)

            if idx % 100 == 0:
                loss_step = tr_loss / nb_tr_steps
                print(f"Training loss per 100 training steps: {loss_step}")

            # compute training accuracy
            flattened_targets = labels.view(-1)
            active_logits = tr_logits.view(-1, model.num_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1)

            # only compute accuracy at active labels
            active_accuracy = labels.view(-1) != -100
            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(flattened_predictions, active_accuracy)

            tr_labels.extend(labels)
            tr_preds.extend(predictions)

            tmp_tr_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
            tr_accuracy += tmp_tr_accuracy

            # gradient clipping
            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=MAX_GRAD_NORM)

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        epoch_loss = tr_loss / nb_tr_steps
        tr_accuracy = tr_accuracy / nb_tr_steps
        print(f"Training loss epoch: {epoch_loss}")
        print(f"Training accuracy epoch: {tr_accuracy}")

    return model


def evaluate(model, test_set, tag_to_id):
    from sklearn.metrics import classification_report
    from sklearn.metrics import accuracy_score
    
    id_to_tag = {ix: tag for tag, ix in tag_to_id.items()}
    test_params = {"batch_size": VALID_BATCH_SIZE, "shuffle": True, "num_workers": 0}
    testing_loader = DataLoader(test_set, **test_params)

    model.eval()  # model in evaluation mode

    eval_loss, eval_accuracy = 0, 0
    nb_eval_examples, nb_eval_steps = 0, 0
    eval_preds, eval_labels = [], []
    device = "cpu"
    with torch.no_grad():
        for idx, batch in enumerate(testing_loader):

            ids = batch["input_ids"].to(device, dtype=torch.long)
            mask = batch["attention_mask"].to(device, dtype=torch.long)
            labels = batch["labels"].to(device, dtype=torch.long)

            outputs = model(input_ids=ids, attention_mask=mask, labels=labels)
            loss = outputs[0]
            eval_logits = outputs[1]

            eval_loss += loss.item()

            nb_eval_steps += 1
            nb_eval_examples += labels.size(0)

            if idx % 100 == 0:
                loss_step = eval_loss / nb_eval_steps
                print(f"Validation loss per 100 evaluation steps: {loss_step}")

            # compute evaluation accuracy
            flattened_targets = labels.view(-1)
            active_logits = eval_logits.view(-1, model.num_labels)
            flattened_predictions = torch.argmax(active_logits, axis=1)

            # only compute accuracy at active labels
            active_accuracy = labels.view(-1) != -100

            labels = torch.masked_select(flattened_targets, active_accuracy)
            predictions = torch.masked_select(flattened_predictions, active_accuracy)

            eval_labels.extend(labels)
            eval_preds.extend(predictions)

            tmp_eval_accuracy = accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
            eval_accuracy += tmp_eval_accuracy

    labels = [id_to_tag[id.item()] for id in eval_labels]
    predictions = [id_to_tag[id.item()] for id in eval_preds]

    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = eval_accuracy / nb_eval_steps
    layer.log({"Test Loss": eval_loss, "Test Accuracy": eval_accuracy})

    print(f"Validation Loss: {eval_loss}")
    print(f"Validation Accuracy: {eval_accuracy}")

    print(classification_report(labels, predictions))
    layer.log(classification_report(labels, predictions, output_dict=True))


@pip_requirements(packages=["transformers", "sklearn", "torch"])
@fabric("f-gpu-small")
@model("un_ner_fine-tuned_bert")
def run_model_training():
    model = train(train_set)
    evaluate(model, test_set, tag_to_id)
    return model

model = run_model_training()

Output()

In [55]:
def predict_ner_example(sentence):
    inputs = tokenizer(
        sentence.split(),
        is_split_into_words=True,
        return_offsets_mapping=True,
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    )

    id_to_tag = {ix: tag for tag, ix in tag_to_id.items()}
    
    ids = inputs["input_ids"]
    mask = inputs["attention_mask"]
    # forward pass
    outputs = model(ids, attention_mask=mask)
    logits = outputs[0]

    active_logits = logits.view(-1, model.num_labels)  # shape (batch_size * seq_len, num_labels)
    flattened_predictions = torch.argmax(
        active_logits, axis=1
    )  # shape (batch_size*seq_len,) - predictions at the token level

    tokens = tokenizer.convert_ids_to_tokens(ids.squeeze().tolist())
    token_predictions = [id_to_tag[i] for i in flattened_predictions.cpu().numpy()]
    wp_preds = list(zip(tokens, token_predictions))  # list of tuples. Each tuple = (wordpiece, prediction)

    prediction = []
    for token_pred, mapping in zip(wp_preds, inputs["offset_mapping"].squeeze().tolist()):
        # only predictions on first word pieces are important
        if mapping[0] == 0 and mapping[1] != 0:
            prediction.append(token_pred[1])
        else:
            continue
            
    return sentence, prediction

sentence = """Expressing deep concern about the impact of the food security crisis on the
assistance provided by United Nations humanitarian agencies, in particular the World
Food Programme."""

sentence, prediction = predict_ner_example(sentence)
print(sentence.split())
print(prediction)

['Expressing', 'deep', 'concern', 'about', 'the', 'impact', 'of', 'the', 'food', 'security', 'crisis', 'on', 'the', 'assistance', 'provided', 'by', 'United', 'Nations', 'humanitarian', 'agencies,', 'in', 'particular', 'the', 'World', 'Food', 'Programme.']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
