Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[SCRIPT] Reproduce MNLI tasks based on BERT #571

Merged
merged 14 commits into from Feb 19, 2019
28 changes: 13 additions & 15 deletions scripts/bert/dataset.py
Expand Up @@ -80,7 +80,7 @@ def __init__(self, path, num_discard_samples, fields):
def _read(self):
all_samples = super(GLUEDataset, self)._read()
largest_field = max(self.fields)
#to filter out error records
# to filter out error records
final_samples = [[s[f] for f in self.fields] for s in all_samples
if len(s) >= largest_field + 1]
residuals = len(all_samples) - len(final_samples)
Expand Down Expand Up @@ -358,43 +358,41 @@ def get_labels():

@register(segment=[
'dev_matched', 'dev_mismatched', 'test_matched', 'test_mismatched',
'diagnostic'
]) #pylint: disable=c0301
'train'
]) # pylint: disable=c0301
class MNLIDataset(GLUEDataset):
"""Task class for Multi-Genre Natural Language Inference

Parameters
----------
segment : str or list of str, default 'train'
Dataset segment. Options are 'dev_matched', 'dev_mismatched',
'test_matched', 'test_mismatched', 'diagnostic' or their combinations.
'test_matched', 'test_mismatched', 'train' or their combinations.
root : str, default '$GLUE_DIR/MNLI'
Path to the folder which stores the MNLI dataset.
The datset can be downloaded by the following script:
https://gist.github.com/W4ngatang/60c2bdb54d156a41194446737ce03e2e
"""

task_name = 'MNLI'
is_pair = True

def __init__(self,
segment='dev_matched',
root=os.path.join(
os.getenv('GLUE_DIR', 'glue_data'), task_name)): #pylint: disable=c0330
segment='train',
root=os.path.join(os.getenv('GLUE_DIR', 'glue_data'),
'MNLI')): # pylint: disable=c0330
self._supported_segments = [
'dev_matched', 'dev_mismatched', 'test_matched', 'test_mismatched',
'diagnostic'
'train', 'dev_matched', 'dev_mismatched',
'test_matched', 'test_mismatched',
]
assert segment in self._supported_segments, 'Unsupported segment: %s' % segment
path = os.path.join(root, '%s.tsv' % segment)
if segment in ['dev_matched', 'dev_mismatched']:
A_IDX, B_IDX, LABEL_IDX = 8, 9, 15
if segment in ['train', 'dev_matched', 'dev_mismatched']:
A_IDX, B_IDX = 8, 9
LABEL_IDX = 11 if segment == 'train' else 15
fields = [A_IDX, B_IDX, LABEL_IDX]
elif segment in ['test_matched', 'test_mismatched']:
A_IDX, B_IDX = 8, 9
fields = [A_IDX, B_IDX]
elif segment == 'diagnostic':
A_IDX, B_IDX = 1, 2
fields = [A_IDX, B_IDX]
super(MNLIDataset, self).__init__(
path, num_discard_samples=1, fields=fields)

Expand Down
112 changes: 85 additions & 27 deletions scripts/bert/finetune_classifier.py
Expand Up @@ -93,8 +93,7 @@
'--warmup_ratio',
type=float,
default=0.1,
help=
'ratio of warmup steps used in NOAM\'s stepsize schedule, default is 0.1')
help='ratio of warmup steps used in NOAM\'s stepsize schedule, default is 0.1')
parser.add_argument(
'--log_interval',
type=int,
Expand Down Expand Up @@ -135,6 +134,25 @@
'for both bert_24_1024_16 and bert_12_768_12.'
'\'wiki_cn\', \'wiki_multilingual\' and \'wiki_multilingual_cased\' for bert_12_768_12 only.'
)
parser.add_argument(
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
'--pretrained_bert_parameters',
type=str,
default=None,
help='Pre-trained bert model parameter file. default is None'
)
parser.add_argument(
'--model_parameters',
type=str,
default=None,
help='Model parameter file. default is None'
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
)
parser.add_argument(
'--output_dir',
type=str,
default='./output_dir',
help='The output directory where the model params will be written.'
' default is ./output_dir'
)

args = parser.parse_args()

Expand Down Expand Up @@ -183,6 +201,20 @@
model.classifier.initialize(init=mx.init.Normal(0.02), ctx=ctx)
loss_function = gluon.loss.SoftmaxCELoss()

eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
# load checkpointing
pretrained_bert_parameters = args.pretrained_bert_parameters
model_parameters = args.model_parameters
output_dir = args.output_dir
if pretrained_bert_parameters:
logging.info('loading bert params from {0}'.format(pretrained_bert_parameters))
model.bert.load_parameters(pretrained_bert_parameters)
if model_parameters:
logging.info('loading model params from {0}'.format(model_parameters))
model.load_parameters(model_parameters)
if not os.path.exists(output_dir):
os.makedirs(output_dir)


logging.info(model)
model.hybridize(static_alloc=True)
loss_function.hybridize(static_alloc=True)
Expand All @@ -203,13 +235,7 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len):
pair=task.is_pair,
label_dtype='float32' if not task.get_labels() else 'int32')

if task.task_name == 'MNLI':
data_train = task('dev_matched').transform(trans, lazy=False)
data_dev = task('dev_mismatched').transform(trans, lazy=False)
else:
data_train = task('train').transform(trans, lazy=False)
data_dev = task('dev').transform(trans, lazy=False)

data_train = task('train').transform(trans, lazy=False)
data_train_len = data_train.transform(
lambda input_id, length, segment_id, label_id: length)

Expand All @@ -227,29 +253,49 @@ def preprocess_data(tokenizer, task, batch_size, dev_batch_size, max_len):
ratio=0,
shuffle=True)
# data loaders
dataloader = gluon.data.DataLoader(
dataloader_train = gluon.data.DataLoader(
dataset=data_train,
num_workers=1,
batch_sampler=batch_sampler,
batchify_fn=batchify_fn)
dataloader_dev = mx.gluon.data.DataLoader(
data_dev,
batch_size=dev_batch_size,
num_workers=1,
shuffle=False,
batchify_fn=batchify_fn)
return dataloader, dataloader_dev, num_samples_train


train_data, dev_data, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len)
if task.task_name == 'MNLI':
data_dev_matched = task('dev_matched').transform(trans, lazy=False)
data_dev_mismatched = task('dev_mismatched').transform(trans, lazy=False)

dataloader_dev_matched = mx.gluon.data.DataLoader(
data_dev_matched, batch_size=dev_batch_size,
num_workers=1, shuffle=False, batchify_fn=batchify_fn)
dataloader_dev_mismatched = mx.gluon.data.DataLoader(
data_dev_mismatched, batch_size=dev_batch_size,
num_workers=1, shuffle=False, batchify_fn=batchify_fn)
return dataloader_train, dataloader_dev_matched, \
dataloader_dev_mismatched, num_samples_train
else:
data_dev = task('dev').transform(trans, lazy=False)
dataloader_dev = mx.gluon.data.DataLoader(
data_dev,
batch_size=dev_batch_size,
num_workers=1,
shuffle=False,
batchify_fn=batchify_fn)
return dataloader_train, dataloader_dev, num_samples_train


# Get the dataloader. Data set for special handling of MNLI tasks
logging.info('processing dataset...')
if task.task_name == 'MNLI':
train_data, dev_data_matched, dev_data_mismatched, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len)
else:
train_data, dev_data, num_train_examples = preprocess_data(
bert_tokenizer, task, batch_size, dev_batch_size, args.max_len)


def evaluate(metric):
def evaluate(dataloader_eval, metric):
"""Evaluate the model on validation dataset.
"""
metric.reset()
for _, seqs in enumerate(dev_data):
for _, seqs in enumerate(dataloader_eval):
input_ids, valid_len, type_ids, label = seqs
out = model(
input_ids.as_in_context(ctx), type_ids.as_in_context(ctx),
Expand Down Expand Up @@ -344,12 +390,24 @@ def train(metric):
metric_val = [metric_val]
eval_str = '[Epoch {} Batch {}/{}] loss={:.4f}, lr={:.7f}, metrics=' + \
','.join([i + ':{:.4f}' for i in metric_nm])
logging.info(eval_str.format(epoch_id + 1, batch_id + 1, len(train_data), \
step_loss / args.log_interval, \
trainer.learning_rate, *metric_val))
logging.info(eval_str.format(epoch_id + 1, batch_id + 1, len(train_data),
step_loss / args.log_interval,
trainer.learning_rate, *metric_val))
step_loss = 0
mx.nd.waitall()
evaluate(metric)
if task.task_name == 'MNLI':
logging.info('On MNLI Matched: ')
evaluate(dev_data_matched, metric)
logging.info('On MNLI Mismatched: ')
evaluate(dev_data_mismatched, metric)
else:
evaluate(dev_data, metric)

# save params
params_saved = os.path.join(output_dir,
'model_bert_{0}_{1}.params'.format(task.task_name, epoch_id))
model.save_parameters(params_saved)
logging.info('params saved in : {0}'.format(params_saved))
toc = time.time()
logging.info('Time cost={:.1f}s'.format(toc - tic))
tic = toc
Expand Down
10 changes: 9 additions & 1 deletion scripts/bert/index.rst
Expand Up @@ -45,4 +45,12 @@ It gets validation accuracy of `88.7% <https://raw.githubusercontent.com/dmlc/we
It gets RTE validation accuracy of `70.8% <https://raw.githubusercontent.com/dmlc/web-data/master/gluonnlp/logs/bert/finetuned_rte.log>`_
, whereas the the original Tensorflow implementation give evaluation results 66.4%.

Some other tasks can be modeled with `--task_name` parameter.
.. code-block:: console

$ MXNET_GPU_MEM_POOL_TYPE=Round GLUE_DIR=glue_data python3 finetune_classifier.py --task_name MNLI --max_len 80 --log_interval 100 --gpu

It gets MNLI validation accuracy ,On dev_matched.tsv: 84.6%
On dev_mismatched.tsv: 84.7%. `log <https://github.com/dmlc/web-data/blob/master/gluonnlp/logs/bert/finetuned_mnli.log>`_


Some other tasks can be modeled with `--task_name` parameter.