Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ops.swiglu: Tests, baseline, benchmarks #487

Merged
merged 8 commits into from
Nov 10, 2022

Conversation

@danthe3rd danthe3rd mentioned this pull request Oct 21, 2022
danthe3rd pushed a commit that referenced this pull request Oct 21, 2022
ghstack-source-id: a105e64c1cadcf383ccaaef2401042b588e9ec73
Pull Request resolved: #487
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 21, 2022
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR Daniel!

I have a few comments. In particular, I wouldn't add this nn.Module nor the torch.autograd.Function in xformers/ops, and would instead just implement a vanilla function in tests / benchmarks like the following:

def swiglu(...):
    x1 = F.linear(...)
    x2 = F.linear(...)
    x3 = F.silu(x1)
    x4 = x3 * x3
    return F.linear(...)

And for the tests, I would have a wrapper nn.Module that just calls into this swiglu.

Also, I wouldn't write the backwards by hand myself for the reference implementation, but instead use the autograd implementation as reference. This way we are sure we didn't introduce a bug in our autograd.

From a second look, you might have implemented the backwards by hand with PyTorch to have an idea of the numerics which are expected to differ for the operation, is that right?

Thoughts?

return (dy.float() * sigm * (1 + x.float() * (1 - sigm))).to(x.dtype)


class SwiGLU_Decomposed(torch.autograd.Function):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: snake_case and CamelCase mix?

# 952us
@classmethod
def forward(cls, ctx, x, w1, b1, w2, b2, w3, b3):
x1 = x @ w1.transpose(-2, -1) + b1 # 275us
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that you can also call torch.nn.functional.linear so that you can measure the performance from PyTorch, making the time measurements be compatible with what users would get

xformers/ops/swiglu.py Show resolved Hide resolved
from torch import nn


class SwiGLUFFN_Reference(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why provide this implementation? Can't we just test the numerics with a functional API?

danthe3rd pushed a commit that referenced this pull request Oct 21, 2022
ghstack-source-id: bd9a7c784e8acc97cc87e84f7145bf39b7890a10
Pull Request resolved: #487
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Good to merge, I have a few suggestions for the future

return (dx, dw1, db1, dw2, db2, dw3, db3)


def functional_swiglu(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def functional_swiglu(
def _swiglu_reference(

?

Comment on lines +38 to +41
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe reuse the function _swiglu_reference in here?

Suggested change
x1 = self.w1(x)
x2 = self.w2(x)
hidden = F.silu(x1) * x2
return self.w3(hidden)
return _swiglu_reference(x, self.w1.weight, self.w1.bias, self.w2.weight, self.w2.bias, self.w3.weight, self.w3.bias

from torch import nn


class _SwiGLUModule(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want to expose this to the user through other ops, we can make this public as well, and expose an op argument.

Suggested change
class _SwiGLUModule(nn.Module):
class _SwiGLU(nn.Module):

danthe3rd pushed a commit that referenced this pull request Oct 21, 2022
ghstack-source-id: 0c60acdac384e28a6318f3dc91d37ccc02f60251
Pull Request resolved: #487
@danthe3rd danthe3rd mentioned this pull request Oct 21, 2022
@codecov-commenter
Copy link

codecov-commenter commented Oct 21, 2022

Codecov Report

Base: 90.59% // Head: 90.59% // Increases project coverage by +0.00% 🎉

Coverage data is based on head (0f0a602) compared to base (261b164).
Patch coverage: 90.62% of modified lines in pull request are covered.

Additional details and impacted files
@@                  Coverage Diff                  @@
##           gh/danthe3rd/50/base     #487   +/-   ##
=====================================================
  Coverage                 90.59%   90.59%           
=====================================================
  Files                        78       79    +1     
  Lines                      4572     4636   +64     
=====================================================
+ Hits                       4142     4200   +58     
- Misses                      430      436    +6     
Flag Coverage Δ
Python 90.59% <90.62%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
xformers/ops/swiglu.py 90.47% <90.47%> (ø)
xformers/ops/__init__.py 81.25% <100.00%> (+1.25%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

This was referenced Oct 24, 2022
@danthe3rd danthe3rd merged commit ca2c6a7 into gh/danthe3rd/50/base Nov 10, 2022
danthe3rd pushed a commit that referenced this pull request Nov 10, 2022
ghstack-source-id: b751901948cdf1796cc42a5640ef1b117b4ae9f7
Pull Request resolved: #487
@danthe3rd danthe3rd deleted the gh/danthe3rd/50/head branch November 10, 2022 18:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants