diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index da32a6b17f2b..ed14db5a1876 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -82,6 +82,8 @@ def fft_abstract_eval(x, fft_type, fft_lengths): raise ValueError(f"FFT input shape {x.shape} must have at least as many " f"input dimensions as fft_lengths {fft_lengths}.") if fft_type == xla_client.FftType.RFFT: + if x.dtype not in (np.float32, np.float64): + raise ValueError(f"RFFT input must be float32 or float64, got {x.dtype}") if x.shape[-len(fft_lengths):] != fft_lengths: raise ValueError(f"RFFT input shape {x.shape} minor dimensions must " f"be equal to fft_lengths {fft_lengths}") @@ -89,6 +91,9 @@ def fft_abstract_eval(x, fft_type, fft_lengths): + (fft_lengths[-1] // 2 + 1,)) dtype = _complex_dtype(x.dtype) elif fft_type == xla_client.FftType.IRFFT: + if not np.issubdtype(x.dtype, np.complexfloating): + raise ValueError("IRFFT input must be complex64 or complex128, got " + f"{x.dtype}") if x.shape[-len(fft_lengths):-1] != fft_lengths[:-1]: raise ValueError(f"IRFFT input shape {x.shape} minor dimensions must " "be equal to all except the last fft_length, got " @@ -96,6 +101,9 @@ def fft_abstract_eval(x, fft_type, fft_lengths): shape = x.shape[:-len(fft_lengths)] + fft_lengths dtype = _real_dtype(x.dtype) else: + if not np.issubdtype(x.dtype, np.complexfloating): + raise ValueError("FFT input must be complex64 or complex128, got " + f"{x.dtype}") if x.shape[-len(fft_lengths):] != fft_lengths: raise ValueError(f"FFT input shape {x.shape} minor dimensions must " f"be equal to fft_lengths {fft_lengths}") diff --git a/tests/fft_test.py b/tests/fft_test.py index c4e808fd24be..b4cddee3b716 100644 --- a/tests/fft_test.py +++ b/tests/fft_test.py @@ -102,26 +102,30 @@ def testLaxFftAcceptsStringTypes(self): lax.fft(x, "fft", fft_lengths=(10,))) def testLaxFftErrors(self): - with self.assertRaises( - ValueError, - msg="FFT input shape (14, 15) must have at least as many input " - "dimensions as fft_lengths (4, 5, 6)"): + with self.assertRaisesRegex( + ValueError, + r"FFT input shape \(14, 15\) must have at least as many input " + r"dimensions as fft_lengths \(4, 5, 6\)"): lax.fft(np.ones((14, 15)), fft_type="fft", fft_lengths=(4, 5, 6)) - with self.assertRaises( - ValueError, - msg="FFT input shape (14, 15) minor dimensions must be equal to " - "fft_lengths (17,)"): + with self.assertRaisesRegex( + ValueError, + r"FFT input shape \(14, 15\) minor dimensions must be equal to " + r"fft_lengths \(17,\)"): lax.fft(np.ones((14, 15)), fft_type="fft", fft_lengths=(17,)) - with self.assertRaises( - ValueError, - msg="RFFT input shape (14, 15) minor dimensions must be equal to " - "fft_lengths (14, 15,)"): + with self.assertRaisesRegex( + ValueError, + r"RFFT input shape \(2, 14, 15\) minor dimensions must be equal to " + r"fft_lengths \(14, 12\)"): lax.fft(np.ones((2, 14, 15)), fft_type="rfft", fft_lengths=(14, 12)) - with self.assertRaises( - ValueError, - msg="IRFFT input shape (14, 15) minor dimensions must be equal to " - "all except the last fft_length, got fft_lengths=(14, 15,)"): + with self.assertRaisesRegex( + ValueError, + r"IRFFT input shape \(2, 14, 15\) minor dimensions must be equal to " + r"all except the last fft_length, got fft_lengths=\(13, 15\)"): lax.fft(np.ones((2, 14, 15)), fft_type="irfft", fft_lengths=(13, 15)) + with self.assertRaisesRegex( + ValueError, "RFFT input must be float32 or float64, got bfloat16"): + lax.fft(np.ones((14, 15), jnp.bfloat16), fft_type="rfft", + fft_lengths=(5, 6)) @parameterized.parameters((np.float32,), (np.float64,)) def testLaxIrfftDoesNotMutateInputs(self, dtype):