Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FFT precision/performance #2952

Closed
s-zymon opened this issue May 4, 2020 · 20 comments · Fixed by #4699
Closed

FFT precision/performance #2952

s-zymon opened this issue May 4, 2020 · 20 comments · Fixed by #4699
Assignees
Labels
performance make things lean and fast

Comments

@s-zymon
Copy link

s-zymon commented May 4, 2020

I find the noticeably difference between outputs of numpy.fft.fft and jax.numpy.fft.fft.
The difference also changes with different device. For cpu device error is bigger than for gpu device. On the other hand mean absolute error for gpu implementation of fft from e.g. PyTorch is around 1e-8 which seems reasonably. I guess that might be some minor bug.

Second issue is the performance of the jax.numpy.fft.fft on cpu device. I am aware that jax is intended for GPU/TPU, but the overhead of jax for fft using cpu seems weirdly big.

Below is the simple code for reproduction.

%env JAX_ENABLE_X64=1
%env JAX_PLATFORM_NAME=cpu

import time
import numpy as np

import jax
from jax import numpy as jnp

np.random.seed(0)

signal = np.random.randn(2**20)
signal_jax = jnp.array(signal)

jfft = jax.jit(jnp.fft.fft)

X_np = np.fft.fft(signal)
X_jax = jfft(signal_jax)

print(np.mean(np.abs(X_np)))
print('max:\t', jnp.max(jnp.abs(X_np - X_jax)))
print('mean:\t', jnp.mean(jnp.abs(X_np - X_jax)))
print('min:\t', jnp.min(jnp.abs(X_np - X_jax)))

### CPU
# 907.3490574884647
# max:	 2.8773885332210747
# mean:	 0.3903197564919141
# min:	 2.4697454729898156e-05

### GPU
# 907.3490574884647
# max:	 0.001166179716824765
# mean:	 0.00020841654559267488
# min:	 2.741492442122853e-07

R = 100
ts = time.time()
for i in range(R):
    _ = np.fft.fft(signal)
print('numpy fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

# Compile
_ = jfft(signal_jax).block_until_ready()

ts = time.time()
for i in range(R):
    _ = jfft(signal_jax).block_until_ready()
print('jax fft execution time [ms]:\t', (time.time()-ts)/R * 1000)

### CPU
# numpy fft execution time [ms]:	 36.75990343093872
# jax fft execution time [ms]:	         219.37960147857666

### GPU
# numpy fft execution time [ms]:	 38.53107929229736
# jax fft execution time [ms]:	         0.38921356201171875
@jakevdp
Copy link
Collaborator

jakevdp commented May 4, 2020

Please see the related discussion in #2874.

@mattjj
Copy link
Member

mattjj commented May 9, 2020

IIUC #2874 is only about a GPU bug. Perhaps this issue is mainly about CPU, both performance and correctness issues.

I think XLA:CPU is using Eigen's FFT. Maybe it's slow or something. I'll ping the XLA:CPU folks to see if they know anything about it. On the JAX side, we could possibly do a CustomCall into some other FFT implementation on CPU, like we CustomCall into LAPACK kernels for matrix decompositions.

@mattjj mattjj self-assigned this May 9, 2020
@mattjj mattjj added the performance make things lean and fast label May 9, 2020
@mattjj
Copy link
Member

mattjj commented May 12, 2020

I confirmed with the XLA:CPU folks that XLA is just calling into Eigen here, and "it's possible but unlikely that XLA is doing something bad here that triggers slowness." (I'd like to double-check just by executing this benchmark that Eigen is being multithreaded properly for FFTs, but I'm not sure when I'll get a chance to do that.)

Depending on whether XLA:CPU folks have the bandwidth to improve this, we might want to look into JAX-side solutions. I'll update this thread again when I learn more.

@MuawizChaudhary
Copy link

MuawizChaudhary commented Jun 18, 2020

I am working on implementing a Jax backend for kymatio (kymat.io), a Python package implementing the scattering transform.

When we compare the Jax FFT implementation against a closed-form expression of the discrete Fourier transform of a box of ones with dtype float32 (that is, the Dirichlet kernel) we note a large deviation. Additionally, comparing the results of the Jax FFT with the results of the NumPy and SciPy FFT shows significant discrepancies.

Are there plans to address this? PR #3290 does not appear to have solved this issue.

def box_dirichlet(N, FFT):
  x = np.arange(N)
  x -= len(x) // 2
  n = 16
  box = np.abs(x) < n
  fbox = np.fft.fftshift(FFT(np.fft.ifftshift(box.astype('float32'))))
  fbox = fbox/2/np.pi
  k = x / (-x.min()) * np.pi
  n = 15
  dirichlet = np.sin((n + .5) * k) / (2 * np.pi * np.sin(.5 * k))
  dirichlet[int(N/2)] = (n + .5) / np.pi
  return dirichlet, fbox


def comparison(dirichlet, fbox):
  print("The absolute difference is: ", np.linalg.norm(dirichlet - fbox))
  print("The relative difference is: ", np.linalg.norm(dirichlet - fbox)/np.linalg.norm(dirichlet))
#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.47198237618500993
The relative difference is: 0.0005201455019338586

dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)

The absolute difference is: 3.6838866295306374e-13
The relative difference is: 4.0598063755531826e-16

dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)

The absolute difference is: 0.00010274223272981098
The relative difference is: 1.132264951183364e-07

Even with smaller arrays we note large differences.

dirichlet, jax_fft = box_dirichlet(2**15., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.025780742747920873
The relative difference is: 0.00016071983551150154

@hawkinsp
Copy link
Member

What version of jaxlib and what hardware platform are you using?

jaxlib 0.1.48 adds 64-bit FFT support on CPU and GPU, which may help if you have accuracy problems. Note also that I believe the NumPy FFT you are comparing it with always computes 64-bit. Can you verify you are using a 64-bit FFT in JAX (i.e., you have 64-bit input types and have JAX_ENABLE_X64 set or similar?)

@MuawizChaudhary
Copy link

MuawizChaudhary commented Jun 18, 2020

Hi,

I was using jaxlib 0.1.47 and 0.1.48 on Google Colab, but this is something my collaborators have noticed on their machines too. We are aware that NumPy upcasts to 64-bit, however SciPy and it appears Jax do not.

The input was a box of ones as float 32s. Testing with input as a box of float 64s, we obtain similar inaccuracies. This is with version 0.1.48, using config.update('jax_enable_x64', True)

#comparison of fft'ed box of ones with dirichlet kernel
dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.47198237618500993
The relative difference is: 0.0005201455019338586

dirichlet, numpy_fft = box_dirichlet(2**20., np.fft.fft)
comparison(dirichlet, numpy_fft)

The absolute difference is: 3.6838866295306374e-13
The relative difference is: 4.0598063755531826e-16

dirichlet, scipy_fft = box_dirichlet(2**20., scipy.fft.fft)
comparison(dirichlet, scipy_fft)

The absolute difference is: 3.781620542617657e-13
The relative difference is: 4.167513480402116e-16

@hawkinsp
Copy link
Member

Out of curiosity, do you see the same results from TensorFlow?

JAX uses Eigen for its FFT implementation on CPU, as does TensorFlow, so one hypothesis is that this is simply due to the quality of the Eigen implementation. That might be nice to verify, if you have time. If they did differ that would be very interesting to know.

@MuawizChaudhary
Copy link

MuawizChaudhary commented Jun 18, 2020

It appears that they give the same outputs on both CPU and GPU!

Edit: My interpretation was that Jax is supposed to be as similar as possible to Numpy. Is this interpretation wrong?

@hawkinsp
Copy link
Member

hawkinsp commented Jun 18, 2020

"Shouldn't Jax be closer to Numpy?"

Ultimately they are two different pieces of code and they will not act the same in all circumstances. And it's not a goal to precisely match NumPy everywhere.

There are at least three things you could mean:
a) JAX should default to float64 precision even when performing float32 FFTs.
b) JAX should return a better quality float64 result on CPU.
c) JAX should return a better quality float64 result on GPU.

For (a): perhaps. We don't try to follow NumPy precisely, and in a number of cases we default to float32 to be more GPU friendly.

For (b): I suspect we need to find a higher quality implementation of FFT on CPU. The obvious candidate is probably Intel's MKL library.

For (c): JAX uses completely different FFT implementations on CPU and GPU. On GPU it uses cufft (which pretty much everyone uses as far as I am aware). I would actually expect that you would see high quality results on GPU. Can you confirm that you were actually running in 64-bit mode on GPU?

@MuawizChaudhary
Copy link

The other results were on CPU, both 32 and 64 bit. Looks like 64 bit on GPU match up.

dirichlet, jax_fft = box_dirichlet(2**20., jnp.fft.fft)
comparison(dirichlet, jax_fft)

The absolute difference is: 0.00019981246168760387
The relative difference is: 2.2020219063518382e-07

dirichlet, tf_fft = box_dirichlet(2**20., tf.signal.fft)
comparison(dirichlet, tf_fft)

The absolute difference is: 0.00019981246168760387
The relative difference is: 2.2020219063518382e-07

@janden
Copy link

janden commented Jun 22, 2020

Thank you for looking into this, @hawkinsp. From our perspective the best thing would be to have the Jax FFT be more accurate (comparable to NumPy, SciPy, and PyFFTW) on the CPU (for both float32 and float64). Plugging into MKL, as you suggest, might be a good idea here.

If this is the Eigen FFT interface that Jax is using, it looks like it supports switching from the default backend (kissfft) to FFTW should be possible by setting a compiler flag.

@shoyer
Copy link
Member

shoyer commented Jul 2, 2020

Licensing might be the trickiest part here. FFTW is GPL and MKL is proprietary.

NumPy uses pocketfft these days. Writing a custom call in JAX to use pocketfft on CPU could be a good option -- or perhaps XLA CPU should use pocketfft.

@Joshuaalbert
Copy link
Contributor

I can also add that the radio astronomy community would be greatly interested if JAX fft on CPU would be both accurate and fast.

@mattjj Re: is the result the same as with tensorflow? Yes,

With JAX:
max:	 4.362646594533903e-08
mean:	 6.237288307614869e-09
min:	 0.0
With Tensorflow:
max:	 4.362646594533903e-08
mean:	 6.237288307614869e-09
min:	 0.0
numpy fft execution time [ms]:	 44.88363027572632
jax fft execution time [ms]:	 84.56079244613647
tensorflow fft execution time [ms]:	 84.12498950958252

@pacargile
Copy link

@mattjj Has there been any progress in understanding why Jax's fft is ~twice as slow as NumPy on a CPU? I second @Joshuaalbert comment stating that other divisions of astrophysics would also be very interested in a fast and accurate Jax FFT.

@hawkinsp
Copy link
Member

I think we're pretty clear on what to do here: replace the Eigen FFT on CPU with something else, probably PocketFFT, same as NumPy. We just need someone to actually do it!

@Joshuaalbert
Copy link
Contributor

@hawkinsp What about MKL's FFT? It's the fastest that I've seen. FFTW is currently what radio astronomers use, due to it's popularity. Some informative comparison of FFTs is here: https://github.com/project-gemmi/benchmarking-fft/

@shoyer
Copy link
Member

shoyer commented Oct 19, 2020

I think optionally using MKL could be viable, but MKL is closed-source software. At the very least, we want to preserve an open source option.

@mattjj
Copy link
Member

mattjj commented Oct 19, 2020

And FFTW is GPL :/

I think the main limiting factor here is just developer bandwidth on the JAX core team, where we have to balance a lot of considerations (code licenses, ensuring it works in OSS as well as internally at Google, etc).

Until we improve this, it might be useful to look at how you can rig up a call into any implementation you want by registering a custom backend-specific kernel with XLA. One example of how to do that is mpi4jax. You could also look at how JAX calls into LAPACK on CPU and cuSolver on GPU, e.g. starting at lapax.pyx for the CPU stuff.

@hawkinsp hawkinsp assigned hawkinsp and unassigned mattjj Oct 22, 2020
copybara-service bot pushed a commit that referenced this issue Oct 23, 2020
PocketFFT is the same FFT library used by NumPy (although we are using the C++ variant rather than the C variant.)

For the benchmark in #2952 on my workstation:

Before:
```
907.3490574884647
max:     4.362646594533903e-08
mean:    6.237288307614869e-09
min:     0.0
numpy fft execution time [ms]:   37.088446617126465
jax fft execution time [ms]:     74.93342399597168
```

After:
```
907.3490574884647
max:     1.9057386696477137e-12
mean:    3.9326737908882566e-13
min:     0.0
numpy fft execution time [ms]:   37.756404876708984
jax fft execution time [ms]:     28.128278255462646
```

Fixes #2952

PiperOrigin-RevId: 338530400
@hawkinsp
Copy link
Member

This issue should be fixed, but it requires a jaxlib rebuild. You can either build from source or wait for us to make a new jaxlib release. Hope that helps!

@hawkinsp
Copy link
Member

The updated FFT has been released as part of jaxlib 0.1.57. Hope that helps!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance make things lean and fast
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants