Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Dropout(Activation(x+bias)), now with partial BW fusion #144

Closed
wants to merge 7 commits into from

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Dec 9, 2021

What does this PR do?

This was a long time in the making.. Fusing the BW part of the activation/bias/dropout kernel. Not quite perfect but in some places the speed goes really bananas (like x3 or x4 the naive calls).
Fusing this implied flipping the whole problem upside down, basically the seeds have to be per collum, and the kernels (fw and bw) also work that way. This allows us to fuse the bias gradient computations, since it's a sum over that direction

TODO:

  • add more unit tests to check that the dropout drops are respected on average
  • possibly make sure that the rand mask does not repeat (may or may not be a big deal). Ok this is doable by making the kernels cooperate on the same col, like Phil does on LayerNorm
  • improve on the scheduling for small buffers
  • Fix the atomic add funkiness (works for now but this does not look completely right, num_warps dependent)

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2021
@blefaudeux blefaudeux changed the base branch from main to bench_triton_sum December 9, 2021 21:27
@blefaudeux blefaudeux marked this pull request as draft December 9, 2021 21:28
@blefaudeux blefaudeux force-pushed the dropout_bw_fusion branch 3 times, most recently from a18bda0 to 147e2c9 Compare December 10, 2021 00:04
@blefaudeux blefaudeux changed the base branch from bench_triton_sum to main December 10, 2021 00:05
@blefaudeux blefaudeux force-pushed the dropout_bw_fusion branch 6 times, most recently from c7ab5b0 to 8dea61d Compare December 10, 2021 06:20
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 10, 2021

#Dropout_Bias_False_FW+BW_torch float16_Act:_gelu

#Dropout_Bias_True_FW+BW_torch float16_Act:_squared_relu

#Dropout_Bias_True_FW_torch float16_Act:_squared_relu

Interested @suchenzang ? This took a while to get right. Some speed for small tensors should be recoverable, I didn´t play with the settings too much

edit: old plots, see below for up to date numbers

@blefaudeux blefaudeux changed the title [DRAFT] Dropout BW fusion [feat] Dropout(Activation(x+bias)), now with BW fusion Dec 10, 2021
@blefaudeux blefaudeux requested review from dianaml0, jieru-hu and fmassa and removed request for dianaml0 December 10, 2021 06:24
@blefaudeux blefaudeux marked this pull request as ready for review December 10, 2021 06:24
@suchenzang
Copy link

Oh man, coming to xFormers to shop for parts is great. @blefaudeux these numbers are on V100s or A100s?

@blefaudeux
Copy link
Contributor Author

Oh man, coming to xFormers to shop for parts is great. @blefaudeux these numbers are on V100s or A100s?

ahah, I missed you @suchenzang ! This is on an ampere laptop, working with what I have around.. 400GB/s is the max bandwidth, so basically there's not much to win on the inference side. The reported training number is not exact GB wise (there are operations in the middle not counted), but could well be that the scheduling brings back another 10/20%, these are very raw numbers. It should be at parity with pytorch accuracy wise, not compromising here, just fusing the kernels. I'll pull in @Chillee also in that he has a NVFuser solution for the same part that you could test out (we were wondering whether xformers could host that also), he was actually the incentive for me to revisit this (he's got very good numbers !)

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 10, 2021

#Dropout_Bias_True_FW_torch float32_Act:_None
the scheduling probably has some margin for improvement, because fp32 saturates the bandwidth better. Anyway, that's the easy part

edit: old plot, see below for up to date numbers

@blefaudeux
Copy link
Contributor Author

cc @min-xu-ai , if it helped to have a look at a Triton kernel for Mevo

@blefaudeux
Copy link
Contributor Author

I'll clean up the PR, sorry for all the extra changes

@blefaudeux
Copy link
Contributor Author

I'll clean up the PR, sorry for all the extra changes

I just pushed a cleaned up version, should be better

@@ -25,42 +25,48 @@
class _dropout(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, x, p, bias, activation, activation_grad):
def forward(ctx, x, p, bias, activation, activation_grad, trainable_bias):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the trainable bias (or not) was not properly handled before (bias was always assumed to be trainable, which is mostly true but not always)

using less seeds

tiling + vertical seeds

Computing the FW and BW per tile over M

workaround atomics, reintroduce locks, should not be too bad

yet another take, partial sum

better scheduling defaults, improves across the board

giving atomic add a go

back to the locks, completely fuse the BW
@blefaudeux
Copy link
Contributor Author

Probably too long on that, it was on the side and I spent a lot of time getting back into the proper context all the time (+couple of small hiccups with triton on my way). Now with good enough perfs I think, there's one case which is not faster than pytorch (really small buffers + gelu + fp16), everything else is significantly faster and the FW speed almost doubled with this PR. I updated all the graphs, keep in mind when comparing that the previous ones were with V100 (HBM memory, something like 900GB/s max bandwidth) and the new ones are with a 3080 laptop (GDDR6 memory, something like 440GB/s max bandwidth). I think that this could be revisited with a newer Triton, or with NVFuser/functorch (if not too much work a PR would be great @Chillee, if the speed is there OOB I would definitely take it !)

Up for review and next is #153, probably much more impactful if this is doable

@blefaudeux blefaudeux marked this pull request as ready for review December 18, 2021 07:09
@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 18, 2021

testing a training with microGPT, looks like sometihng is wrong, the loss plateau..
edit: checked, the stats are wrong on the % of dropout, randint4x is not returning what I thought it was.. fixing that

@blefaudeux blefaudeux marked this pull request as draft December 18, 2021 07:54
@blefaudeux
Copy link
Contributor Author

testing a training with microGPT, looks like sometihng is wrong, the loss plateau.. edit: checked, the stats are wrong on the % of dropout, randint4x is not returning what I thought it was.. fixing that

fixed, no impact on perf. Checking with loss curves right now that everything is fine but should be the case. I've improved on the unit test to catch that, basically p=0.5 was correct but p=0.1 was not

@blefaudeux blefaudeux marked this pull request as ready for review December 18, 2021 23:12
@blefaudeux
Copy link
Contributor Author

Screenshot from 2021-12-18 15-39-03
good, now training properly

@suchenzang
Copy link

Screenshot from 2021-12-18 15-39-03

good, now training properly

Do you have a baseline plot to compare against?

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 19, 2021

Screenshot from 2021-12-18 15-39-03
good, now training properly

Do you have a baseline plot to compare against?

this is against the previous main (blue, still FusedMLP), and MLP instead (red, pure pytorch). Scale is log, emphasizes the small differences, if not you cannot distinguish one from the other. The end test (sampled text) looked good. In both fused cases there's a small but measurable difference with pure pytorch though, I'm not sure why, except that maybe the AMP execution is not the same (in the FusedMLP case inputs are cast to fp16 and remain so over the layer, in the pytorch non-fused case maybe that not everything is fp16)

Screenshot from 2021-12-18 20-17-00

@suchenzang
Copy link

suchenzang commented Dec 19, 2021

So the only mildly worrisome thing here is seeing some divergence as training progresses. It seems like that delta shows up a bit vs the blue. If the two converges to different points, then we have a problem :(

@blefaudeux
Copy link
Contributor Author

blefaudeux commented Dec 19, 2021

So the only mildly worrisome thing here is seeing some divergence as training progresses. It seems like that delta shows up a bit vs the blue. If the two converges to different points, then we have a problem :(

totally, it's very strange because there's no shortcut in the implementation, it's supposed to give the same results. I've tried AMP/fp16, it's not that, same results. Now it could just be about the seed, I'll check that next

@blefaudeux
Copy link
Contributor Author

So the only mildly worrisome thing here is seeing some divergence as training progresses. It seems like that delta shows up a bit vs the blue. If the two converges to different points, then we have a problem :(

totally, it's very strange because there's no shortcut in the implementation, it's supposed to give the same results. I've tried AMP/fp16, it's not that, same results. Now it could just be about the seed, I'll check that next

hmm, testing seeds it does change the result a bit, but still. With this PR we're using less seeds within the random number generation, and generating more numbers out of them, but that's the only measurable difference top of head with the previous take, the rest is just about a different kernel architecture

@blefaudeux blefaudeux marked this pull request as draft December 19, 2021 17:57
@blefaudeux
Copy link
Contributor Author

Discussing with Phil, could be that the RNG is not good enough, checking that next

@blefaudeux
Copy link
Contributor Author

Discussing with Phil, could be that the RNG is not good enough, checking that next

could well be the reason, the fused dropout on main/ is fine. I checked if there was a pattern in the dropout, but not obvious on a heatmap. Dropping that for now, could be that this was the main reason for a very fast FW (generating less seeds), but cannot be at the expense of training accuracy

@blefaudeux blefaudeux closed this Dec 20, 2021
xwhan pushed a commit to xwhan/xformers that referenced this pull request Feb 8, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants