Skip to content

Commit

Permalink
Fix jax2tf graph_serialization irfft length issue.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 539405879
  • Loading branch information
maxwillzq authored and jax authors committed Jun 11, 2023
1 parent 1a7336d commit 5d1ed55
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 18 deletions.
25 changes: 12 additions & 13 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -3066,26 +3066,25 @@ def _fft(x, *, fft_type, fft_lengths,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3]))
x_aval, = _in_avals
x_shape = x_aval.shape
if fft_type == IRFFT:
expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,)
else:
expected_lengths = x_shape[-len(fft_lengths):]
if expected_lengths != fft_lengths:
raise NotImplementedError(
f"Unsupported {fft_lengths=} for {fft_type=} of "
f"array with shape={x.shape}.")
tf_funcs = {
FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d],
IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d],
RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d],
IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d]
}
res = tf_funcs[fft_type][len(fft_lengths) - 1](x)
tf_func = tf_funcs[fft_type][len(fft_lengths) - 1]
if fft_type in (RFFT, IRFFT):
# https://www.tensorflow.org/api_docs/python/tf/signal/irfft
# Here we only set `fft_lengths` argument for non-default value.
(x_aval,) = _in_avals
x_shape = x_aval.shape
expected_lengths = x_shape[-len(fft_lengths) : -1] + (
(x_shape[-1] - 1) * 2,
)
if fft_lengths != expected_lengths:
tf_func = partial(tf_func, fft_length=_eval_shape(fft_lengths))
res = tf_func(x)
return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval))


tf_impl_with_avals[lax.fft_p] = _fft


Expand Down
9 changes: 5 additions & 4 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -1764,10 +1764,11 @@ def _fft_rng_factory(dtype):
for dtype in (jtu.dtypes.floating
if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex):
shape = (14, 15, 16, 17)
for fft_lengths in [
(shape[-1],) if fft_type != xla_client.FftType.IRFFT else
((shape[-1] - 1) * 2,)
]:
if fft_type != xla_client.FftType.IRFFT:
fft_lengths_list = [ (shape[-1],) ]
else:
fft_lengths_list = [ ((shape[-1] - 1) * 2,), (shape[-1] * 2 - 1,) ]
for fft_lengths in fft_lengths_list:
_make_fft_harness(
"dtypes",
shape=shape,
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/shape_poly_test.py
Expand Up @@ -2769,7 +2769,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
)
def test_harness(self, harness: PolyHarness):
# Exclude some harnesses that are known to fail for native serialization
# FOR GRAPH SERIALIZATION
# FOR NATIVE SERIALIZATION
if config.jax2tf_default_native_serialization:
if not harness.enable_xla:
raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")
Expand Down

0 comments on commit 5d1ed55

Please sign in to comment.