Skip to content

Commit

Permalink
Merge pull request #19378 from Micky774:fft_overflow
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604468627
  • Loading branch information
jax authors committed Feb 6, 2024
2 parents 155958b + 1c844ae commit fb6fa04
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
8 changes: 6 additions & 2 deletions jax/_src/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from jax import lax
from jax._src.lib import xla_client
from jax._src.util import safe_zip
from jax._src.numpy.util import check_arraylike, implements
from jax._src.numpy.util import check_arraylike, implements, promote_dtypes_inexact
from jax._src.numpy import lax_numpy as jnp
from jax._src.numpy import ufuncs, reductions
from jax._src.typing import Array, ArrayLike
Expand All @@ -32,7 +32,11 @@
def _fft_norm(s: Array, func_name: str, norm: str) -> Array:
if norm == "backward":
return jnp.array(1)
elif norm == "ortho":

# Avoid potential integer overflow
s, = promote_dtypes_inexact(s)

if norm == "ortho":
return ufuncs.sqrt(reductions.prod(s)) if func_name.startswith('i') else 1/ufuncs.sqrt(reductions.prod(s))
elif norm == "forward":
return reductions.prod(s) if func_name.startswith('i') else 1/reductions.prod(s)
Expand Down
18 changes: 18 additions & 0 deletions tests/fft_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from jax._src import dtypes
from jax._src import test_util as jtu
from jax._src.numpy.util import promote_dtypes_complex
from jax._src.numpy.fft import _fft_norm

config.parse_flags_with_absl()

Expand Down Expand Up @@ -445,5 +446,22 @@ def testIfftshift(self, shape, dtype, axes):
np_fn = lambda arg: np.fft.ifftshift(arg, axes=axes)
self._CheckAgainstNumpy(np_fn, jnp_fn, args_maker)

@jtu.sample_product(
norm=["ortho", "forward"],
func_name = ["fft", "ifft"],
dtype=jtu.dtypes.integer
)
def testFftnormOverflow(self, norm, func_name, dtype):
# non-regression test for gh-18453

shape = jnp.array([3] + [900] * 3, dtype=dtype)
jax_norm = _fft_norm(shape, func_name, norm)
np_norm = np.array(shape).prod(dtype=np.float64)
if norm == "ortho":
np_norm = np.sqrt(np_norm)
if func_name[0] != "i":
np_norm = np.reciprocal(np_norm)
self.assertArraysAllClose(jax_norm, np_norm, rtol=3e-8, check_dtypes=False)

if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit fb6fa04

Please sign in to comment.