Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add Metal support for FFT #399

Open
awni opened this issue Jan 8, 2024 · 9 comments
Open

[Feature] Add Metal support for FFT #399

awni opened this issue Jan 8, 2024 · 9 comments
Labels
enhancement New feature or request

Comments

@awni
Copy link
Member

awni commented Jan 8, 2024

Add support for Metal backend with FFT primitive as mentioned here ml-explore/mlx-examples#249

@awni awni added the enhancement New feature or request label Jan 8, 2024
@aneeshk1412
Copy link
Contributor

I would like to take up this issue @awni. Please let me know if I can take this implementation up.

@awni
Copy link
Member Author

awni commented Feb 1, 2024

Do you have experience with GPU programming? This is not a trivial one so I would recommend starting with something simpler if not.

@aneeshk1412
Copy link
Contributor

I have some experience from my intern at Amazon HPC and a GPU programming course. I am not completely familiar with Metal, but I am looking at its documentation. I'm familiar with the basics of FFT. Would that be enough to start on this?

@awni
Copy link
Member Author

awni commented Feb 2, 2024

It's hard for me to answer. I would recommend you take a look at parallel implementations of FFT. There is also this code which does FFT or Metal. Maybe a good place to start is to benchmark that code against a CPU implementation and just see how usable it might be?

We can use this thread to discuss your findings.

@Rifur13
Copy link
Contributor

Rifur13 commented Feb 12, 2024

What's the status on this? It's blocking some audio processing stuff I'm trying to do. It looks like there's MPS implementation - any interests in adding this?

A tangential question to this, why doesn't the codebase leverage MPS more? IIUC it handles broadcasting between matrices correctly and has optimized ops for lower dimensions.

@AndreSlavescu
Copy link
Contributor

AndreSlavescu commented Feb 16, 2024

Just to chime in, if this operator is implemented, when I eventually finish implementing conv3d, I can accelerate the operation for larger input kernel shapes. FFT conv tends to outperform winowgrad conv for larger kernel shape inputs. Would be great to have as a feature to dispatch the appropriate conv kernel depending on shape.

I can also take a stab at this if nobody is working on it anymore @awni.

@adonath
Copy link

adonath commented Mar 11, 2024

@AndreSlavescu See also: #811 (comment)

@avinashahuja
Copy link

Could you reuse the code written for PyTorch? They added MPS support for FFT. pytorch/pytorch#119670

@awni
Copy link
Member Author

awni commented Mar 14, 2024

It's not a great fit for us as we don't wrap MPSGraph and likely wouldn't make an exception for FFTs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants