# Fine-tuning Sentence Pair Classification with BERT
### Importing necessary modules

In [1]:
import warnings
warnings.filterwarnings('ignore')

import io
import random
import numpy as np
import mxnet as mx
import gluonnlp as nlp
from bert import data, model
import time

In [2]:
np.random.seed(100)
random.seed(100)
mx.random.seed(10000)
# change `ctx` to `mx.cpu()` if no GPU is available.
ctx = mx.gpu(0)
# ctx = mx.cpu()

## Using the pre-trained BERT model

The list of pre-trained BERT models available
in GluonNLP can be found
[here](../../model_zoo/bert/index.rst).

In this
tutorial, the BERT model we will use is BERT
BASE trained on an uncased corpus of books and
the English Wikipedia dataset in the
GluonNLP model zoo.

### Get BERT

Let's first take
a look at the BERT model
architecture for sentence pair classification below:
<div style="width:
500px;">![bert-sentence-pair](bert-sentence-pair.png)</div>
where the model takes a pair of
sequences and pools the representation of the
first token in the sequence.
Note that the original BERT model was trained for a
masked language model and next-sentence prediction tasks, which includes layers
for language model decoding and
classification. These layers will not be used
for fine-tuning the sentence pair classification.

We can load the
pre-trained BERT fairly easily
using the model API in GluonNLP, which returns the vocabulary
along with the
model. We include the pooler layer of the pre-trained model by setting
`use_pooler` to `True`.

In [3]:
bert_base, vocabulary = nlp.model.get_model('bert_12_768_12',#book_corpus_wiki_en_uncased
                                             dataset_name='biobert_v1.1_pubmed_cased', #biobert_v1.1_pubmed_cased
                                             pretrained=True, ctx=ctx, use_pooler=True,
                                             use_decoder=False, use_classifier=False)
# print(bert_base)
print(vocabulary)

Vocab(size=28996, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")


### Transform the model for `SentencePair` classification

Now that we have loaded
the BERT model, we only need to attach an additional layer for classification.
The `BERTClassifier` class uses a BERT base model to encode sentence
representation, followed by a `nn.Dense` layer for classification.

In [4]:
bert_classifier = model.classification.BERTClassifier(bert_base, num_classes=2, dropout=0.1)
# only need to initialize the classifier layer.
bert_classifier.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
bert_classifier.hybridize(static_alloc=True)

# softmax cross entropy loss for classification
loss_function = mx.gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)

metric = mx.metric.Accuracy()

## Data preprocessing for BERT

For this tutorial, we need to do a bit of preprocessing before feeding our data introduced
the BERT model. Here we want to leverage the dataset included in the downloaded archive at the
beginning of this tutorial.

### Loading the dataset

We use
the dev set of the
Microsoft Research Paraphrase Corpus dataset. The file is
named 'dev.tsv'. Let's take a look at the first few lines of the raw dataset.

In [5]:
# Skip the first line, which is the schema
num_discard_samples = 1
# Split fields by tabs
field_separator = nlp.data.Splitter('\t')
# Fields to select from the file
field_indices = [3, 4, 0]
data_train_raw = nlp.data.TSVDataset(filename='bert_train.tsv',
                                 field_separator=field_separator,
                                 num_discard_samples=num_discard_samples,
                                 field_indices=field_indices)
# data_train_val_raw = nlp.data.TSVDataset(filename='bert_train_val.tsv',
#                                  field_separator=field_separator,
#                                  num_discard_samples=num_discard_samples,
#                                  field_indices=field_indices)


sample_id = 0
# Sentence A
print(data_train_raw[sample_id][0])
# Sentence B
print(data_train_raw[sample_id][1])
# 1 means equivalent, 0 means not equivalent
print(data_train_raw[sample_id][2])

Moreover, Wnt-1-inducible secreted protein-1 (WISP-1), which is a responsive gene of Wnt activation , can promote angiogenesis in post-MI heart via regulating histone deacetylase [[**##**]].
Angiokine Wisp-1 is increased in myocardial infarction and regulates cardiac endothelial signaling. Myocardial infarctions (MIs) cause the loss of myocytes due to lack of sufficient oxygenation and latent revascularization. Although the administration of histone deacetylase (HDAC) inhibitors reduces the size of infarctions and improves cardiac physiology in small-animal models of MI injury, the cellular targets of the HDACs, which the drugs inhibit, are largely unspecified. Here, we show that WNT-inducible secreted protein-1 (Wisp-1), a matricellular protein that promotes angiogenesis in cancers as well as cell survival in isolated cardiac myocytes and neurons, is a target of HDACs. Further, Wisp-1 transcription is regulated by HDACs and can be modified by the HDAC inhibitor, suberanilohydroxamic a

To use the pre-trained BERT model, we need to pre-process the data in the same
way it was trained. The following figure shows the input representation in BERT:
<div style="width: 500px;">![bert-embed](bert-embed.png)</div>

We will use
`BERTDatasetTransform` to perform the following transformations:
- tokenize
the
input sequences
- insert [CLS] at the beginning
- insert [SEP] between sentence
A and sentence B, and at the end
- generate segment ids to indicate whether
a token belongs to the first sequence or the second sequence.
- generate valid length

In [6]:
bert_tokenizer = nlp.data.BERTTokenizer(vocabulary, lower=False)
max_len = 512
all_labels = ["0", "1"]

# whether to transform the data as sentence pairs.
# for single sentence classification, set pair=False
# for regression task, set class_labels=None
# for inference without label available, set has_label=False
pair = True
transform = data.transform.BERTDatasetTransform(bert_tokenizer, max_len,
                                                class_labels=all_labels,
                                                has_label=True,
                                                pad=True,
                                                pair=pair)

data_train = data_train_raw.transform(transform)
# data_train_val = data_train_val_raw.transform(transform)

print('vocabulary used for tokenization = \n%s'%vocabulary)
print('%s token id = %s'%(vocabulary.padding_token, vocabulary[vocabulary.padding_token]))
print('%s token id = %s'%(vocabulary.cls_token, vocabulary[vocabulary.cls_token]))
print('%s token id = %s'%(vocabulary.sep_token, vocabulary[vocabulary.sep_token]))
print('token ids = \n%s'%data_train[sample_id][0])
print('valid length = \n%s'%data_train[sample_id][1])
print('segment ids = \n%s'%data_train[sample_id][2])
print('label = \n%s'%data_train[sample_id][3])

vocabulary used for tokenization = 
Vocab(size=28996, unk="[UNK]", reserved="['[CLS]', '[SEP]', '[MASK]', '[PAD]']")
[PAD] token id = 0
[CLS] token id = 101
[SEP] token id = 102
token ids = 
[  101  9841   117   160  2227   118   122   118  1107  7641 16240  3318
  1174  4592   118   122   113   160  6258  2101   118   122   114   117
  1134  1110   170  1231 20080  4199  2109  5565  1104   160  2227 14915
   117  1169  4609  1126 10712 27364  1107  2112   118 26574  1762  2258
 24717  1117  4793  1260  7954  2340 26572   164   164   115   115   108
   108   115   115   166   166   119   102 26285  2660  4314  1162   160
  1548  1643   118   122  1110  2569  1107  1139 13335  2881  2916  1107
 14794  5796  1105 16146  1116 17688  1322 12858 21091  1348 16085   119
  1422 13335  2881  2916  1107 14794 13945   113 26574  1116   114  2612
  1103  2445  1104  1139 26431  1496  1106  2960  1104  6664  7621  1891
  1105  1523  2227  1231 11509 11702  2734   119  1966  1103  3469  1104
  1117

In [7]:
len(data_train)

1249934

## Fine-tuning the model

Now we have all the pieces to put together, and we can finally start fine-tuning the
model with very few epochs. For demonstration, we use a fixed learning rate and
skip the validation steps. For the optimizer, we leverage the ADAM optimizer which
performs very well for NLP data and for BERT models in particular.

In [8]:
batch_size = 4
train_sampler = nlp.data.FixedBucketSampler(lengths=[123] * len(data_train),
                                            batch_size=batch_size,
                                            shuffle=True)
bert_dataloader = mx.gluon.data.DataLoader(data_train, batch_sampler=train_sampler, prefetch=500, num_workers=2)

# train_val_sampler = nlp.data.FixedBucketSampler(lengths=[123] * len(data_train_val),
#                                             batch_size=batch_size * 2,
#                                             shuffle=False)
# bert_dataloader_val = mx.gluon.data.DataLoader(data_train_val, batch_sampler=train_val_sampler)

In [9]:
lr = 5e-6
trainer = mx.gluon.Trainer(bert_classifier.collect_params(), 'adam',
                           {'learning_rate': lr, 'epsilon': 1e-9})

# Collect all differentiable parameters
# `grad_req == 'null'` indicates no gradients are calculated (e.g. constant parameters)
# The gradients for these params are clipped later
params = [p for p in bert_classifier.collect_params().values() if p.grad_req != 'null']
grad_clip = 1

# Training the model with only three epochs
log_interval = 512
test_log_interval = 32 * 10
num_epochs = 10

In [10]:
# def evaluate(dataset):
#     total_L = 0.0
#     total_sample_num = 0
#     total_correct_num = 0
#     start_log_interval_time = time.time()
    
#     t_metric = mx.metric.Accuracy()

#     print('Begin Testing...')
#     for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(dataset):
        
#         # Load the data to the GPU
#         token_ids = token_ids.as_in_context(ctx)
#         valid_length = valid_length.as_in_context(ctx)
#         segment_ids = segment_ids.as_in_context(ctx)
        
#         label = label.as_in_context(ctx)
#         out = bert_classifier(token_ids, segment_ids, valid_length.astype('float32'))

        
#         t_metric.update([label], [out])

#         if (batch_id + 1) % test_log_interval == 0:
#             print('[Batch {}/{}] elapsed {:.2f} s'.format(
#                 batch_id + 1, len(dataset),
#                 time.time() - start_log_interval_time))
#             start_log_interval_time = time.time()
#     return t_metric.get()[1]
    

In [11]:
for epoch_id in range(num_epochs):
    metric.reset()
    step_loss = 0
    for batch_id, (token_ids, valid_length, segment_ids, label) in enumerate(bert_dataloader):
        with mx.autograd.record():

            # Load the data to the GPU
            token_ids = token_ids.as_in_context(ctx)
            valid_length = valid_length.as_in_context(ctx)
            segment_ids = segment_ids.as_in_context(ctx)
            label = label.as_in_context(ctx)

            # Forward computation
            out = bert_classifier(token_ids, segment_ids, valid_length.astype('float32'))
            ls = loss_function(out, label).mean()

        # And backwards computation
        ls.backward()

        # Gradient clipping
        trainer.allreduce_grads()
        nlp.utils.clip_grad_global_norm(params, 1)
        trainer.update(1)

        step_loss += ls.asscalar()
        metric.update([label], [out])
        
        
        
        # Printing vital information
        if (batch_id) % (log_interval) == 0:
            conf_rate = '{} {} {}'.format((out[:,1]>1.5).sum().asnumpy()[0], 'of',(label==1).sum().asnumpy()[0])
            print('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.7f}, acc={:.3f}, conf={}'
                         .format(epoch_id, batch_id + 1, len(bert_dataloader),
                                 step_loss / log_interval,
                                 trainer.learning_rate, metric.get()[1],
                                 conf_rate))
            print(out[:,1].asnumpy())
            step_loss = 0
            
        if (batch_id + 1) % (20000) == 0:
            bert_classifier.save_parameters('bisai/epoch{}.params'.format(batch_id))

[Epoch 0 Batch 1/312484] loss=0.0014, lr=0.0000050, acc=0.500, conf=0.0 of 2
[-0.01505742 -0.10734081 -0.14366901 -0.15694007]
[Epoch 0 Batch 513/312484] loss=0.4042, lr=0.0000050, acc=0.848, conf=1.0 of 4
[1.2760909 1.4083455 1.505771  1.4063264]
[Epoch 0 Batch 1025/312484] loss=0.4152, lr=0.0000050, acc=0.859, conf=2.0 of 2
[ 1.5577323 -1.9759296 -1.7627802  1.8772725]
[Epoch 0 Batch 1537/312484] loss=0.4015, lr=0.0000050, acc=0.869, conf=0.0 of 0
[-2.454082  -2.3445866 -2.1327221 -1.2225958]
[Epoch 0 Batch 2049/312484] loss=0.4211, lr=0.0000050, acc=0.874, conf=0.0 of 2
[ 0.6504191  -2.4373806   0.10464716 -2.4724102 ]


KeyboardInterrupt: 

In [None]:
1