-
Notifications
You must be signed in to change notification settings - Fork 988
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Support fft based convolution #811
Comments
wouldn't mind looking into implementing this, what do you think ? |
One challenge here is that FFT is not yet supported on the GPU (in Metal). So you could use it but on the CPU it would almost certainly be much slower than our GPU convolution. Also I think FFT-based convolution is more of an implementation detail. If there are some sizes that are slow for you, please share any benchmarks. We can then figure out the best way to make them faster (which may or may not require an FFT-based convolution). |
Thanks @awni and @sebblanchet! I did a quick implementation of a FFT based convolution in MLX: def _centered(arr, newshape):
newshape = mx.array(newshape)
currshape = mx.array(arr.shape)
startind = (currshape - newshape) // 2
endind = startind + newshape
myslice = [slice(startind[k].item(), endind[k].item()) for k in range(len(endind))]
return arr[tuple(myslice)]
def convolve_fft(image, kernel, stream):
"""Convolve FFT for torch tensors"""
image_2d, kernel_2d = image[0, 0], kernel[0, 0]
shape = [image_2d.shape[i] + kernel_2d.shape[i] - 1 for i in range(image_2d.ndim)]
image_ft = mx.fft.rfft2(image, s=shape, stream=stream)
kernel_ft = mx.fft.rfft2(kernel, s=shape, stream=stream)
result = mx.fft.irfft2(image_ft * kernel_ft, s=shape, stream=stream)
return _centered(result, image.shape) I also did a simple benchmark. It uses a random image of size 1024x1024 and varying kernel sizes. It compares I think it follows exactly the expectation:
In general I think it is still worth to have an FFT based convolution. For NNs with small kernels, there is no point. But there are many scientific applications that rely on large kernels (think of cross-correlations, convolution with pathological point spread functions, etc.) I think it is worth re-opening. |
Ok sounds good! Thanks for the benchmarks, that's really interesting! |
One option is to update the CPU convolution to dispatch to an FFT implementation when the input sizes make sense. We would want to benchmark it in a few settings to be sure it's a strict improvement. |
Thanks for re-opening @awni!
This is what Scipy has too, see https://github.com/scipy/scipy/blob/v1.12.0/scipy/signal/_signaltools.py#L1161 There is the option to measure or to actually compute the flops. Measuring only makes sense for repeated convolutions, but gives probably the most accurate results for arbitrary architectures. Looking at the Scipy code, it seems that computing the flops is maybe too complex. Or is there a general way to predict flops for mlx operations? (would be nice to have...) In general the performance of MLX operations is probably much more predictable across the more homogeneous M architectures. So there could be a third option by just parametrizing the scaling laws based on empirical benchmarks or something similar... |
Here is the gist with the code for the benchmark: https://gist.github.com/adonath/3f16b30498c60f25cf1349792c15283c |
can I work on it? |
Go for it! |
Its my first contribution in the repo can you please guide a bit what files I need to modify? |
It would be nice to have FFT based convolution supported in mlx. FFT bases convolution shows much better performance for large images / arrays and kernels. The FFT building blocks are already supported in mlx, so it is mostly a matter of combining them to a convolution operation.
The text was updated successfully, but these errors were encountered: