Skip to content

BiasDropoutFusion #4167

Merged
SherlockNoMad merged 23 commits into
masterfrom
bahuang/bias_dropout
Jun 30, 2020
Merged

BiasDropoutFusion #4167
SherlockNoMad merged 23 commits into
masterfrom
bahuang/bias_dropout

Conversation

@SherlockNoMad
Copy link
Copy Markdown
Contributor

@SherlockNoMad SherlockNoMad commented Jun 9, 2020

Fuse Add + Dropout + Add into a single BiasDropout op.

E2E test passed: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=127697&view=results

Kernel Benchmark

nvprof --print-gpu-summary ./onnxruntime_training_bert --model_name /bert_ort/bert_models/nv/bert-large/bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm --train_data_dir /bert_data/128/books_wiki_en_corpus/train --test_data_dir /bert_data/128/books_wiki_en_corpus/test --train_batch_size 32 --mode train --num_train_steps 100 --display_loss_steps 1 --warmup_ratio=0.2843 --warmup_mode=Poly --optimizer lamb --gradient_accumulation_steps 1 --max_predictions_per_seq=20 --use_nccl  --use_mixed_precision --allreduce_in_fp16 

Before
2.12% 422.14ms 7300 DropoutKernel
1.57% 312.42ms 7300 DropoutGradientKernel

After
1.43% 280.37ms 7300 DropoutGradientKernel
1.39% 272.51ms 4800 BiasDropoutKernel
1.05% 206.41ms 2500 DropoutKernel

BERT-L run Benchmark

./onnxruntime_training_bert --model_name /bert_ort/bert_models/nv/bert-large/bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm --train_data_dir /bert_data/128/books_wiki_en_corpus/train --test_data_dir /bert_data/128/books_wiki_en_corpus/test --train_batch_size 64 --mode train --num_train_steps 228 --display_loss_steps 1 --warmup_ratio=0.2843 --warmup_mode=Poly --optimizer lamb --gradient_accumulation_steps 1 --max_predictions_per_seq=20 --use_nccl  --use_mixed_precision --allreduce_in_fp16 

Before
Stabilized Throughput: 180.876 Examples / Second

After
Stabilized Throughput: 182.82 Examples / Second

Gain: 1.07%

@SherlockNoMad SherlockNoMad requested a review from a team as a code owner June 9, 2020 05:44
@SherlockNoMad SherlockNoMad added the training issues related to ONNX Runtime training; typically submitted using template label Jun 10, 2020
Dropout kernel for residual input

BiasDropout Fusion to take residual input

Fix BiasDropout Kernel

Optimize DropoutGrad with 4 elements per thread
@SherlockNoMad SherlockNoMad force-pushed the bahuang/bias_dropout branch from 896551e to b6da6ac Compare June 15, 2020 17:18
@SherlockNoMad SherlockNoMad changed the title [Draft] BiasDropoutFusion BiasDropoutFusion Jun 16, 2020
@SherlockNoMad SherlockNoMad requested a review from edgchen1 June 16, 2020 07:43
Comment thread onnxruntime/core/graph/graph_utils.cc Outdated
Comment thread onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc Outdated
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc
Comment thread orttraining/orttraining/test/optimizer/graph_transform_test.cc Outdated
Comment thread orttraining/orttraining/test/optimizer/graph_transform_test.cc Outdated
Comment thread orttraining/orttraining/test/optimizer/graph_transform_test.cc Outdated
Comment thread orttraining/orttraining/test/optimizer/graph_transform_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/training_ops/cuda/nn/dropout.cc Outdated
Comment thread orttraining/orttraining/test/training_ops/cpu/nn/dropout_op_test.cc Outdated
Comment thread orttraining/orttraining/training_ops/cuda/nn/dropout.cc Outdated
Comment thread orttraining/orttraining/training_ops/cuda/nn/dropout.cc Outdated
Comment thread orttraining/orttraining/training_ops/cuda/nn/dropout.h Outdated
Comment thread onnxruntime/core/providers/cuda/nn/dropout.h Outdated
edgchen1
edgchen1 previously approved these changes Jun 24, 2020
Comment thread orttraining/orttraining/training_ops/cuda/nn/dropout_impl.cu
Comment thread orttraining/orttraining/core/optimizer/bias_dropout_fusion.cc
edgchen1
edgchen1 previously approved these changes Jun 25, 2020
if (li < N) {
mask_data[li] = (&rand.x)[i] < p;
Y_data[li] = X_data[li] * T(mask_data[li]) * scale;
Y_data[li] = T(float(X_data[li]) * mask_data[li] * scale);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we do the math with floats even when T is double?

@SherlockNoMad SherlockNoMad merged commit 6365760 into master Jun 30, 2020
@SherlockNoMad SherlockNoMad deleted the bahuang/bias_dropout branch June 30, 2020 22:43
Fuse Add + Dropout + optional Add to BiasDropoutFusion

*/
class BiasDropoutFusion : public GraphTransformer {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we make this a RewriteRule? From a quick skim, it seems that the fusion is quite local, so we can avoid traversing the whole tree and call Resolve.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants