Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] Add transformer inference code #852

Merged
merged 20 commits into from
Sep 8, 2019
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
110 changes: 61 additions & 49 deletions scripts/machine_translation/dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,11 @@ def get_data_lengths(dataset):
get_lengths = lambda *args: (args[2], args[3])
return list(dataset.transform(get_lengths))


def make_dataloader(data_train, data_val, data_test, args,
use_average_length=False, num_shards=0, num_workers=8):
def get_dataloader(data_set, args, dataset_type,
use_average_length=False, num_shards=0, num_workers=8):
"""Create data loaders for training/validation/test."""
data_train_lengths = get_data_lengths(data_train)
data_val_lengths = get_data_lengths(data_val)
data_test_lengths = get_data_lengths(data_test)
train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())
target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
assert dataset_type in ['train', 'val', 'test']

if args.bucket_scheme == 'constant':
bucket_scheme = nlp.data.ConstWidthBucket()
elif args.bucket_scheme == 'linear':
Expand All @@ -224,44 +215,65 @@ def make_dataloader(data_train, data_val, data_test, args,
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
else:
raise NotImplementedError
train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
batch_size=args.batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=True,
use_average_length=use_average_length,
num_shards=num_shards,
bucket_scheme=bucket_scheme)
logging.info('Train Batch Sampler:\n%s', train_batch_sampler.stats())
train_data_loader = nlp.data.ShardedDataLoader(data_train,
batch_sampler=train_batch_sampler,
batchify_fn=train_batchify_fn,
num_workers=num_workers)

val_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_val_lengths,
batch_size=args.test_batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=False,
use_average_length=use_average_length,
bucket_scheme=bucket_scheme)
logging.info('Valid Batch Sampler:\n%s', val_batch_sampler.stats())
val_data_loader = gluon.data.DataLoader(data_val,
batch_sampler=val_batch_sampler,

data_lengths = get_data_lengths(data_set)

if dataset_type == 'train':
train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))

else:
data_lengths = list(map(lambda x: x[-1], data_lengths))
test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())

batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
batch_size=(args.batch_size \
if dataset_type == 'train' \
else args.test_batch_size),
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=(dataset_type == 'train'),
use_average_length=use_average_length,
num_shards=num_shards,
bucket_scheme=bucket_scheme)

if dataset_type == 'train':
logging.info('Train Batch Sampler:\n%s', batch_sampler.stats())
data_loader = nlp.data.ShardedDataLoader(data_set,
batch_sampler=batch_sampler,
batchify_fn=train_batchify_fn,
num_workers=num_workers)
else:
if dataset_type == 'val':
logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats())
else:
logging.info('Test Batch Sampler:\n%s', batch_sampler.stats())

data_loader = gluon.data.DataLoader(data_set,
batch_sampler=batch_sampler,
batchify_fn=test_batchify_fn,
num_workers=num_workers)
test_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_test_lengths,
batch_size=args.test_batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=False,
use_average_length=use_average_length,
bucket_scheme=bucket_scheme)
logging.info('Test Batch Sampler:\n%s', test_batch_sampler.stats())
test_data_loader = gluon.data.DataLoader(data_test,
batch_sampler=test_batch_sampler,
batchify_fn=test_batchify_fn,
num_workers=num_workers)

return data_loader

def make_dataloader(data_train, data_val, data_test, args,
use_average_length=False, num_shards=0, num_workers=8):
"""Create data loaders for training/validation/test."""
train_data_loader = get_dataloader(data_train, args, dataset_type='train',
use_average_length=use_average_length,
num_shards=num_shards,
num_workers=num_workers)

val_data_loader = get_dataloader(data_val, args, dataset_type='val',
use_average_length=use_average_length,
num_workers=num_workers)

test_data_loader = get_dataloader(data_test, args, dataset_type='test',
use_average_length=use_average_length,
num_workers=num_workers)

return train_data_loader, val_data_loader, test_data_loader


Expand Down
21 changes: 21 additions & 0 deletions scripts/machine_translation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,24 @@ obtain BLEU=27.05 with ``--bleu 13a``, BLEU=27.81 with ``--bleu intl``, and BLEU
The pre-trained model can be downloaded from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.


Use the following commands to inference the Transformer model on the WMT14 test dataset for English to German translation.

.. code-block:: console

$ python inference_transformer.py --dataset WMT2014BPE
--src_lang en \
--tgt_lang de \
--batch_size 2700 \
--scaled --average_start 5 \
--num_buckets 20 \
--bucket_scheme exp \
--bleu 13a \
--log_interval 10 \
--model_parameter PATH/TO/valid_best.params

Before inference, you should do a complete training at least one time to get the pre-trained model, or you can get the pre-trained model from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.