Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metal FFT for powers of 2 up to 2048 #915

Merged
merged 7 commits into from
Apr 12, 2024
Merged

Conversation

barronalex
Copy link
Collaborator

Proposed changes

#399

Add a GPU FFT algorithm for the powers of 2 from 4 -> 2048.

Only supports 1D, forward, complex to complex transforms at the moment but planning to follow up with more features shortly.

Performance

On my M1 Max:
fft_plot

For 64 <= N <= 512, we're doing ~360Gb/s with a large batch size which isn't far off from the maximum memory bandwidth of an M1 Max (~400Gb/s). The other sizes are slightly slower but can be addressed in a follow up PR.

The GPU implementation is 20 to 35 times faster than the CPU implementation for the FFT sizes it implements.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

atol = 1e-4
rtol = 1e-4
np.random.seed(7)
with mx.stream(mx.gpu):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just run these tests on both devices rather than skipping them for the CPU and specifying the context manager. The goal is to remove the context manager for the CPU above once all the ops can run on the GPU.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(usually we don't specify the device in the tests but just run the full test suite twice with the CPU and GPU as default).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah that makes sense -- updated.

mlx/backend/metal/fft.cpp Outdated Show resolved Hide resolved

size_t n = in.shape(axes_[0]);

if (!is_power_of_2(n) || n > 2048 || n < 4) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curious, how difficult would it be to support non powers of 2? (We could easily pad with zeros.. but maybe there is a clean way to do it in the kernel itself?).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not completely trivial since there are a couple different algorithms libraries typically use based on the prime factors of N.

VkFFT seems to do this:

For N > 2048, we'll probably want to use the 4-step FFT algorithm.

I was thinking I'd first have a go at implementing the 3/5/7/11/13-radix kernels. Then we'd have about 17% of 1 < N <= 2048 covered.

I'm happy to work on Rader's/Bluestein's/4 step after too.

Comment on lines 62 to 61
// FFT dim has stride 1 so there are no gaps
flags.contiguous = true;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't seem quite right. Even if the FFT dim has stride 1, you could have gaps in due to another dimensions having a larger stride?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're definitely right there's a bug in the no_copy case, which should now be fixed. Thanks for catching it!

return x_copy;
}
};
const array& in_contiguous = check_input(inputs[0]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This array, is it actually contiguous or just contiguous in the FFT dim? If it's not truly row_contiguous, then we presumably need to specify the strides to the kernel?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! I've changed it to do the GeneralGeneral copy when the input array isn't contiguous and added a test for this case. Is that alright for now?

I started working on passing the strides to directly to the kernel to avoid the extra copy, but I might save that for a future PR.

Copy link

@djphoenix djphoenix Apr 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@barronalex FWIW, it is indeed place where nextup operations on FFT result (like .abs()) cause fatal error. Replacing in_contiguous.flags() to empty is exact that makes my experimental code working. BTW big thank you for your work on Metal FFT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! Will push a fix shortly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@barronalex would you mind adding a test or two with non-contiguous arrays (e.g. output of transpose or broadcast). The logic here is a bit subtle so some tests would be good.

@barronalex
Copy link
Collaborator Author

If this goes in, I have a follow up working that implements ifft, rfft and rifft with similar performance characteristics.

@yury
Copy link

yury commented Apr 5, 2024

I wonder how pocketfft compared to Accelerate.framework vDSP_fft in terms of performance.
Or Accelerate framework uses pocketfft inside?

auto check_input = [this, &copies, &s](const array& x) {
// TODO: Pass the strides to the kernel so
// we can avoid the copy when x is not contiguous.
bool no_copy = x.strides()[axes_[0]] == 1 && x.flags().contiguous;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you want x.falgs().row_contiguous || x.flags().col_contiguous here instead of x.flags().contiguous (which has a different meaning).

@awni
Copy link
Member

awni commented Apr 7, 2024

This looks great and I think we can merge it soon! Just left a couple more comments. Let me know when you've addressed and will re-run tests.

@barronalex
Copy link
Collaborator Author

@awni Thanks for the comments! I added some contiguity tests and changed some of the logic. Let me know if they look reasonable.

mlx/backend/metal/fft.cpp Outdated Show resolved Hide resolved
mlx/backend/metal/fft.cpp Outdated Show resolved Hide resolved
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is awesome. Let's get it landed and focus on additions in #981

#include "mlx/primitives.h"

namespace mlx::core {

void FFT::eval_gpu(const std::vector<array>& inputs, array& out) {
auto& s = out.primitive().stream();
Copy link
Member

@awni awni Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

auto& s = stream() works here ;)

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm cool to land this and continue the discussion in #981

@barronalex
Copy link
Collaborator Author

Sounds good to me!

@awni awni merged commit 2e7c02d into ml-explore:main Apr 12, 2024
5 checks passed
@barronalex barronalex deleted the ab-metal-fft branch May 10, 2024 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants