Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fast Metal FFT for all N #981

Closed
wants to merge 16 commits into from
Closed

Conversation

barronalex
Copy link
Collaborator

@barronalex barronalex commented Apr 11, 2024

Proposed changes

A feature-complete Metal FFT that's faster than both CPU and PyTorch MPS in the majority of 1D cases.

Fully functional, but still needs some clean up.

Resolves #399.

Supports

  • All values of N (tested up to 2^20)
  • Real and Inverse FFTs: fft, ifft, rfft, irfft
  • ND FFTs: fft2, ifft2, rfft2, irfft2, fftn, ifftn, rfftn, irfftn

Performance

For N < 1024, 1D FFTs on my M1 Max:

  • Faster than PyTorch MPS for ~90% of FFT sizes
  • ~1.5X higher average throughput than PyTorch MPS
  • ~13X higher average throughput than CPU

All

We're only behind MPS on some multiples of 7 and all multiples of 11 and 13:
Radix 2-13
Our Bluestein's implementation is significantly more efficient for N < 1024:

Bluestein's

Note: For the sake of time, I ran at a slightly lower batch size than is required to max out the bandwidth for the powers of 2. I'll run a full one shortly, but in my experiments so far the relative speeds seem to hold.

Implementation Details

For N <= 2048 whose prime factors are all <= 7:

  • Stockham's algorithm in threadgroup memory with a mixed-radix, out of place FFT
  • Codelets for radix-2,3,4,5,7
  • Kernels are specialized at runtime for each N with Metal function constants
  • Threadgroup batching for small N to improve performance

For all other N <= 1024:

  • A fused Bluestein's algorithm implementation
  • Bluestein twiddles are computed on CPU in float64 to maintain acceptable precision for the overall algorithm.

For N > 1024:

  • The four step FFT algorithm
  • If N has prime factors > 1024, we use a manual version of Bluestein's implemented with MLX ops

RFFT:

  • We implement a custom kernel for real FFTs that uses a trick to perform two at a time, doubling the bandwidth

Areas for Improvement

Codelet optimizations and additions

The radix codelets are extremely naive currently and could be replaced with hand-tuned or compiled ones that perform fewer than O(N^2) operations. We should also add radix11 and radix13 codelets to match MPS and VkFFT.

Performance on ND and four step FFT cases

These have quite a few unnecessary copies currently. A fused implementation incorporating the transpose and twiddle factors would bring us closer to the max bandwidth.

Accuracy

Accuracy is comparable to MPS' implementation but about an order of magnitude behind pocketfft. More careful twiddle factor computation inspired by pocketfft could help here. Precision also suffers a bit on very large N. Computing the twiddle factors in float64 as we do with Bluestein's would help.

IRFFT

irfft on GPU currently only works for outputs of rfft (there are a couple exceptions in the tests to account for this).

Convolution theorem

The fused Bluestein's implementation contains a convolution implemented with FFTs via the convolution theorem. For larger kernel sizes we might want to adapt this and add it to the main convolution implementation as suggested in #811.

Checklist

Put an x in the boxes that apply.

  • I have read the CONTRIBUTING document
  • I have run pre-commit run --all-files to format my code / installed pre-commit prior to committing changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have updated the necessary documentation (if needed)

@awni awni mentioned this pull request Apr 11, 2024
4 tasks
@awni
Copy link
Member

awni commented Apr 12, 2024

Do you mind rebasing this @barronalex ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Add Metal support for FFT
2 participants