diff --git a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md index 4351eb74a2a0..c47a34bdc1d8 100644 --- a/jax/experimental/jax2tf/g3doc/no_xla_limitations.md +++ b/jax/experimental/jax2tf/g3doc/no_xla_limitations.md @@ -150,20 +150,29 @@ function `lax.reduce_window_p` with the following conditions: We provide partial support for all these ops, with the following limitations: -* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`. -* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`, - `np.complex64`, and `np.complex128` are not supported. -* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not - supported. -* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are - supported. -* We support at most 2 spatial dimension. -* Base dilations other than `(1,) * len(operand)` are not supported. -* `padding` should either be `VALID` or `SAME`. -* Using `lax.add` on TPU may give very large deviations. This is due to the way - the conversion is implemented (first take the average over the window and then - multiply by window size). This gives large deviations on TPU due to the fact - that it uses `bfloat16` for computations. +* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`. +* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`, + `np.complex64`, and `np.complex128` are not supported. +* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not + supported. +* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are + supported. +* We support at most 2 spatial dimension. +* Base dilations other than `(1,) * len(operand)` are not supported. +* `padding` should either be `VALID` or `SAME`. +* We compute `lax.reduce_window_sum_p` by calling `tf.nn.avg_pool` (through + `tf.nn.pool`), and then multiplying the result by + `np.prod(window_dimensions)`. If you are using an NN library that implements + `avg_pool` using `lax.reduce_window` (such as Flax's + [pooling.py](https://github.com/google/flax/blob/main/flax/linen/pooling.py)), + this is usually implemented by dividing the result with + `np.prod(window_dimensions)`. So when converting this function, the + resulting computation for `avg_pool` is `(tf.nn.avg_pool(xs) * + np.prod(window)) / np.prod(window)`. This is redundant and can be optimized. +* Using `lax.add` on TPU may give very large deviations. This is due to the + way the conversion is implemented (first take the average over the window + and then multiply by window size). This gives large deviations on TPU due to + the fact that it uses `bfloat16` for computations. We implement all reductions using the Tensorflow function [tf.nn.pool](https://www.tensorflow.org/api_docs/python/tf/nn/pool). diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index a56f1470a39c..1aef21d272ab 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -120,10 +120,11 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str: def _pad_spatial_dims(x, x_shape, padding): """Pads `x` using `padding`, which specifies padding for the spatial dimensions.""" - # Add empty padding for batch and feature dimensions. - no_pad = ((0, 0),) padding = tuple(padding) - padding = no_pad + padding + no_pad + if len(padding) == len(x_shape) - 2: + # If necessary, add empty padding for batch and feature dimensions. + no_pad = ((0, 0),) + padding = no_pad + padding + no_pad x = tf.pad(x, padding) assert len(x.shape) == len(padding) x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding)) @@ -517,11 +518,9 @@ def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int], tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False) -def _reduce_monoid(operand, window_dimensions, window_strides, padding, - base_dilation, window_dilation, computation_name, - _in_avals: Sequence[core.ShapedArray], - _out_aval: core.ShapedArray): - dtype = operand.dtype +def _validate_reduce_window_inputs(operand_shape, computation_name, dtype, + window_dimensions, window_strides, + base_dilation, window_dilation): if computation_name not in ["min", "max", "add"]: raise _reduce_error("Reduction function should be either min, max, or add.") if computation_name in ["min", "max"] and dtype in [ @@ -540,35 +539,117 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding, raise _reduce_error("Add pooling does not support operands of type " f"{dtype}") - # In presence of shape polymorphism, operand.shape may contain None. The - # actual dimension polynomial shapes are in _in_avals. - operand_shape = _in_avals[0].shape + if (len(operand_shape) != len(window_dimensions) != len(window_strides) != + len(window_dilation)): + raise _reduce_error("Input shapes, window dimensions, window stride " + "dimensions, and window dilation dimensions should " + "match.") + + has_only_spatial_dims = True + if len(operand_shape) > 4: + raise _reduce_error("Only 1D or 2D input are supported.") + if len(operand_shape) > 2: + # operand_shape = (batch, spatial_dims, ..., channel). + has_only_spatial_dims = False + + for name, value in [("window_dimensions", window_dimensions), + ("window_strides", window_strides), + ("window_dilation", window_dilation)]: + if value[0] != value[-1] != 1: + raise _reduce_error("Only 1D or 2D input are supported, expected " + f"{name}=(1, spatial_dims, ..., 1), but got " + f"{value}") if list(base_dilation) != [1] * len(operand_shape): # TODO(marcvanzee): Add support for base dilations. We can do this using # a scatter on operand. raise _reduce_error("Unimplemented support for base dilation.") + return has_only_spatial_dims + + +def _padding_reduce_window(operand, operand_shape, computation_name, + window_dimensions, window_strides, padding): padding_type = pads_to_padtype(operand_shape, window_dimensions, window_strides, padding) - if padding_type == "EXPLICIT": - # TODO(marcvanzee): Add support for explicit padding. This can be done - # similarly like we did for convolutions. - raise _reduce_error("Only 'VALID' and 'SAME' padding are currently " - "supported.") - - def tf_pool(op, pooling_type): - # Add batch and channel dimensions, these are expected by TF. - op = tf.reshape(op, (1,) + operand_shape + (1,)) - op = tf.nn.pool( - input=op, + + # https://github.com/google/jax/issues/11874. + needs_manual_padding = ( + padding_type == "SAME" and computation_name == "add" and + window_dimensions != [1] * len(operand_shape)) + + if needs_manual_padding or padding_type == "EXPLICIT": + operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding) + padding_type = "VALID" + + return operand, operand_shape, padding_type + + +def _reshape_reduce_window(operand, operand_shape, window_dimensions, + window_strides, window_dilation, *, + has_only_spatial_dims): + # Reshape inputs so they are accepted by tf.nn.pool, which expects batch and + # channel dimensions for operand but not for any of the other inputs. + if has_only_spatial_dims: # len(operand_shape) <= 2 + # Call eval_shape on a shape that may contain polynomials, otherwise TF does + # not know what to do with polynomials in the shape. + operand_shape = jax2tf._eval_shape(operand_shape) + # Add batch and channel dimensions to operand. + operand = tf.reshape(operand, (1,) + operand_shape + (1,)) + else: + # This branch assumes operand.shape = (batch, spatial_dims, ..., channel), + # and dimensions, strides, dilation are all (1, spatial_values, ..., 1). + # Input validation for this is done in _validate_reduce_window_inputs. + window_dimensions = window_dimensions[1:-1] + window_strides = window_strides[1:-1] + window_dilation = window_dilation[1:-1] + + return operand, window_dimensions, window_strides, window_dilation + + +def _reduce_monoid(operand, window_dimensions, window_strides, padding, + base_dilation, window_dilation, computation_name, + _in_avals: Sequence[core.ShapedArray], + _out_aval: core.ShapedArray): + dtype = operand.dtype + # In presence of shape polymorphism, operand.shape may contain None. The + # actual dimension polynomial shapes are in _in_avals. + operand_shape = _in_avals[0].shape + + # TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to + # Gather, to simplify function calls. + has_only_spatial_dims = _validate_reduce_window_inputs( + operand_shape, computation_name, dtype, window_dimensions, window_strides, + base_dilation, window_dilation) + + operand, operand_shape, padding_type = _padding_reduce_window( + operand, operand_shape, computation_name, window_dimensions, + window_strides, padding) + + operand, window_dimensions, window_strides, dilations = _reshape_reduce_window( + operand, + operand_shape, + window_dimensions, + window_strides, + window_dilation, + has_only_spatial_dims=has_only_spatial_dims) + + def tf_pool(inputs, pooling_type): + result = tf.nn.pool( + inputs, window_shape=window_dimensions, pooling_type=pooling_type, padding=padding_type, strides=window_strides, - dilations=window_dilation) - op = tf.reshape(op, jax2tf._aval_to_tf_shape(_out_aval)) - return op + dilations=dilations) + + if has_only_spatial_dims: + # If the input only had spatial dimensions we need to contract the batch + # and channel dimensions before returning the output. + result = tf.squeeze(result, [0, -1]) + + jax2tf._assert_matching_abstract_shape(result, _out_aval.shape) + return result negate = lambda x: tf.multiply(x, tf.constant(-1, dtype)) if computation_name == "max": @@ -577,8 +658,8 @@ def tf_pool(op, pooling_type): return negate(tf_pool(negate(operand), "MAX")) elif computation_name == "add": # TODO(marcvanzee): This may give very large deviations on TPU when using - # floats as inputs. We should think of a different implementation if users - # run into this often. + # floats as inputs. Alternatively, we could implement this using a + # convolution with an all-1's kernel. return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions)) diff --git a/jax/experimental/jax2tf/tests/jax2tf_limitations.py b/jax/experimental/jax2tf/tests/jax2tf_limitations.py index fae8e52ffa52..39c308b221cc 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_limitations.py +++ b/jax/experimental/jax2tf/tests/jax2tf_limitations.py @@ -124,25 +124,20 @@ def limitations_for_harness( # We keep here the explicit set of groups for which we don't have limitations harness_groups_no_limitations = { - "abs", "add", "add_any", "and", "atan2", - "bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil", - "clamp", "concatenate", "cos", "cosh", "complex", "conj", - "convert_element_type", - "cummax", "cummin", "device_put", "dynamic_slice", - "dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt", - "imag", - "iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not", - "or", "pad", "population_count", + "abs", "add", "add_any", "and", "atan2", "bitcast_convert_type", + "broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate", + "cos", "cosh", "complex", "conj", "convert_element_type", "cummax", + "cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp", + "eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le", + "lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count", "random_categorical", "random_split", "random_uniform", "random_randint", - "reduce", - "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", - "reduce_window_mul", "reduce_window_min", - "reduce_window_max", - "real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min", - "select_n", "select_and_scatter_add", - "shift_left", "shift_right_logical", "shift_right_arithmetic", "sign", - "sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub", - "tie_in", "transpose", "xor", "zeros_like" + "reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum", + "reduce_window_mul", "reduce_window_min", "reduce_window_max", "real", + "reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n", + "select_and_scatter_add", "shift_left", "shift_right_logical", + "shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt", + "squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor", + "zeros_like" } @classmethod @@ -910,6 +905,15 @@ def reduce_min(cls, harness: primitive_harness.Harness): @classmethod def reduce_window_add(cls, harness: primitive_harness.Harness): return [ + Jax2TfLimitation( + "Small deviations on GPU for large inputs and enable_xla=False", + dtypes=[np.float32], + devices="gpu", + modes=("eager", "graph", "compiled"), + expect_tf_error=False, + skip_comparison=False, + enabled=not harness.params["enable_xla"], + tol=3e-5), Jax2TfLimitation( "Large deviations on TPU for enable_xla=False", dtypes=[np.float16, np.float32], diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 14ad40f3717c..b152eed4e0ce 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -828,7 +828,7 @@ def _make_argminmax_harness(prim, for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned: _make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype) - # Some special cases, with equal elements and NaN + # Some special cases, with equal elements and NaN for name, operand in [ ("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)), ("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)), @@ -2488,19 +2488,22 @@ def requires_xla_for_reduce(name, dtype): requires_xla=True) # Validate window_dilation _make_reduce_window_harness("window_dilation", window_dilation=(1, 2)) -# Validate squeezing behavior and dimensions in tf.nn.max_pool -for shape, window_dimensions in [ - ((2,), (2,)), # 1 spatial dimension, left and right squeeze - ((2, 1), (2, 1)), # 1 spatial dimension, left squeeze - ((1, 2), (1, 2)), # 1 spatial dimension, right squeeze - ((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze - ((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze - ((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze - ((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze +# Validate batch and channel dimensions behavior. lax.reduce_window accepts +# inputs that either have or do not have batch and channel dimensions. +# N=batch, DHW=spatial, C=channel. +# Without XLA only supports 1D/2D reductions. +for shape, window_dimensions, requires_xla in [ + ((2,), (2,), False), # W + ((2, 1), (2, 1), False), # WC + ((1, 2), (1, 2), False), # NW + ((1, 2, 1), (1, 2, 1), False), # NWC + ((2, 4), (2, 2), False), # HW + ((1, 2, 4, 1), (1, 2, 2, 1), False), # NHWC + ((2, 4, 3), (2, 2, 2), True), # DHW + ((1, 4, 3, 2, 1), (1, 2, 2, 2, 1), True) # NDHWC ]: - requires_xla = len(shape) > 2 # Without XLA only supports 1D/2D reductions. _make_reduce_window_harness( - "squeeze_dim", + "batch_channel_dims", computation=lax.max, shape=shape, dtype=np.float32, @@ -2512,17 +2515,31 @@ def requires_xla_for_reduce(name, dtype): window_dimensions=window_dimensions, requires_xla=requires_xla) -# This corresponds to SAME padding. -_make_reduce_window_harness( - "same_padding", - shape=(112, 112), - init_value=-np.inf, - computation=lax.max, - window_dimensions=(3, 3), - window_strides=(2, 2), - padding="SAME") +for computation, id_value in [(lax.max, _get_max_identity(np.float32)), + (lax.min, _get_min_identity(np.float32)), + (lax.add, 0.)]: + _make_reduce_window_harness( + "same_padding", + shape=(112, 112), + init_value=id_value, + computation=computation, + window_dimensions=(3, 3), + window_strides=(2, 2), + padding="SAME") + +# A few additional test cases for manual padding, which is applied when calling +# reduce_window with lax.add, SAME padding and window_dimensions != (1, 1, ...). +for window_dimensions, window_strides in [((2, 2), (1, 1)), ((3, 3), (2, 2)), + ((13, 13), (5, 6))]: + _make_reduce_window_harness( + "manual_padding", + shape=(12, 12), + init_value=0., + computation=lax.add, + window_dimensions=window_dimensions, + window_strides=window_strides, + padding="SAME") -# b/240647139 _make_reduce_window_harness( "init_value_1d", shape=(1, 16000), diff --git a/jax/experimental/jax2tf/tests/primitives_test.py b/jax/experimental/jax2tf/tests/primitives_test.py index 9324baf96b3a..9f82d541eb55 100644 --- a/jax/experimental/jax2tf/tests/primitives_test.py +++ b/jax/experimental/jax2tf/tests/primitives_test.py @@ -102,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase): @primitive_harness.parameterized( primitive_harness.all_harnesses, include_jax_unimpl=False, - #one_containing="reduce_window_max", + #one_containing="reduce_window_add_same_padding", ) @jtu.ignore_warning( category=UserWarning, message="Using reduced precision for gradient.*") diff --git a/jax/experimental/jax2tf/tests/shape_poly_test.py b/jax/experimental/jax2tf/tests/shape_poly_test.py index 13a6dd4eec94..657ce2fa1c87 100644 --- a/jax/experimental/jax2tf/tests/shape_poly_test.py +++ b/jax/experimental/jax2tf/tests/shape_poly_test.py @@ -1163,7 +1163,8 @@ def _make_harness(group_name: str, name: str, match the expected exception string. enable_and_disable_xla=True means that we generate two harnesses, - one with enable_xla=False. + one with enable_xla=False and one with enable_xal=True. Otherwise we create + only one harness with enable_xla=True. """ if enable_and_disable_xla: return [ @@ -1189,7 +1190,8 @@ def _make_harness(group_name: str, name: str, dtype=np.float32, poly_axes=poly_axes, check_result=check_result, skip_jax_run=skip_jax_run, expect_error=expect_error, - tol=tol) + tol=tol, + **params) _f32 = np.float32 @@ -1637,13 +1639,26 @@ def _make_harness(group_name: str, name: str, lambda x: lax.reduce_window(x, np.array(1., _f32), lax.min, (2, 2), (1, 1), "VALID"), [RandArg((3, 8), _f32)], - poly_axes=[0]), + poly_axes=[0], + enable_and_disable_xla=True), _make_harness("reduce_window", "add", # x.shape = (b, 8) lambda x: lax.reduce_window(x, 0, lax.add, (2, 2), (1, 1), "VALID"), [RandArg((3, 8), _f32)], - poly_axes=[0]), + poly_axes=[0], + enable_and_disable_xla=True), + # https://github.com/google/jax/issues/11804 + # Use the reshape trick to simulate a polymorphic dimension of 16*b. + # (See test "conv_general_dilated.1d_1" above for more details.) + _make_harness("reduce_window", "add", + # x.shape = (1, 16*b, 1) + lambda x: lax.reduce_window( + jnp.reshape(x, (1, -1, 1)), + 0., lax.add, (1, 4, 1), (1, 2, 1), "SAME"), + [RandArg((1, 128, 16), _f32)], + poly_axes=[1], + enable_and_disable_xla=True), # TODO(necula): not yet supported, but also unlikely to come up. # _make_harness("random_uniform", "odd", # lambda key, a: jax.random.uniform(key, (2 * a.shape[0] + 1, a.shape[1]), @@ -1897,7 +1912,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase): # to parameterized below. @primitive_harness.parameterized( _flatten_harnesses(_POLY_SHAPE_TEST_HARNESSES), - #one_containing="reduce_window_add", + #one_containing="reduce_window_add_noxla_poly_axes=[1]", ) def test_prim(self, harness: Harness): _test_one_harness(self, harness) diff --git a/jax/experimental/jax2tf/tests/tf_test_util.py b/jax/experimental/jax2tf/tests/tf_test_util.py index b9bd78629cfd..990d416583fc 100644 --- a/jax/experimental/jax2tf/tests/tf_test_util.py +++ b/jax/experimental/jax2tf/tests/tf_test_util.py @@ -392,7 +392,8 @@ def CheckShapePolymorphism(self, f_jax: Callable, *, """ f_tf = jax2tf.convert(f_jax, polymorphic_shapes=polymorphic_shapes, enable_xla=enable_xla) - f_tf_func = tf.function(f_tf, autograph=False, input_signature=input_signature) + f_tf_func = tf.function( + f_tf, autograph=False, input_signature=input_signature) concrete_f_tf = f_tf_func.get_concrete_function(*input_signature) if expected_output_signature: # Strangely, output_shapes can be a single shape for a function with a