Skip to content

Commit

Permalink
Check dtypes in fft_p's abstract eval rule.
Browse files Browse the repository at this point in the history
In particular, this catches a bad error when a bfloat16 is passed to rfft.
  • Loading branch information
hawkinsp committed Oct 6, 2023
1 parent f4bb1c0 commit 4e1b8fc
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 16 deletions.
8 changes: 8 additions & 0 deletions jax/_src/lax/fft.py
Expand Up @@ -82,20 +82,28 @@ def fft_abstract_eval(x, fft_type, fft_lengths):
raise ValueError(f"FFT input shape {x.shape} must have at least as many "
f"input dimensions as fft_lengths {fft_lengths}.")
if fft_type == xla_client.FftType.RFFT:
if x.dtype not in (np.float32, np.float64):
raise ValueError(f"RFFT input must be float32 or float64, got {x.dtype}")
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"RFFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
shape = (x.shape[:-len(fft_lengths)] + fft_lengths[:-1]
+ (fft_lengths[-1] // 2 + 1,))
dtype = _complex_dtype(x.dtype)
elif fft_type == xla_client.FftType.IRFFT:
if not np.issubdtype(x.dtype, np.complexfloating):
raise ValueError("IRFFT input must be complex64 or complex128, got "
f"{x.dtype}")
if x.shape[-len(fft_lengths):-1] != fft_lengths[:-1]:
raise ValueError(f"IRFFT input shape {x.shape} minor dimensions must "
"be equal to all except the last fft_length, got "
f"{fft_lengths=}")
shape = x.shape[:-len(fft_lengths)] + fft_lengths
dtype = _real_dtype(x.dtype)
else:
if not np.issubdtype(x.dtype, np.complexfloating):
raise ValueError("FFT input must be complex64 or complex128, got "
f"{x.dtype}")
if x.shape[-len(fft_lengths):] != fft_lengths:
raise ValueError(f"FFT input shape {x.shape} minor dimensions must "
f"be equal to fft_lengths {fft_lengths}")
Expand Down
36 changes: 20 additions & 16 deletions tests/fft_test.py
Expand Up @@ -102,26 +102,30 @@ def testLaxFftAcceptsStringTypes(self):
lax.fft(x, "fft", fft_lengths=(10,)))

def testLaxFftErrors(self):
with self.assertRaises(
ValueError,
msg="FFT input shape (14, 15) must have at least as many input "
"dimensions as fft_lengths (4, 5, 6)"):
with self.assertRaisesRegex(
ValueError,
r"FFT input shape \(14, 15\) must have at least as many input "
r"dimensions as fft_lengths \(4, 5, 6\)"):
lax.fft(np.ones((14, 15)), fft_type="fft", fft_lengths=(4, 5, 6))
with self.assertRaises(
ValueError,
msg="FFT input shape (14, 15) minor dimensions must be equal to "
"fft_lengths (17,)"):
with self.assertRaisesRegex(
ValueError,
r"FFT input shape \(14, 15\) minor dimensions must be equal to "
r"fft_lengths \(17,\)"):
lax.fft(np.ones((14, 15)), fft_type="fft", fft_lengths=(17,))
with self.assertRaises(
ValueError,
msg="RFFT input shape (14, 15) minor dimensions must be equal to "
"fft_lengths (14, 15,)"):
with self.assertRaisesRegex(
ValueError,
r"RFFT input shape \(2, 14, 15\) minor dimensions must be equal to "
r"fft_lengths \(14, 12\)"):
lax.fft(np.ones((2, 14, 15)), fft_type="rfft", fft_lengths=(14, 12))
with self.assertRaises(
ValueError,
msg="IRFFT input shape (14, 15) minor dimensions must be equal to "
"all except the last fft_length, got fft_lengths=(14, 15,)"):
with self.assertRaisesRegex(
ValueError,
r"IRFFT input shape \(2, 14, 15\) minor dimensions must be equal to "
r"all except the last fft_length, got fft_lengths=\(13, 15\)"):
lax.fft(np.ones((2, 14, 15)), fft_type="irfft", fft_lengths=(13, 15))
with self.assertRaisesRegex(
ValueError, "RFFT input must be float32 or float64, got bfloat16"):
lax.fft(np.ones((14, 15), jnp.bfloat16), fft_type="rfft",
fft_lengths=(5, 6))

@parameterized.parameters((np.float32,), (np.float64,))
def testLaxIrfftDoesNotMutateInputs(self, dtype):
Expand Down

0 comments on commit 4e1b8fc

Please sign in to comment.