-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Metal FFT for powers of 2 up to 2048 #915
Conversation
python/tests/test_fft.py
Outdated
atol = 1e-4 | ||
rtol = 1e-4 | ||
np.random.seed(7) | ||
with mx.stream(mx.gpu): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just run these tests on both devices rather than skipping them for the CPU and specifying the context manager. The goal is to remove the context manager for the CPU above once all the ops can run on the GPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(usually we don't specify the device in the tests but just run the full test suite twice with the CPU and GPU as default).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah that makes sense -- updated.
|
||
size_t n = in.shape(axes_[0]); | ||
|
||
if (!is_power_of_2(n) || n > 2048 || n < 4) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious, how difficult would it be to support non powers of 2? (We could easily pad with zeros.. but maybe there is a clean way to do it in the kernel itself?).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not completely trivial since there are a couple different algorithms libraries typically use based on the prime factors of N.
VkFFT seems to do this:
- Pure radix decomposition (as currently implemented) if N factorizes into primes < 13. This would require adding custom DFTs for 3, 5, 7, 11 and 13.
- Rader's algorithm for everything else except Sophia Germain primes
- Bluestein's algorithm for Sophia Germain primes
For N > 2048, we'll probably want to use the 4-step FFT algorithm.
I was thinking I'd first have a go at implementing the 3/5/7/11/13-radix kernels. Then we'd have about 17% of 1 < N <= 2048 covered.
I'm happy to work on Rader's/Bluestein's/4 step after too.
mlx/backend/metal/fft.cpp
Outdated
// FFT dim has stride 1 so there are no gaps | ||
flags.contiguous = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That doesn't seem quite right. Even if the FFT dim has stride 1, you could have gaps in due to another dimensions having a larger stride?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're definitely right there's a bug in the no_copy
case, which should now be fixed. Thanks for catching it!
return x_copy; | ||
} | ||
}; | ||
const array& in_contiguous = check_input(inputs[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This array, is it actually contiguous or just contiguous in the FFT dim? If it's not truly row_contiguous, then we presumably need to specify the strides to the kernel?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I've changed it to do the GeneralGeneral
copy when the input array isn't contiguous and added a test for this case. Is that alright for now?
I started working on passing the strides to directly to the kernel to avoid the extra copy, but I might save that for a future PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@barronalex FWIW, it is indeed place where nextup operations on FFT result (like .abs()
) cause fatal error. Replacing in_contiguous.flags()
to empty is exact that makes my experimental code working. BTW big thank you for your work on Metal FFT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this! Will push a fix shortly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@barronalex would you mind adding a test or two with non-contiguous arrays (e.g. output of transpose or broadcast). The logic here is a bit subtle so some tests would be good.
ccfa3b4
to
98a9197
Compare
If this goes in, I have a follow up working that implements |
I wonder how pocketfft compared to Accelerate.framework vDSP_fft in terms of performance. |
mlx/backend/metal/fft.cpp
Outdated
auto check_input = [this, &copies, &s](const array& x) { | ||
// TODO: Pass the strides to the kernel so | ||
// we can avoid the copy when x is not contiguous. | ||
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().contiguous; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you want x.falgs().row_contiguous || x.flags().col_contiguous
here instead of x.flags().contiguous
(which has a different meaning).
This looks great and I think we can merge it soon! Just left a couple more comments. Let me know when you've addressed and will re-run tests. |
@awni Thanks for the comments! I added some contiguity tests and changed some of the logic. Let me know if they look reasonable. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is awesome. Let's get it landed and focus on additions in #981
#include "mlx/primitives.h" | ||
|
||
namespace mlx::core { | ||
|
||
void FFT::eval_gpu(const std::vector<array>& inputs, array& out) { | ||
auto& s = out.primitive().stream(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto& s = stream()
works here ;)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm cool to land this and continue the discussion in #981
Sounds good to me! |
Proposed changes
#399
Add a GPU FFT algorithm for the powers of 2 from 4 -> 2048.
Only supports 1D, forward, complex to complex transforms at the moment but planning to follow up with more features shortly.
Performance
On my M1 Max:
For 64 <= N <= 512, we're doing ~360Gb/s with a large batch size which isn't far off from the maximum memory bandwidth of an M1 Max (~400Gb/s). The other sizes are slightly slower but can be addressed in a follow up PR.
The GPU implementation is 20 to 35 times faster than the CPU implementation for the FFT sizes it implements.
Checklist
Put an
x
in the boxes that apply.pre-commit run --all-files
to format my code / installed pre-commit prior to committing changes