diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index bb0dbae595b2..b89b29d12beb 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -30,20 +30,13 @@ from jax._src.lib.mlir.dialects import mhlo from jax._src.lib import xla_client from jax._src.lib import pocketfft +from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact __all__ = [ "fft", "fft_p", ] -def _promote_to_complex(arg): - dtype = dtypes.result_type(arg, np.complex64) - return lax.convert_element_type(arg, dtype) - -def _promote_to_real(arg): - dtype = dtypes.result_type(arg, np.float32) - return lax.convert_element_type(arg, dtype) - def _str_to_fft_type(s: str) -> xla_client.FftType: if s == "FFT": return xla_client.FftType.FFT @@ -68,9 +61,9 @@ def fft(x, fft_type: Union[xla_client.FftType, str], fft_lengths: Sequence[int]) if typ == xla_client.FftType.RFFT: if np.iscomplexobj(x): raise ValueError("only real valued inputs supported for rfft") - x = _promote_to_real(x) + x, = _promote_dtypes_inexact(x) else: - x = _promote_to_complex(x) + x, = _promote_dtypes_complex(x) if len(fft_lengths) == 0: # XLA FFT doesn't support 0-rank. return x @@ -137,7 +130,7 @@ def _irfft_transpose(t, fft_lengths): x = fft(t, xla_client.FftType.RFFT, fft_lengths) n = x.shape[-1] is_odd = fft_lengths[-1] % 2 - full = partial(lax.full_like, t, dtype=t.dtype) + full = partial(lax.full_like, t, dtype=x.dtype) mask = lax.concatenate( [full(1.0, shape=(1,)), full(2.0, shape=(n - 2 + is_odd,)), diff --git a/jax/_src/numpy/fft.py b/jax/_src/numpy/fft.py index 77be0587d668..4335c88ddefa 100644 --- a/jax/_src/numpy/fft.py +++ b/jax/_src/numpy/fft.py @@ -16,6 +16,7 @@ import operator import numpy as np +from jax import dtypes from jax import lax from jax._src.lib import xla_client from jax._src.util import safe_zip @@ -87,7 +88,7 @@ def _fft_core(func_name, fft_type, a, s, axes, norm): else: s = [a.shape[axis] for axis in axes] transformed = lax.fft(a, fft_type, tuple(s)) - transformed *= _fft_norm(jnp.array(s, dtype=transformed.real.dtype), func_name, norm) + transformed *= _fft_norm(jnp.array(s, dtype=transformed.dtype), func_name, norm) if orig_axes is not None: transformed = jnp.moveaxis(transformed, axes, orig_axes) @@ -199,6 +200,7 @@ def irfft2(a, s=None, axes=(-2,-1), norm=None): @_wraps(np.fft.fftfreq) def fftfreq(n, d=1.0): + dtype = dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.fftfreq only takes an int. " @@ -209,26 +211,27 @@ def fftfreq(n, d=1.0): "The d argument of jax.numpy.fft.fftfreq only takes a single value. " "Got d = %s." % list(d)) - k = jnp.zeros(n) + k = jnp.zeros(n, dtype=dtype) if n % 2 == 0: # k[0: n // 2 - 1] = jnp.arange(0, n // 2 - 1) - k = k.at[0: n // 2].set( jnp.arange(0, n // 2)) + k = k.at[0: n // 2].set( jnp.arange(0, n // 2, dtype=dtype)) # k[n // 2:] = jnp.arange(-n // 2, -1) - k = k.at[n // 2:].set( jnp.arange(-n // 2, 0)) + k = k.at[n // 2:].set( jnp.arange(-n // 2, 0, dtype=dtype)) else: # k[0: (n - 1) // 2] = jnp.arange(0, (n - 1) // 2) - k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1)) + k = k.at[0: (n - 1) // 2 + 1].set(jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype)) # k[(n - 1) // 2 + 1:] = jnp.arange(-(n - 1) // 2, -1) - k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0)) + k = k.at[(n - 1) // 2 + 1:].set(jnp.arange(-(n - 1) // 2, 0, dtype=dtype)) return k / (d * n) @_wraps(np.fft.rfftfreq) def rfftfreq(n, d=1.0): + dtype = dtypes.canonicalize_dtype(jnp.float_) if isinstance(n, (list, tuple)): raise ValueError( "The n argument of jax.numpy.fft.rfftfreq only takes an int. " @@ -240,10 +243,10 @@ def rfftfreq(n, d=1.0): "Got d = %s." % list(d)) if n % 2 == 0: - k = jnp.arange(0, n // 2 + 1) + k = jnp.arange(0, n // 2 + 1, dtype=dtype) else: - k = jnp.arange(0, (n - 1) // 2 + 1) + k = jnp.arange(0, (n - 1) // 2 + 1, dtype=dtype) return k / (d * n) diff --git a/jax/_src/numpy/util.py b/jax/_src/numpy/util.py index e79ef9498857..af88ffbdaf8d 100644 --- a/jax/_src/numpy/util.py +++ b/jax/_src/numpy/util.py @@ -275,6 +275,17 @@ def _promote_dtypes_inexact(*args): for x in args] +def _promote_dtypes_complex(*args): + """Convenience function to apply Numpy argument dtype promotion. + + Promotes arguments to a complex type.""" + to_dtype, weak_type = dtypes._lattice_result_type(*args) + to_dtype = dtypes.canonicalize_dtype(to_dtype) + to_dtype_complex = dtypes._to_complex_dtype(to_dtype) + return [lax_internal._convert_element_type(x, to_dtype_complex, weak_type) + for x in args] + + def _complex_elem_type(dtype): """Returns the float type of the real/imaginary parts of a complex dtype.""" return np.abs(np.zeros((), dtype)).dtype diff --git a/jax/_src/scipy/fft.py b/jax/_src/scipy/fft.py index 9ec80455c68e..459e3e6f1b77 100644 --- a/jax/_src/scipy/fft.py +++ b/jax/_src/scipy/fft.py @@ -17,9 +17,10 @@ import scipy.fftpack as osp_fft # TODO use scipy.fft once scipy>=1.4.0 is used from jax import lax, numpy as jnp from jax._src.util import canonicalize_axis -from jax._src.numpy.util import _wraps +from jax._src.numpy.util import _wraps, _promote_dtypes_complex def _W4(N, k): + N, k = _promote_dtypes_complex(N, k) return jnp.exp(-.5j * jnp.pi * k / N) def _dct_interleave(x, axis): @@ -49,7 +50,7 @@ def dct(x, type=2, n=None, axis=-1, norm=None): N = x.shape[axis] v = _dct_interleave(x, axis) V = jnp.fft.fft(v, axis=axis) - k = lax.expand_dims(jnp.arange(N), [a for a in range(x.ndim) if a != axis]) + k = lax.expand_dims(jnp.arange(N, dtype=V.real.dtype), [a for a in range(x.ndim) if a != axis]) out = V * _W4(N, k) out = 2 * out.real if norm == 'ortho': @@ -62,8 +63,10 @@ def _dct2(x, axes, norm): N1, N2 = x.shape[axis1], x.shape[axis2] v = _dct_interleave(_dct_interleave(x, axis1), axis2) V = jnp.fft.fftn(v, axes=axes) - k1 = lax.expand_dims(jnp.arange(N1), [a for a in range(x.ndim) if a != axis1]) - k2 = lax.expand_dims(jnp.arange(N2), [a for a in range(x.ndim) if a != axis2]) + k1 = lax.expand_dims(jnp.arange(N1, dtype=V.dtype), + [a for a in range(x.ndim) if a != axis1]) + k2 = lax.expand_dims(jnp.arange(N2, dtype=V.dtype), + [a for a in range(x.ndim) if a != axis2]) out = _W4(N1, k1) * (_W4(N2, k2) * V + _W4(N2, -k2) * jnp.roll(jnp.flip(V, axis=axis2), shift=1, axis=axis2)) out = 2 * out.real if norm == 'ortho': diff --git a/tests/fft_test.py b/tests/fft_test.py index 742ade9d7604..5dd539c7ccda 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -23,7 +23,9 @@ import jax from jax import lax from jax import numpy as jnp +from jax._src import dtypes from jax._src import test_util as jtu +from jax._src.numpy.util import _promote_dtypes_complex from jax.config import config config.parse_flags_with_absl() @@ -111,7 +113,7 @@ def testLaxFftAcceptsStringTypes(self): def testLaxIrfftDoesNotMutateInputs(self, dtype): if dtype == np.float64 and not config.x64_enabled: raise self.skipTest("float64 requires jax_enable_x64=true") - x = jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtype) * (1+1j) + x = (1 + 1j) * jnp.array([[1.0, 2.0], [3.0, 4.0]], dtype=dtypes._to_complex_dtype(dtype)) y = np.asarray(jnp.fft.irfft2(x)) z = np.asarray(jnp.fft.irfft2(x)) self.assertAllClose(y, z) @@ -157,7 +159,9 @@ def build_matrix(linear_func, size): return jax.vmap(linear_func)(jnp.eye(size, size)) def func(x): - return jnp.fft.irfft(jnp.concatenate([jnp.zeros(1), x[:2] + 1j*x[2:]])) + x, = _promote_dtypes_complex(x) + return jnp.fft.irfft(jnp.concatenate([jnp.zeros_like(x, shape=1), + x[:2] + 1j*x[2:]])) def func_transpose(x): return jax.linear_transpose(func, x)(x)[0]