Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

[FEATURE] Add transformer inference code #852

Merged
merged 20 commits into from
Sep 8, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 22 additions & 0 deletions scripts/machine_translation/index.rst
Expand Up @@ -47,3 +47,25 @@ obtain BLEU=27.05 with ``--bleu 13a``, BLEU=27.81 with ``--bleu intl``, and BLEU
The pre-trained model can be downloaded from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.


Use the following commands to inference the Transformer model on the WMT14 test dataset for English to German translation.

.. code-block:: console

$ python inference_transformer.py --dataset WMT2014BPE
--src_lang en \
--tgt_lang de \
--batch_size 2700 \
--num_accumulated 16 \
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
--scaled --average_start 5 \
--num_buckets 20 \
--bucket_scheme exp \
--bleu 13a \
--log_interval 10 \
--model_parameter PATH/TO/valid_best.params

Before inference, you should do a complete training at least one time to get the pre-trained model, or you can get the pre-trained model from http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

For the users from China, it might be faster with this link instead: https://apache-mxnet.s3.cn-north-1.amazonaws.com.cn/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip.

316 changes: 316 additions & 0 deletions scripts/machine_translation/inference_transformer.py
@@ -0,0 +1,316 @@
"""
Transformer
=================================

This example shows how to implement the Transformer model with Gluon NLP Toolkit.

@inproceedings{vaswani2017attention,
title={Attention is all you need},
author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones,
Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia},
booktitle={Advances in Neural Information Processing Systems},
pages={6000--6010},
year={2017}
}
"""

# coding: utf-8

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name,logging-format-interpolation

import argparse
import time
import random
import os
import logging
import math
import numpy as np
import mxnet as mx
from mxnet import gluon
import gluonnlp as nlp

from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing
from gluonnlp.model.translation import NMTModel
from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer
from gluonnlp.utils.parallel import Parallel
from translation import BeamSearchTranslator
from utils import logging_config
from bleu import _bpe_to_words, compute_bleu
import dataprocessor

np.random.seed(100)
random.seed(100)
mx.random.seed(10000)

parser = argparse.ArgumentParser(description='Neural Machine Translation Example.'
'We use this script only for transformer inference.')
parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set default value to "WMT2014BPE" ?

parser.add_argument('--src_lang', type=str, default='en', help='Source language')
parser.add_argument('--tgt_lang', type=str, default='de', help='Target language')
parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really need --epochs for inference mode?

parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding '
'vectors and states.')
parser.add_argument('--hidden_size', type=int, default=2048,
help='Dimension of the hidden state in position-wise feed-forward networks.')
parser.add_argument('--dropout', type=float, default=0.1,
help='dropout applied to layers (0 = no dropout)')
parser.add_argument('--epsilon', type=float, default=0.1,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Training only parameters?

help='epsilon parameter for label smoothing')
parser.add_argument('--num_layers', type=int, default=6,
help='number of layers in the encoder and decoder')
parser.add_argument('--num_heads', type=int, default=8,
help='number of heads in multi-head attention')
parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention')
parser.add_argument('--batch_size', type=int, default=1024,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--batch_size should not be related with hardware back-end.

help='Batch size. Number of tokens per gpu in a minibatch')
parser.add_argument('--beam_size', type=int, default=4, help='Beam size')
parser.add_argument('--lp_alpha', type=float, default=0.6,
help='Alpha used in calculating the length penalty')
parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty')
parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Redundant parameter?

parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number')
parser.add_argument('--bucket_scheme', type=str, default='constant',
help='Strategy for generating bucket keys. It supports: '
'"constant": all the buckets have the same width; '
'"linear": the width of bucket increases linearly; '
'"exp": the width of bucket increases exponentially')
parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the '
'throughput of the bucketing')
parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source '
'sentence, -1 means no clipping')
parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target '
'sentence, -1 means no clipping')
parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm')
parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate')
parser.add_argument('--warmup_steps', type=float, default=4000,
help='number of warmup steps used in NOAM\'s stepsize schedule')
parser.add_argument('--num_accumulated', type=int, default=1,
help='Number of steps to accumulate the gradients. '
'This is useful to mimic large batch training with limited gpu memory')
parser.add_argument('--magnitude', type=float, default=3.0,
help='Magnitude of Xavier initialization')
parser.add_argument('--average_checkpoint', action='store_true',
help='Turn on to perform final testing based on '
'the average of last few checkpoints')
parser.add_argument('--num_averages', type=int, default=5,
help='Perform final testing based on the '
'average of last num_averages checkpoints. '
'This is only used if average_checkpoint is True')
parser.add_argument('--average_start', type=int, default=5,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

--optimizier to --average_start are all training only parameters, please remove all of these in inference script.

help='Perform average SGD on last average_start epochs')
parser.add_argument('--full', action='store_true',
help='In default, we use the test dataset in'
' http://statmt.org/wmt14/test-filtered.tgz.'
' When the option full is turned on, we use the test dataset in'
' http://statmt.org/wmt14/test-full.tgz')
parser.add_argument('--bleu', type=str, default='tweaked',
help='Schemes for computing bleu score. It can be: '
'"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor '
'repository, where compound words are put in ATAT format; '
'"13a": This uses official WMT tokenization and produces the same results'
' as official script (mteval-v13a.pl) used by WMT; '
'"intl": This use international tokenization in mteval-v14a.pl')
parser.add_argument('--log_interval', type=int, default=100, metavar='N',
help='report interval')
parser.add_argument('--save_dir', type=str, default='transformer_out',
help='directory path to save the final model and training log')
parser.add_argument('--gpus', type=str,
help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.'
'(using single gpu is suggested)')
parser.add_argument('--model_parameter', type=str, default=' ',
help='provide model parameter for inference.')

args = parser.parse_args()
logging_config(args.save_dir)
logging.info(args)

# data process
data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \
= dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args)

dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt'))
dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt'))

data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False)
data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
for i, ele in enumerate(data_val)])
data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i)
for i, ele in enumerate(data_test)])

data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x)
for x in
[data_train, data_val, data_test]]
detokenizer = nlp.data.SacreMosesDetokenizer()

# model prepare
ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \
[mx.gpu(int(x)) for x in args.gpus.split(',')]
num_ctxs = len(ctx)

if args.src_max_len <= 0 or args.tgt_max_len <= 0:
max_len = np.max(
[np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0),
np.max(data_test_lengths, axis=0)],
axis=0)
if args.src_max_len > 0:
src_max_len = args.src_max_len
else:
src_max_len = max_len[0]
if args.tgt_max_len > 0:
tgt_max_len = args.tgt_max_len
else:
tgt_max_len = max_len[1]

encoder, decoder = get_transformer_encoder_decoder(units=args.num_units,
hidden_size=args.hidden_size,
dropout=args.dropout,
num_layers=args.num_layers,
num_heads=args.num_heads,
max_src_length=max(src_max_len, 500),
max_tgt_length=max(tgt_max_len, 500),
scaled=args.scaled)
model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder,
share_embed=args.dataset != 'TOY', embed_size=args.num_units,
tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_')

model.load_parameters(args.model_parameter, ctx)

static_alloc = True
model.hybridize(static_alloc=static_alloc)
#logging.info(model)

# translator prepare
translator = BeamSearchTranslator(model=model, beam_size=args.beam_size,
scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha,
K=args.lp_k),
max_length=200)
logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k))

test_loss_function = MaskedSoftmaxCELoss()
test_loss_function.hybridize(static_alloc=static_alloc)

def inference():
"""inference function."""
logging.info('Inference on dev_dataset!')

# data and model prepare
_, _, test_data_loader \
= dataprocessor.make_dataloader(data_train, data_val, data_test, args,
use_average_length=True, num_shards=len(ctx))

if args.bleu == 'tweaked':
bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY')
split_compound_word = bpe
tokenized = True
elif args.bleu == '13a' or args.bleu == 'intl':
bpe = False
split_compound_word = False
tokenized = False
else:
raise NotImplementedError

translation_out = []
all_inst_ids = []
avg_loss_denom = 0
avg_loss = 0.0
total_wc = 0
total_time = 0
batch_total_blue = 0

for batch_id, (src_seq, tgt_seq, src_test_length, tgt_test_length, inst_ids) \
in enumerate(test_data_loader):

total_wc += src_test_length.sum().asscalar() + tgt_test_length.sum().asscalar()

src_seq = src_seq.as_in_context(ctx[0])
tgt_seq = tgt_seq.as_in_context(ctx[0])
src_test_length = src_test_length.as_in_context(ctx[0])
tgt_test_length = tgt_test_length.as_in_context(ctx[0])
# Calculating Loss
out, _ = model(src_seq, tgt_seq[:, :-1], tgt_test_length, tgt_test_length - 1)
loss = test_loss_function(out, tgt_seq[:, 1:], tgt_test_length - 1).mean().asscalar()
all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist())
avg_loss += loss * (tgt_seq.shape[1] - 1)
avg_loss_denom += (tgt_seq.shape[1] - 1)

start = time.time()
# Translate to get a bleu score
samples, _, sample_test_length = \
translator.translate(src_seq=src_seq, src_valid_length=src_test_length)
total_time += (time.time() - start)

# generator the translator result for each batch
max_score_sample = samples[:, 0, :].asnumpy()
sample_test_length = sample_test_length[:, 0].asnumpy()
translation_tmp = []
translation_tmp_sentences = []
for i in range(max_score_sample.shape[0]):
translation_tmp.append([tgt_vocab.idx_to_token[ele] for ele in \
max_score_sample[i][1:(sample_test_length[i] - 1)]])

# detokenizer each translator result
for _, sentence in enumerate(translation_tmp):
if args.bleu == 'tweaked':
translation_tmp_sentences.append(sentence)
translation_out.append(sentence)
elif args.bleu == '13a' or args.bleu == 'intl':
translation_tmp_sentences.append(detokenizer(_bpe_to_words(sentence)))
translation_out.append(detokenizer(_bpe_to_words(sentence)))
else:
raise NotImplementedError

# generate tgt_sentence for bleu calculation of each batch
tgt_sen_tmp = [test_tgt_sentences[index] for \
_, index in enumerate(inst_ids.asnumpy().astype(np.int32).tolist())]
batch_test_bleu_score, _, _, _, _ = compute_bleu([tgt_sen_tmp], translation_tmp_sentences,
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
batch_total_blue += batch_test_bleu_score

# log for every ten batchs
if batch_id % 10 == 0 and batch_id != 0:
batch_ave_bleu = batch_total_blue / 10
batch_total_blue = 0
logging.info('batch id={:d}, loss={:.4f}, batch_bleu={:.4f}'
.format(batch_id, loss, batch_ave_bleu * 100))

# reorg translation sentences by inst_ids
real_translation_out = [None for _ in range(len(all_inst_ids))]
for ind, sentence in zip(all_inst_ids, translation_out):
real_translation_out[ind] = sentence

# get bleu score, n-gram precisions, brevity penalty, reference length, and translation length
test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], real_translation_out,
tokenized=tokenized, tokenizer=args.bleu,
split_compound_word=split_compound_word,
bpe=bpe)
# total batch logging
test_ave_loss = avg_loss / avg_loss_denom
logging.info('Inference at test dataset. Loss={:.4f}, \
inference ppl={:.4f}, inference bleu={:.4f}, throughput={:.4f}K wps'
.format(test_ave_loss, np.exp(test_ave_loss),
test_bleu_score * 100, total_wc / total_time / 1000))


if __name__ == '__main__':
if args.model_parameter:
inference()
else:
logging.error('Do inference but model parameter not provide!')
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved