BiasDropoutFusion #4167
Merged
Merged
Conversation
Dropout kernel for residual input BiasDropout Fusion to take residual input Fix BiasDropout Kernel Optimize DropoutGrad with 4 elements per thread
896551e to
b6da6ac
Compare
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 16, 2020
edgchen1
reviewed
Jun 22, 2020
added 7 commits
June 22, 2020 23:19
edgchen1
reviewed
Jun 24, 2020
edgchen1
reviewed
Jun 24, 2020
edgchen1
reviewed
Jun 24, 2020
edgchen1
reviewed
Jun 24, 2020
edgchen1
previously approved these changes
Jun 24, 2020
ytaous
reviewed
Jun 25, 2020
edgchen1
previously approved these changes
Jun 25, 2020
edgchen1
reviewed
Jun 30, 2020
| if (li < N) { | ||
| mask_data[li] = (&rand.x)[i] < p; | ||
| Y_data[li] = X_data[li] * T(mask_data[li]) * scale; | ||
| Y_data[li] = T(float(X_data[li]) * mask_data[li] * scale); |
Contributor
There was a problem hiding this comment.
should we do the math with floats even when T is double?
edgchen1
approved these changes
Jun 30, 2020
kkaranasos
reviewed
Jun 30, 2020
| Fuse Add + Dropout + optional Add to BiasDropoutFusion | ||
|
|
||
| */ | ||
| class BiasDropoutFusion : public GraphTransformer { |
Contributor
There was a problem hiding this comment.
Why don't we make this a RewriteRule? From a quick skim, it seems that the fusion is quite local, so we can avoid traversing the whole tree and call Resolve.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fuse Add + Dropout + Add into a single BiasDropout op.
E2E test passed: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=127697&view=results
Kernel Benchmark
Before
2.12% 422.14ms 7300 DropoutKernel
1.57% 312.42ms 7300 DropoutGradientKernel
After
1.43% 280.37ms 7300 DropoutGradientKernel
1.39% 272.51ms 4800 BiasDropoutKernel
1.05% 206.41ms 2500 DropoutKernel
BERT-L run Benchmark
Before
Stabilized Throughput: 180.876 Examples / Second
After
Stabilized Throughput: 182.82 Examples / Second
Gain: 1.07%