# Fine-tuning Entity Pair Classification Task with BERT

This example is implemented using GluonNLP API and BERT.

https://gluon-nlp.mxnet.io/examples/sentence_embedding/bert.html




### Start with importing neccesarry modules

In [3]:
import warnings
warnings.filterwarnings('ignore')

import io
import random
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from Bert import data, model

### Setting up the environment

In [4]:
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)
ctx = mx.cpu(0)

## Using the pre-trained BERT model



"The BERT model we will use is BERT
BASE trained on an uncased corpus of books and
the English Wikipedia dataset in the
GluonNLP model zoo."

### Get BERT



Load the pre-trained BERT model. 

In [5]:
bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',
                                             dataset_name='book_corpus_wiki_en_uncased',
                                             pretrained=True, ctx=ctx, use_pooler=True,
                                             use_decoder=False, use_classifier=False)
print(bert_base)

Vocab file is not found. Downloading.
Downloading C:\Users\CRS\.mxnet\models\1578777863.8742678book_corpus_wiki_en_uncased-a6607397.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/book_corpus_wiki_en_uncased-a6607397.zip...
Downloading C:\Users\CRS\.mxnet\models\bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip...
BERTModel(
  (encoder): BERTEncoder(
    (dropout_layer): Dropout(p = 0.1, axes=())
    (layer_norm): BERTLayerNorm(eps=1e-12, axis=-1, center=True, scale=True, in_channels=768)
    (transformer_cells): HybridSequential(
      (0): BERTEncoderCell(
        (dropout_layer): Dropout(p = 0.1, axes=())
        (attention_cell): MultiHeadAttentionCell(
          (_base_cell): DotProductAttentionCell(
            (_dropout_layer): Dropout(p = 0.1, axes=())
          )
          (proj_query): Dense(768 -> 7

### Transform the model for Entity classification

The original model was trained for sentence pair classificaton. 
Entity + Context pair will be fed to the model. 
Both entities with context will be treated as sentences and the model will try to predict their label. 
After BERT model is loaded, another layer for classification is attached to it.  
The `BERTClassifier` class uses a BERT base model to encode entity
representation, followed by a `nn.Dense` layer for classification.

In [6]:
bert_classifier = model.classification.BERTClassifier(bert_base, num_classes=6, dropout=0.1)
# only need to initialize the classifier layer.
bert_classifier.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
bert_classifier.hybridize(static_alloc=True)

# softmax cross entropy loss for classification
loss_function = mx.gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)

metric = mx.metric.Accuracy()

## Data preprocessing for BERT


### Loading the dataset

1. Import the pickeled file of entity pair in a dataframe.
2. Transform the dataframe into a tsv(tab separated values) file
3. Read the tsv file and print some examples to have a feeling how does data look like before it is fed into model.

In [34]:
import pickle
import pandas as pd
entity_pair = pd.read_pickle('entity_pair.pkl')
entity_pair.head()
type(entity_pair)

pandas.core.frame.DataFrame

In [37]:
entity_pair.to_csv('train.tsv', sep = '\t', index=False)

In [38]:
tsv_file = io.open('train.tsv', encoding='utf-8')
for i in range(5):
    print(tsv_file.readline())

entity_text_1	entity_text_2	label

database traditional information retrieval techniques use a	histogram of keywords as the	USAGE

representation but oral communication may offer	offer additional indices such as	USAGE

a large database of tv	database of tv shows emotions and	PART_WHOLE

of a distributed message-passing infrastructure for dialogue	infrastructure for dialogue systems which all	MODEL-FEATURE



In [39]:
# Skip the first line, column names
num_discard_samples = 1
# Split fields by tabs
field_separator = nlp.data.Splitter('\t')
# Fields to select from the file
field_indices = [0, 1, 2]
train_data = nlp.data.TSVDataset(filename='train.tsv',
                                 field_separator=field_separator,
                                 num_discard_samples=num_discard_samples,
                                 field_indices=field_indices)
sample_id = 0
# First entity context
print(train_data[sample_id][0])
# Second enitiy context
print(train_data[sample_id][1])
# relation type (label)
print(train_data[sample_id][2])

database traditional information retrieval techniques use a
histogram of keywords as the
USAGE


We will use
`BERTDatasetTransform` to perform the following transformations:
- tokenize
the
input sequences
- insert [CLS] at the beginning
- insert [SEP] between entities and at the end
- generate segment ids to indicate whether a token belongs to the first sequence or the second sequence.
- generate valid length

In [41]:
# Use the vocabulary from pre-trained model for tokenization
bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=True)

# The maximum length of an input sequence, the averge length of the entity
max_len = 12

all_labels = ["USAGE", "PART_WHOLE", "MODEL-FEATURE", "RESULT", "COMPARE", "TOPIC"]

pair = True
transform = data.transform.BERTDatasetTransform(bert_tokenizer, max_len,
                                                class_labels=all_labels,
                                                has_label=True,
                                                pad=True,
                                                pair=pair)
#data_train = data_train_raw.transform(transform)
data_train = train_data.transform(transform)

In [42]:
print('vocabulary used for tokenization = \n%s'%vocabulary)
print('%s token id = %s'%(vocabulary.padding_token, vocabulary[vocabulary.padding_token]))
print('%s token id = %s'%(vocabulary.cls_token, vocabulary[vocabulary.cls_token]))
print('%s token id = %s'%(vocabulary.sep_token, vocabulary[vocabulary.sep_token]))
print('token ids = \n%s'%data_train[sample_id][0])
print('valid length = \n%s'%data_train[sample_id][1])
print('segment ids = \n%s'%data_train[sample_id][2])
print('label = \n%s'%data_train[sample_id][3])

vocabulary used for tokenization = 
Vocab(size=30522, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")
[PAD] token id = 1
[CLS] token id = 2
[SEP] token id = 3
token ids = 
[    2  7809  3151  2592 26384  5461     3  2010  3406 13113  1997     3]
valid length = 
12
segment ids = 
[0 0 0 0 0 0 0 1 1 1 1 1]
label = 
[0]


## Fine-tuning the model

In [43]:
# The hyperparameters
batch_size = 32
lr = 5e-6

# The FixedBucketSampler and the DataLoader for making the mini-batches
train_sampler = nlp.data.FixedBucketSampler(lengths=[int(item[1]) for item in data_train],
                                            batch_size=batch_size,
                                            shuffle=True)
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler)

trainer = mx.gluon.Trainer(bert_classifier.collect_params(), 'adam',
                           {'learning_rate': lr, 'epsilon': 1e-9})

# Collect all differentiable parameters
# `grad_req == 'null'` indicates no gradients are calculated (e.g. constant parameters)
# The gradients for these params are clipped later
params = [p for p in bert_classifier.collect_params().values() if p.grad_req != 'null']
grad_clip = 1

# Training the model with only three epochs
log_interval = 4
num_epochs = 3
for epoch_id in range(num_epochs):
    metric.reset()
    step_loss = 0
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(bert_dataloader):
        with mx.autograd.record():

            # Load the data to the GPU
            token_ids = token_ids.as_in_context(ctx)
            valid_length = valid_length.as_in_context(ctx)
            segment_ids = segment_ids.as_in_context(ctx)
            label = label.as_in_context(ctx)

            # Forward computation
            out = bert_classifier(token_ids, segment_ids, valid_length.astype('float32'))
            ls = loss_function(out, label).mean()

        # And backwards computation
        ls.backward()

        # Gradient clipping
        trainer.allreduce_grads()
        nlp.utils.clip_grad_global_norm(params, 1)
        trainer.update(1)

        step_loss += ls.asscalar()
        metric.update([label], [out])

        # Printing vital information
        if (batch_id + 1) % (log_interval) == 0:
            print('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.7f}, acc={:.3f}'
                         .format(epoch_id, batch_id + 1, len(bert_dataloader),
                                 step_loss / log_interval,
                                 trainer.learning_rate, metric.get()[1]))
            step_loss = 0

[Epoch 0 Batch 4/40] loss=1.7237, lr=0.0000050, acc=0.336
[Epoch 0 Batch 8/40] loss=1.7798, lr=0.0000050, acc=0.300
[Epoch 0 Batch 12/40] loss=1.6881, lr=0.0000050, acc=0.335
[Epoch 0 Batch 16/40] loss=1.7406, lr=0.0000050, acc=0.329
[Epoch 0 Batch 20/40] loss=1.6590, lr=0.0000050, acc=0.345
[Epoch 0 Batch 24/40] loss=1.6636, lr=0.0000050, acc=0.345
[Epoch 0 Batch 28/40] loss=1.6490, lr=0.0000050, acc=0.340
[Epoch 0 Batch 32/40] loss=1.7028, lr=0.0000050, acc=0.329
[Epoch 0 Batch 36/40] loss=1.5662, lr=0.0000050, acc=0.341
[Epoch 0 Batch 40/40] loss=1.6497, lr=0.0000050, acc=0.344
[Epoch 1 Batch 4/40] loss=1.5859, lr=0.0000050, acc=0.383
[Epoch 1 Batch 8/40] loss=1.5810, lr=0.0000050, acc=0.367
[Epoch 1 Batch 12/40] loss=1.5659, lr=0.0000050, acc=0.357
[Epoch 1 Batch 16/40] loss=1.4698, lr=0.0000050, acc=0.380
[Epoch 1 Batch 20/40] loss=1.5260, lr=0.0000050, acc=0.382
[Epoch 1 Batch 24/40] loss=1.4942, lr=0.0000050, acc=0.381
[Epoch 1 Batch 28/40] loss=1.3811, lr=0.0000050, acc=0.394
[