From 5d1ed558ea6559363bb56374025e1f97d8540fe9 Mon Sep 17 00:00:00 2001 From: John QiangZhang Date: Sat, 10 Jun 2023 23:20:19 -0700 Subject: [PATCH] Fix jax2tf graph_serialization irfft length issue. PiperOrigin-RevId: 539405879 --- jax/experimental/jax2tf/jax2tf.py | 25 +++++++++---------- .../jax2tf/tests/primitive_harness.py | 9 ++++--- .../jax2tf/tests/shape_poly_test.py | 2 +- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index b4d8f0fe3331..222d17ed55aa 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -3066,26 +3066,25 @@ def _fft(x, *, fft_type, fft_lengths, _in_avals: Sequence[core.ShapedArray], _out_aval: core.ShapedArray): FFT, IFFT, RFFT, IRFFT = list(map(xla_client.FftType, [0, 1, 2, 3])) - x_aval, = _in_avals - x_shape = x_aval.shape - if fft_type == IRFFT: - expected_lengths = x_shape[-len(fft_lengths):-1] + ((x_shape[-1] - 1) * 2,) - else: - expected_lengths = x_shape[-len(fft_lengths):] - if expected_lengths != fft_lengths: - raise NotImplementedError( - f"Unsupported {fft_lengths=} for {fft_type=} of " - f"array with shape={x.shape}.") tf_funcs = { FFT: [tf.signal.fft, tf.signal.fft2d, tf.signal.fft3d], IFFT: [tf.signal.ifft, tf.signal.ifft2d, tf.signal.ifft3d], RFFT: [tf.signal.rfft, tf.signal.rfft2d, tf.signal.rfft3d], IRFFT: [tf.signal.irfft, tf.signal.irfft2d, tf.signal.irfft3d] } - res = tf_funcs[fft_type][len(fft_lengths) - 1](x) + tf_func = tf_funcs[fft_type][len(fft_lengths) - 1] + if fft_type in (RFFT, IRFFT): + # https://www.tensorflow.org/api_docs/python/tf/signal/irfft + # Here we only set `fft_lengths` argument for non-default value. + (x_aval,) = _in_avals + x_shape = x_aval.shape + expected_lengths = x_shape[-len(fft_lengths) : -1] + ( + (x_shape[-1] - 1) * 2, + ) + if fft_lengths != expected_lengths: + tf_func = partial(tf_func, fft_length=_eval_shape(fft_lengths)) + res = tf_func(x) return _ensure_tf_shape_if_dynamic(res, _aval_to_tf_shape(_out_aval)) - - tf_impl_with_avals[lax.fft_p] = _fft diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 2200dc95b414..4134aeb8d37c 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -1764,10 +1764,11 @@ def _fft_rng_factory(dtype): for dtype in (jtu.dtypes.floating if fft_type == xla_client.FftType.RFFT else jtu.dtypes.complex): shape = (14, 15, 16, 17) - for fft_lengths in [ - (shape[-1],) if fft_type != xla_client.FftType.IRFFT else - ((shape[-1] - 1) * 2,) - ]: + if fft_type != xla_client.FftType.IRFFT: + fft_lengths_list = [ (shape[-1],) ] + else: + fft_lengths_list = [ ((shape[-1] - 1) * 2,), (shape[-1] * 2 - 1,) ] + for fft_lengths in fft_lengths_list: _make_fft_harness( "dtypes", shape=shape, diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 9e8401be2371..5f30a85caae6 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -2769,7 +2769,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): ) def test_harness(self, harness: PolyHarness): # Exclude some harnesses that are known to fail for native serialization - # FOR GRAPH SERIALIZATION + # FOR NATIVE SERIALIZATION if config.jax2tf_default_native_serialization: if not harness.enable_xla: raise unittest.SkipTest("disabled for native_serialization and enable_xla=False")