Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
[SCRIPTS] Add transformer inference code (#852)
Browse files Browse the repository at this point in the history
* add transformer inference code

* fix lint

* change inference to translator step and improve OOB feature

* add inference doc and fix pylint

* fix CI fail

* fix CI fail

* fix logging info

* fix pylint

* delete code which relate to train

* trigger CI

* update for dataload method

* fix review comments

* remove num_shards on the get_dataloader fn

* fix reviews

* Download trained params if needed for transformer inference

* Fix lint error and add warnning  for downlonging param file

* Change mx.test_utils.download to mx.gluon.utils.download, update params link

* retrigger CI
  • Loading branch information
pengxin99 authored and eric-haibin-lin committed Sep 8, 2019
1 parent bb18ae2 commit 5059208
Show file tree
Hide file tree
Showing 4 changed files with 401 additions and 49 deletions.
110 changes: 61 additions & 49 deletions scripts/machine_translation/dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,11 @@ def get_data_lengths(dataset):
get_lengths = lambda *args: (args[2], args[3])
return list(dataset.transform(get_lengths))


def make_dataloader(data_train, data_val, data_test, args,
use_average_length=False, num_shards=0, num_workers=8):
def get_dataloader(data_set, args, dataset_type,
use_average_length=False, num_shards=0, num_workers=8):
"""Create data loaders for training/validation/test."""
data_train_lengths = get_data_lengths(data_train)
data_val_lengths = get_data_lengths(data_val)
data_test_lengths = get_data_lengths(data_test)
train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())
target_val_lengths = list(map(lambda x: x[-1], data_val_lengths))
target_test_lengths = list(map(lambda x: x[-1], data_test_lengths))
assert dataset_type in ['train', 'val', 'test']

if args.bucket_scheme == 'constant':
bucket_scheme = nlp.data.ConstWidthBucket()
elif args.bucket_scheme == 'linear':
Expand All @@ -224,44 +215,65 @@ def make_dataloader(data_train, data_val, data_test, args,
bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
else:
raise NotImplementedError
train_batch_sampler = nlp.data.FixedBucketSampler(lengths=data_train_lengths,
batch_size=args.batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=True,
use_average_length=use_average_length,
num_shards=num_shards,
bucket_scheme=bucket_scheme)
logging.info('Train Batch Sampler:\n%s', train_batch_sampler.stats())
train_data_loader = nlp.data.ShardedDataLoader(data_train,
batch_sampler=train_batch_sampler,
batchify_fn=train_batchify_fn,
num_workers=num_workers)

val_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_val_lengths,
batch_size=args.test_batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=False,
use_average_length=use_average_length,
bucket_scheme=bucket_scheme)
logging.info('Valid Batch Sampler:\n%s', val_batch_sampler.stats())
val_data_loader = gluon.data.DataLoader(data_val,
batch_sampler=val_batch_sampler,

data_lengths = get_data_lengths(data_set)

if dataset_type == 'train':
train_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))

else:
data_lengths = list(map(lambda x: x[-1], data_lengths))
test_batchify_fn = btf.Tuple(btf.Pad(), btf.Pad(),
btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
btf.Stack())

batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
batch_size=(args.batch_size \
if dataset_type == 'train' \
else args.test_batch_size),
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=(dataset_type == 'train'),
use_average_length=use_average_length,
num_shards=num_shards,
bucket_scheme=bucket_scheme)

if dataset_type == 'train':
logging.info('Train Batch Sampler:\n%s', batch_sampler.stats())
data_loader = nlp.data.ShardedDataLoader(data_set,
batch_sampler=batch_sampler,
batchify_fn=train_batchify_fn,
num_workers=num_workers)
else:
if dataset_type == 'val':
logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats())
else:
logging.info('Test Batch Sampler:\n%s', batch_sampler.stats())

data_loader = gluon.data.DataLoader(data_set,
batch_sampler=batch_sampler,
batchify_fn=test_batchify_fn,
num_workers=num_workers)
test_batch_sampler = nlp.data.FixedBucketSampler(lengths=target_test_lengths,
batch_size=args.test_batch_size,
num_buckets=args.num_buckets,
ratio=args.bucket_ratio,
shuffle=False,
use_average_length=use_average_length,
bucket_scheme=bucket_scheme)
logging.info('Test Batch Sampler:\n%s', test_batch_sampler.stats())
test_data_loader = gluon.data.DataLoader(data_test,
batch_sampler=test_batch_sampler,
batchify_fn=test_batchify_fn,
num_workers=num_workers)

return data_loader

def make_dataloader(data_train, data_val, data_test, args,
use_average_length=False, num_shards=0, num_workers=8):
"""Create data loaders for training/validation/test."""
train_data_loader = get_dataloader(data_train, args, dataset_type='train',
use_average_length=use_average_length,
num_shards=num_shards,
num_workers=num_workers)

val_data_loader = get_dataloader(data_val, args, dataset_type='val',
use_average_length=use_average_length,
num_workers=num_workers)

test_data_loader = get_dataloader(data_test, args, dataset_type='test',
use_average_length=use_average_length,
num_workers=num_workers)

return train_data_loader, val_data_loader, test_data_loader


Expand Down
21 changes: 21 additions & 0 deletions scripts/machine_translation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,24 @@ obtain BLEU=27.05 with ``--bleu 13a``, BLEU=27.81 with ``--bleu intl``, and BLEU
The pre-trained model can be downloaded from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.


Use the following commands to inference the Transformer model on the WMT14 test dataset for English to German translation.

.. code-block:: console
$ python inference_transformer.py --dataset WMT2014BPE
--src_lang en \
--tgt_lang de \
--batch_size 2700 \
--scaled --average_start 5 \
--num_buckets 20 \
--bucket_scheme exp \
--bleu 13a \
--log_interval 10 \
--model_parameter PATH/TO/valid_best.params
Before inference, you should do a complete training at least one time to get the pre-trained model, or you can get the pre-trained model from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

0 comments on commit 5059208

Please sign in to comment.