Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[SCRIPT] Reproduce MNLI tasks based on BERT #571

Merged
merged 14 commits into from Feb 19, 2019
68 changes: 67 additions & 1 deletion scripts/bert/dataset.py
Expand Up @@ -14,10 +14,11 @@
# limitations under the License.
"""BERT datasets."""

__all__ = ['MRPCDataset', 'BERTDatasetTransform']
__all__ = ['MRPCDataset', 'MNLIDataset', 'BERTDatasetTransform']

import os
import numpy as np
from mxnet.metric import Accuracy
from gluonnlp.data import TSVDataset, BERTSentenceTransform
from gluonnlp.data.registry import register

Expand Down Expand Up @@ -50,6 +51,71 @@ def get_labels():
return ['0', '1']


class GLUEDataset(TSVDataset):
"""GLUEDataset class"""

def __init__(self, path, num_discard_samples, fields):
self.fields = fields
super(GLUEDataset, self).__init__(
path, num_discard_samples=num_discard_samples)

def _read(self):
all_samples = super(GLUEDataset, self)._read()
largest_field = max(self.fields)
# to filter out error records
final_samples = [[s[f] for f in self.fields] for s in all_samples
if len(s) >= largest_field + 1]
return final_samples


@register(segment=[
'dev_matched', 'dev_mismatched', 'test_matched', 'test_mismatched',
'diagnostic'
]) # pylint: disable=c0301
class MNLIDataset(GLUEDataset):
"""Task class for Multi-Genre Natural Language Inference
Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'dev_matched', 'dev_mismatched',
'test_matched', 'test_mismatched', 'diagnostic' or their combinations.
root : str, default '$GLUE_DIR/MNLI'
Path to the folder which stores the MNLI dataset.
The datset can be downloaded by the following script:
https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
"""

def __init__(self,
segment='train',
root=os.path.join(os.getenv('GLUE_DIR', 'glue_data'),
'MNLI')): # pylint: disable=c0330
self._supported_segments = [
'train', 'dev_matched', 'dev_mismatched',
'test_matched', 'test_mismatched',
]
assert segment in self._supported_segments, 'Unsupported segment: %s' % segment
path = os.path.join(root, '%s.tsv' % segment)
if segment in ['train', 'dev_matched', 'dev_mismatched']:
A_IDX, B_IDX = 8, 9
LABEL_IDX = 11 if segment == 'train' else 15
fields = [A_IDX, B_IDX, LABEL_IDX]
elif segment in ['test_matched', 'test_mismatched']:
A_IDX, B_IDX = 8, 9
fields = [A_IDX, B_IDX]
super(MNLIDataset, self).__init__(
path, num_discard_samples=1, fields=fields)

@staticmethod
def get_labels():
"""Get classification label ids of the dataset."""
return ['neutral', 'entailment', 'contradiction']

@staticmethod
def get_metric():
"""Get metrics Accuracy"""
return Accuracy()


class BERTDatasetTransform(object):
"""Dataset Transformation for BERT-style Sentence Classification or Regression.

Expand Down
244 changes: 244 additions & 0 deletions scripts/bert/finetune_mnli.py
@@ -0,0 +1,244 @@
"""
Sentence Pair Classification with Bidirectional Encoder Representations from Transformers

=========================================================================================

This example shows how to implement finetune a model with pre-trained BERT parameters for
sentence pair classification, with Gluon NLP Toolkit.

@article{devlin2018bert,
title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
journal={arXiv preprint arXiv:1810.04805},
year={2018}
}
"""

# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation

import time
import argparse
import random
import logging
import warnings
import numpy as np
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp
from gluonnlp.model import bert_12_768_12
from gluonnlp.data import BERTTokenizer

from bert import BERTClassifier
from dataset import MNLIDataset, BERTDatasetTransform

parser = argparse.ArgumentParser(description='BERT sentence pair classification example.'
'We fine-tune the BERT model on MRPC')
parser.add_argument('--epochs', type=int, default=3, help='number of epochs, default is 3')
parser.add_argument('--batch_size', type=int, default=32,
help='Batch size. Number of examples per gpu in a minibatch, default is 32')
parser.add_argument('--dev_batch_size', type=int, default=8,
help='Batch size for dev set, default is 8')
parser.add_argument('--optimizer', type=str, default='bertadam',
help='Optimization algorithm, default is bertadam')
parser.add_argument('--lr', type=float, default=5e-5,
help='Initial learning rate, default is 5e-5')
parser.add_argument('--warmup_ratio', type=float, default=0.1,
help='ratio of warmup steps used in NOAM\'s stepsize schedule, default is 0.1')
parser.add_argument('--log_interval', type=int, default=100, help='report interval, default is 100')
parser.add_argument('--max_len', type=int, default=80,
help='Maximum length of the sentence pairs, default is 80')
parser.add_argument('--seed', type=int, default=2, help='Random seed, default is 2')
parser.add_argument('--accumulate', type=int, default=None, help='The number of batches for '
'gradients accumulation to simulate large batch size. Default is None')
parser.add_argument('--gpu', action='store_true', help='whether to use gpu for finetuning')
args = parser.parse_args()

logging.getLogger().setLevel(logging.DEBUG)
logging.info(args)

batch_size = args.batch_size
dev_batch_size = args.dev_batch_size
lr = args.lr
accumulate = args.accumulate
log_interval = args.log_interval * accumulate if accumulate else args.log_interval
if accumulate:
logging.info('Using gradient accumulation. Effective batch size = %d', accumulate * batch_size)

# random seed
np.random.seed(args.seed)
random.seed(args.seed)
mx.random.seed(args.seed)

ctx = mx.cpu() if not args.gpu else mx.gpu()

# model and loss
dataset = 'book_corpus_wiki_en_uncased'
bert, vocabulary = bert_12_768_12(dataset_name=dataset,
pretrained=True, ctx=ctx, use_pooler=True,
use_decoder=False, use_classifier=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add an option to load from a specific checkpoint file? For example, users can specify --load_checkpoint ~/.mxnet/bert/xxx.params

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eric-haibin-lin
GLUE task only requires 3 or 5 epochs for fine-tuning, I don't think checkpoint requires currently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. What I mean is that if someone uses our pre-training script to further pre-train on their dataset (work in progress, not merged yet), they're interested in loading the bert checkpoint and finetune on other tasks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, adding this option will make the script work for a more specific situation, and this is easy to implement. I can add it later.

model = BERTClassifier(bert, num_classes=3, dropout=0.1)
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
model.hybridize(static_alloc=True)

loss_function = gluon.loss.SoftmaxCELoss()
loss_function.hybridize(static_alloc=True)
metric = mx.metric.Accuracy()

# data processing
do_lower_case = 'uncased' in dataset
bert_tokenizer = BERTTokenizer(vocabulary, lower=do_lower_case)


def preprocess_data(tokenizer, batch_size, dev_batch_size, max_len):
"""Data preparation function."""
# transformation
train_trans = BERTDatasetTransform(tokenizer, max_len,
labels=MNLIDataset.get_labels(),
pad=False, label_dtype='int32')
dev_trans = BERTDatasetTransform(tokenizer, max_len,
labels=MNLIDataset.get_labels(),
label_dtype='int32')
data_train = MNLIDataset('train').transform(train_trans, lazy=False)
data_dev_matched = MNLIDataset('dev_matched').transform(dev_trans, lazy=False)
data_dev_mismatched = MNLIDataset('dev_mismatched').transform(dev_trans, lazy=False)
data_train_len = data_train.transform(lambda input_id, length, segment_id, label_id: length)
num_samples_train = len(data_train)
# bucket sampler
batchify_fn = nlp.data.batchify.Tuple(nlp.data.batchify.Pad(axis=0), # input token ids
nlp.data.batchify.Stack(), # valid length
nlp.data.batchify.Pad(axis=0), # input token type ids
nlp.data.batchify.Stack()) # label id
batch_sampler = nlp.data.sampler.FixedBucketSampler(data_train_len,
batch_size=batch_size,
num_buckets=10,
ratio=0,
shuffle=True)
# data loaders
dataloader = gluon.data.DataLoader(dataset=data_train, num_workers=1,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)
dataloader_dev_matched = mx.gluon.data.DataLoader(data_dev_matched, batch_size=dev_batch_size,
num_workers=1, shuffle=False)
dataloader_dev_mismatched = mx.gluon.data.DataLoader(
data_dev_mismatched, batch_size=dev_batch_size, num_workers=1, shuffle=False)
return dataloader, dataloader_dev_matched, dataloader_dev_mismatched, num_samples_train


train_data, dev_data_matched, dev_data_mismatched, num_train_examples = preprocess_data(
bert_tokenizer, batch_size, dev_batch_size, args.max_len)


def evaluate(dataloader_eval):
"""Evaluate the model on validation dataset.
"""
step_loss = 0
metric.reset()
for _, seqs in enumerate(dataloader_eval):
input_ids, valid_len, type_ids, label = seqs
out = model(input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
valid_len.astype('float32').as_in_context(ctx))
ls = loss_function(out, label.as_in_context(ctx)).mean()

step_loss += ls.asscalar()
metric.update([label], [out])
logging.info('Validation accuracy: {:.3f}'.format(metric.get()[1]))


def train():
"""Training function."""
optimizer_params = {'learning_rate': lr, 'wd': 0.01}
try:
trainer = gluon.Trainer(model.collect_params(), args.optimizer,
optimizer_params, update_on_kvstore=False)
except ValueError as e:
print(e)
warnings.warn('AdamW optimizer is not found. Please consider upgrading to '
'mxnet>=1.5.0. Now the original Adam optimizer is used instead.')
trainer = gluon.Trainer(model.collect_params(), 'adam',
optimizer_params, update_on_kvstore=False)

step_size = batch_size * accumulate if accumulate else batch_size
num_train_steps = int(num_train_examples / step_size * args.epochs)
warmup_ratio = args.warmup_ratio
num_warmup_steps = int(num_train_steps * warmup_ratio)
step_num = 0

# Do not apply weight decay on LayerNorm and bias terms
for _, v in model.collect_params('.*beta|.*gamma|.*bias').items():
v.wd_mult = 0.0
# Collect differentiable parameters
params = [p for p in model.collect_params().values() if p.grad_req != 'null']
# Set grad_req if gradient accumulation is required
if accumulate:
for p in params:
p.grad_req = 'add'

for epoch_id in range(args.epochs):
metric.reset()
step_loss = 0
tic = time.time()
for batch_id, seqs in enumerate(train_data):
# set grad to zero for gradient accumulation
if accumulate:
if batch_id % accumulate == 0:
model.collect_params().zero_grad()
step_num += 1
else:
step_num += 1
# learning rate schedule
if step_num < num_warmup_steps:
new_lr = lr * step_num / num_warmup_steps
else:
offset = (step_num - num_warmup_steps) * lr / (num_train_steps - num_warmup_steps)
new_lr = lr - offset
trainer.set_learning_rate(new_lr)
# forward and backward
with mx.autograd.record():
input_ids, valid_length, type_ids, label = seqs
out = model(input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
valid_length.astype('float32').as_in_context(ctx))
ls = loss_function(out, label.as_in_context(ctx)).mean()
ls.backward()
# update
if not accumulate or (batch_id + 1) % accumulate == 0:
trainer.allreduce_grads()
nlp.utils.clip_grad_global_norm(params, 1)
trainer.update(accumulate if accumulate else 1)
step_loss += ls.asscalar()
metric.update([label], [out])
if (batch_id + 1) % log_interval == 0:
logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.7f}, acc={:.3f}'
.format(epoch_id, batch_id + 1, len(train_data),
step_loss / log_interval,
trainer.learning_rate, metric.get()[1]))
step_loss = 0
mx.nd.waitall()
logging.info('On MNLI Matched: ')
evaluate(dev_data_matched)
logging.info('On MNLI Mismatched: ')
evaluate(dev_data_mismatched)
toc = time.time()
logging.info('Time cost={:.1f}s'.format(toc - tic))
tic = toc


if __name__ == '__main__':
train()