Skip to content

Commit

Permalink
feat: add ifftn to jax frontend along with the test (#28550)
Browse files Browse the repository at this point in the history
passing
  • Loading branch information
Medo072 committed Mar 23, 2024
1 parent 4539031 commit a3146df
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 0 deletions.
8 changes: 8 additions & 0 deletions ivy/functional/frontends/jax/numpy/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ def ifft2(a, s=None, axes=(-2, -1), norm=None):
return ivy.array(ivy.ifft2(a, s=s, dim=axes, norm=norm), dtype=ivy.dtype(a))


@with_unsupported_dtypes({"1.24.3 and below": ("complex64", "bfloat16")}, "numpy")
@to_ivy_arrays_and_back
def ifftn(a, s=None, axes=None, norm=None):
a = ivy.asarray(a, dtype=ivy.complex128)
a = ivy.ifftn(a, s=s, axes=axes, norm=norm)
return a


@to_ivy_arrays_and_back
@with_unsupported_dtypes({"1.25.2 and below": ("float16", "bfloat16")}, "numpy")
def rfft(a, n=None, axis=-1, norm=None):
Expand Down
29 changes: 29 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_numpy/test_fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.test_functional.test_experimental.test_nn.test_layers import (
_x_and_ifftn_jax,
)


# fft
Expand Down Expand Up @@ -242,6 +245,32 @@ def test_jax_numpy_ifft2(
)


@handle_frontend_test(
fn_tree="jax.numpy.fft.ifftn",
dtype_and_x=_x_and_ifftn_jax(),
)
def test_jax_numpy_ifftn(
dtype_and_x, backend_fw, frontend, test_flags, fn_tree, on_device
):
input_dtype, x, s, axes, norm = dtype_and_x

helpers.test_frontend_function(
input_dtypes=input_dtype,
frontend=frontend,
backend_to_test=backend_fw,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
test_values=True,
atol=1e-09,
rtol=1e-08,
a=x,
s=s,
axes=axes,
norm=norm,
)


# rfft
@handle_frontend_test(
fn_tree="jax.numpy.fft.rfft",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,45 @@ def _x_and_ifftn(draw):
return dtype, x, s, axes, norm


@st.composite
def _x_and_ifftn_jax(draw):
min_fft_points = 2
dtype = draw(helpers.get_dtypes("complex"))
x_dim = draw(
helpers.get_shape(
min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4
)
)
x = draw(
helpers.array_values(
dtype=dtype[0],
shape=tuple(x_dim),
min_value=-1e-10,
max_value=1e10,
)
)
axes = draw(
st.lists(
st.integers(0, len(x_dim) - 1),
min_size=1,
max_size=min(len(x_dim), 3),
unique=True,
)
)
norm = draw(st.sampled_from(["forward", "ortho", "backward"]))

# Shape for s can be larger, smaller or equal to the size of the input
# along the axes specified by axes.
# Here, we're generating a list of integers corresponding to each axis in axes.
s = draw(
st.lists(
st.integers(min_fft_points, 256), min_size=len(axes), max_size=len(axes)
)
)

return dtype, x, s, axes, norm


@st.composite
def _x_and_rfft(draw):
min_fft_points = 2
Expand Down

0 comments on commit a3146df

Please sign in to comment.