Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] Add transformer inference code #852

Merged
merged 20 commits into from
Sep 8, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions scripts/machine_translation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,25 @@ obtain BLEU=27.05 with ``--bleu 13a``, BLEU=27.81 with ``--bleu intl``, and BLEU
The pre-trained model can be downloaded from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.


Use the following commands to inference the Transformer model on the WMT14 test dataset for English to German translation.

.. code-block:: console

$ python inference_transformer.py --dataset WMT2014BPE
--src_lang en \
--tgt_lang de \
--batch_size 2700 \
--num_accumulated 16 \
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
--scaled --average_start 5 \
--num_buckets 20 \
--bucket_scheme exp \
--bleu 13a \
--log_interval 10 \
--model_parameter PATH/TO/valid_best.params

Before inference, you should do a complete training at least one time to get the pre-trained model, or you can get the pre-trained model from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

45 changes: 15 additions & 30 deletions scripts/machine_translation/inference_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,7 @@
from mxnet import gluon
import gluonnlp as nlp

from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing
leezu marked this conversation as resolved.
Show resolved Hide resolved
from gluonnlp.model.translation import NMTModel
from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer
from gluonnlp.utils.parallel import Parallel
from translation import BeamSearchTranslator
#from loss import SoftmaxCEMaskedLoss, LabelSmoothing
from utils import logging_config
Expand All @@ -60,7 +57,7 @@
mx.random.seed(10000)

parser = argparse.ArgumentParser(description='Neural Machine Translation Example.'
'We train the Transformer Model')
'We use this script only for transformer inference.')
parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set default value to "WMT2014BPE" ?

parser.add_argument('--src_lang', type=str, default='en', help='Source language')
parser.add_argument('--tgt_lang', type=str, default='de', help='Target language')
Expand Down Expand Up @@ -190,7 +187,6 @@
share_embed=args.dataset != 'TOY', embed_size=args.num_units,
tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')

assert args.model_parameter is not ' ', 'INFERENCE BUT DO NOT HAVE PARAMETERS!'
model.load_parameters(args.model_parameter, ctx)

static_alloc = True
Expand Down Expand Up @@ -233,17 +229,8 @@ def inference():
avg_loss = 0.0
total_wc = 0
total_time = 0

batch_total_blue = 0

# add profiler on-off
is_profiler_on = os.getenv('GLUONNLP_TRANSFORMER_PROFILING', False)
is_mkldnn_verbose_on = os.getenv('MKLDNN_VERBOSE', 0)
if is_profiler_on:
mx.profiler.set_config(profile_symbolic=True, profile_imperative=True, profile_memory=False,
profile_api=False, filename='profile.json', aggregate_stats=True)
mx.profiler.set_state('run')

for batch_id, (src_seq, tgt_seq, src_test_length, tgt_test_length, inst_ids) \
in enumerate(test_data_loader):

Expand All @@ -253,7 +240,6 @@ def inference():
tgt_seq = tgt_seq.as_in_context(ctx[0])
src_test_length = src_test_length.as_in_context(ctx[0])
tgt_test_length = tgt_test_length.as_in_context(ctx[0])

# Calculating Loss
out, _ = model(src_seq, tgt_seq[:, :-1], tgt_test_length, tgt_test_length - 1)
loss = test_loss_function(out, tgt_seq[:, 1:], tgt_test_length - 1).mean().asscalar()
Expand All @@ -273,10 +259,9 @@ def inference():
translation_tmp = []
translation_tmp_sentences = []
for i in range(max_score_sample.shape[0]):
translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in
max_score_sample[i][1:(sample_test_length[i] - 1)]])
# translation_out.extend(translation_tmp)

translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in \
max_score_sample[i][1:(sample_test_length[i] - 1)]])

# detokenizer each translator result
for _, sentence in enumerate(translation_tmp):
if args.bleu == 'tweaked':
Expand All @@ -289,20 +274,20 @@ def inference():
raise NotImplementedError

# generate tgt_sentence for bleu calculation of each batch
tgt_sen_tmp = [test_tgt_sentences[index] for
_, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())]
tgt_sen_tmp = [test_tgt_sentences[index] for \
_, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())]
batch_test_bleu_score, _, _, _, _ = compute_bleu([tgt_sen_tmp], translation_tmp_sentences,
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
batch_total_blue += batch_test_bleu_score

# log for every ten batchs
if batch_id % 10 == 0 and batch_id != 0:
batch_ave_bleu = batch_total_blue / 10
batch_total_blue = 0
logging.info('batch id={:d}, loss={:.4f}, batch_bleu={:.4f}'
.format(batch_id, loss, batch_ave_bleu * 100))
.format(batch_id, loss, batch_ave_bleu * 100))

# reorg translation sentences by inst_ids
real_translation_out = [None for _ in range(len(all_inst_ids))]
Expand All @@ -311,15 +296,15 @@ def inference():

# get bleu score, n-gram precisions, brevity penalty, reference length, and translation length
test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], real_translation_out,
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
# total batch logging
test_ave_loss = avg_loss / avg_loss_denom
logging.info('Inference at val dataset. Loss={:.4f}, \
val ppl={:.4f}, val bleu={:.4f}, throughput={:.4f}K wps'
.format(test_ave_loss, np.exp(val_ave_loss),
test_bleu_score * 100, total_wc / total_time / 1000))
.format(test_ave_loss, np.exp(test_ave_loss),
test_bleu_score * 100, total_wc / total_time / 1000))


if __name__ == '__main__':
Expand Down