-
Notifications
You must be signed in to change notification settings - Fork 555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ops.swiglu: Tests, baseline, benchmarks #487
Conversation
[ghstack-poisoned]
ghstack-source-id: a105e64c1cadcf383ccaaef2401042b588e9ec73 Pull Request resolved: #487
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR Daniel!
I have a few comments. In particular, I wouldn't add this nn.Module
nor the torch.autograd.Function
in xformers/ops
, and would instead just implement a vanilla function in tests / benchmarks like the following:
def swiglu(...):
x1 = F.linear(...)
x2 = F.linear(...)
x3 = F.silu(x1)
x4 = x3 * x3
return F.linear(...)
And for the tests, I would have a wrapper nn.Module
that just calls into this swiglu
.
Also, I wouldn't write the backwards by hand myself for the reference implementation, but instead use the autograd implementation as reference. This way we are sure we didn't introduce a bug in our autograd.
From a second look, you might have implemented the backwards by hand with PyTorch to have an idea of the numerics which are expected to differ for the operation, is that right?
Thoughts?
xformers/ops/swiglu.py
Outdated
return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype) | ||
|
||
|
||
class SwiGLU_Decomposed(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: snake_case and CamelCase mix?
# 952us | ||
@classmethod | ||
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3): | ||
x1 = x @ w1.transpose(-2, -1) + b1 # 275us |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note that you can also call torch.nn.functional.linear
so that you can measure the performance from PyTorch, making the time measurements be compatible with what users would get
xformers/ops/swiglu.py
Outdated
from torch import nn | ||
|
||
|
||
class SwiGLUFFN_Reference(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why provide this implementation? Can't we just test the numerics with a functional API?
[ghstack-poisoned]
ghstack-source-id: bd9a7c784e8acc97cc87e84f7145bf39b7890a10 Pull Request resolved: #487
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Good to merge, I have a few suggestions for the future
return (dx, dw1, db1, dw2, db2, dw3, db3) | ||
|
||
|
||
def functional_swiglu( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def functional_swiglu( | |
def _swiglu_reference( |
?
x1 = self.w1(x) | ||
x2 = self.w2(x) | ||
hidden = F.silu(x1) * x2 | ||
return self.w3(hidden) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe reuse the function _swiglu_reference
in here?
x1 = self.w1(x) | |
x2 = self.w2(x) | |
hidden = F.silu(x1) * x2 | |
return self.w3(hidden) | |
return _swiglu_reference(x, self.w1.weight, self.w1.bias, self.w2.weight, self.w2.bias, self.w3.weight, self.w3.bias |
from torch import nn | ||
|
||
|
||
class _SwiGLUModule(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want to expose this to the user through other ops, we can make this public as well, and expose an op
argument.
class _SwiGLUModule(nn.Module): | |
class _SwiGLU(nn.Module): |
[ghstack-poisoned]
ghstack-source-id: 0c60acdac384e28a6318f3dc91d37ccc02f60251 Pull Request resolved: #487
[ghstack-poisoned]
[ghstack-poisoned]
Codecov ReportBase: 90.59% // Head: 90.59% // Increases project coverage by
Additional details and impacted files@@ Coverage Diff @@
## gh/danthe3rd/50/base #487 +/- ##
=====================================================
Coverage 90.59% 90.59%
=====================================================
Files 78 79 +1
Lines 4572 4636 +64
=====================================================
+ Hits 4142 4200 +58
- Misses 430 436 +6
Flags with carried forward coverage won't be shown. Click here to find out more.
Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here. ☔ View full report at Codecov. |
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
ghstack-source-id: b751901948cdf1796cc42a5640ef1b117b4ae9f7 Pull Request resolved: #487
Stack from ghstack (oldest at bottom):