[FEATURE] Add transformer inference code #852
Conversation
Codecov Report
|
Codecov Report
@@ Coverage Diff @@
## master #852 +/- ##
==========================================
+ Coverage 90.38% 90.47% +0.08%
==========================================
Files 66 66
Lines 6367 6405 +38
==========================================
+ Hits 5755 5795 +40
+ Misses 612 610 -2
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the module docstring and the argparse description to reflect this is only for inference. Thanks!
@pengxin99 please help to resolve the comments and we need to merge this soon. |
Job PR-852/3 is complete. |
thanks for your comments, code has been modified accordingly, please take review. @leezu . @ciyongch @eric-haibin-lin |
Job PR-852/4 is complete. |
Job PR-852/5 is complete. |
Job PR-852/6 is complete. |
@pengxin99 please take a look at CI failure, pylint/format error. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For inference, only test dataset is required (both src_test and target_test), while training and validation dataset are for training phase. loss
is also useless in inference mode, bleu/ppl is enough.
Please cleanup all those training related arguments/dataset/metrics(loss).
|
||
parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' | ||
'We use this script only for transformer inference.') | ||
parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Set default value to "WMT2014BPE" ?
parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.') | ||
parser.add_argument('--src_lang', type=str, default='en', help='Source language') | ||
parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') | ||
parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we really need --epochs
for inference mode?
help='Dimension of the hidden state in position-wise feed-forward networks.') | ||
parser.add_argument('--dropout', type=float, default=0.1, | ||
help='dropout applied to layers (0 = no dropout)') | ||
parser.add_argument('--epsilon', type=float, default=0.1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Training only parameters?
parser.add_argument('--num_heads', type=int, default=8, | ||
help='number of heads in multi-head attention') | ||
parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') | ||
parser.add_argument('--batch_size', type=int, default=1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--batch_size
should not be related with hardware back-end.
parser.add_argument('--lp_alpha', type=float, default=0.6, | ||
help='Alpha used in calculating the length penalty') | ||
parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') | ||
parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Redundant parameter?
help='Perform final testing based on the ' | ||
'average of last num_averages checkpoints. ' | ||
'This is only used if average_checkpoint is True') | ||
parser.add_argument('--average_start', type=int, default=5, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
--optimizier
to --average_start
are all training only parameters, please remove all of these in inference script.
Job PR-852/7 is complete. |
@ciyongch thanks for review :)
|
@pengxin99 , only bleu score should be fine for inference. It's ok to reuse |
@pengxin99 please take a look at the failure and check if it's related to your latest code changes. |
There were some unrelated CI failures. They will go away after #875 is merged and if this PR merges or rebases on current master. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Job PR-852/16 is complete. |
Still some issues left with the testing: http://ci.mxnet.io/blue/organizations/jenkins/GluonNLP-py3-gpu-integration/detail/PR-852/16/pipeline#step-126-log-504 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ciyongch the test failed because the param file is not available.
mxnet.base.MXNetError: [23:00:08] src/io/local_filesys.cc:209: Check failed: allow_null: LocalFileSystem::Open "./scripts/machine_translation/transformer_en_de_u512/valid_best.params": No such file or directory
Can you make sure the checkpoint is downloaded in the test?
@eric-haibin-lin Sure, I will add this. |
@eric-haibin-lin , I've added the support of downloading params file if needed. Do you have any preferred location to stores this file ( ~387MB)? Currently I just put it on google drive. |
Job PR-852/17 is complete. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
|
||
param_name = args.model_parameter | ||
if (not os.path.exists(param_name)): | ||
download("https://drive.google.com/open?id=1588i6OoaL8qC0K8gI3p2iFOYY5AEuRIN", fname=param_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a warning that the provided file does not exist, and the download will happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eric-haibin-lin I've created a new commit to address your comments :).
And I think it's better to host this params on s3 to keep align with other dataset/param files.
Job PR-852/18 is complete. |
@eric-haibin-lin looks like CI failed to download completed params file ( |
@ciyongch I can help upload the file. Just let me know where I can download the complete file and I can share the link with you once done |
Thanks @szha :) |
Job PR-852/19 is complete. |
Job PR-852/20 is complete. |
@szha @eric-haibin-lin The params file is updated to use s3 link, also sha1_hash is added to check the file, Please help to take a check. |
@ciyongch @pengxin99 nice work. Thanks! |
@ciyongch @pengxin99 @eric-haibin-lin what is the difference between the trained parameters added in this PR (http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-97ffd554a.zip) and the ones we used earlier and still linked in the website (http://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/transformer_en_de_512_WMT2014-e25287c5.zip) |
@leezu @eric-haibin-lin The new params 97ffd554a was introduced for testing the transformer inference script when adding it as we didn't notice the available pre-trained params. |
Description
Add transformer inference code to make inference easy and convenient to analysis the performance of transformer inference.
@TaoLv @juliusshufan @pengzhao-intel
can use below command to do inference:
python inference_transformer.py --dataset WMT2014BPE --src_lang en --tgt_lang de --batch_size 2700 --scaled --average_start 5 --num_buckets 20 --bucket_scheme exp --bleu 13a --model_parameter PATH/TO/valid_best.params
will get output:
Checklist
Essentials
Changes
Comments