Skip to content

Commit

Permalink
[jax2tf] lax.reduce_window (enable_xla=False): bug fix and improvements.
Browse files Browse the repository at this point in the history
* Fixes #11804: we only supported `lax.reduce_window` without batch and channel dimensions, which is wrong. This is supported, and in fact something that most users use (this case is actually not explained in the [operational semantics for XLA::ReduceWindow](https://www.tensorflow.org/xla/operation_semantics#reducewindow)). I have fixed this and clarified a number of test cases with batch and channel dimensions.

* Also, @sdenton4 gave a failing example in a Colab using polymorphic dimensions. I've added this as a test case to make sure it works now.

* Adds support for explicit padding using the existing padding logic from convolutions.

* Fixes #11874: we were not handling SAME padding for `lax.add` correctly, since we used `tf.nn.avg_pool`, which does not include non-padding tokens (see issue for more details). I resolved it by adding manual padding and added some additional tests for this.

* Ensures we call eval_shape on a shape containing polynomials before calling a TF op.

* Fixes #11929 (comment): we weren't running any of the shape_poly_test.py tests for `enable_xla=False`.

PiperOrigin-RevId: 467879449
  • Loading branch information
marcvanzee authored and jax authors committed Aug 16, 2022
1 parent da168a1 commit df5f3c5
Show file tree
Hide file tree
Showing 7 changed files with 215 additions and 88 deletions.
37 changes: 23 additions & 14 deletions jax/experimental/jax2tf/g3doc/no_xla_limitations.md
Expand Up @@ -150,20 +150,29 @@ function `lax.reduce_window_p` with the following conditions:

We provide partial support for all these ops, with the following limitations:

* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
`np.complex64`, and `np.complex128` are not supported.
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
supported.
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
supported.
* We support at most 2 spatial dimension.
* Base dilations other than `(1,) * len(operand)` are not supported.
* `padding` should either be `VALID` or `SAME`.
* Using `lax.add` on TPU may give very large deviations. This is due to the way
the conversion is implemented (first take the average over the window and then
multiply by window size). This gives large deviations on TPU due to the fact
that it uses `bfloat16` for computations.
* `computation` should be one of `lax.min`, `lax.max`, or `lax.add`.
* For `lax.min` and `lax.max`, dtypes `np.bool`, `np.uint32`, `np.uint64`,
`np.complex64`, and `np.complex128` are not supported.
* Additionally, for `lax.min`, dtypes `np.uint8` and `np.uint16` are not
supported.
* For `lax.add`, only dtypes `np.float16`, `np.float32`, and `np.float64` are
supported.
* We support at most 2 spatial dimension.
* Base dilations other than `(1,) * len(operand)` are not supported.
* `padding` should either be `VALID` or `SAME`.
* We compute `lax.reduce_window_sum_p` by calling `tf.nn.avg_pool` (through
`tf.nn.pool`), and then multiplying the result by
`np.prod(window_dimensions)`. If you are using an NN library that implements
`avg_pool` using `lax.reduce_window` (such as Flax's
[pooling.py](https://github.com/google/flax/blob/main/flax/linen/pooling.py)),
this is usually implemented by dividing the result with
`np.prod(window_dimensions)`. So when converting this function, the
resulting computation for `avg_pool` is `(tf.nn.avg_pool(xs) *
np.prod(window)) / np.prod(window)`. This is redundant and can be optimized.
* Using `lax.add` on TPU may give very large deviations. This is due to the
way the conversion is implemented (first take the average over the window
and then multiply by window size). This gives large deviations on TPU due to
the fact that it uses `bfloat16` for computations.

We implement all reductions using the Tensorflow function
[tf.nn.pool](https://www.tensorflow.org/api_docs/python/tf/nn/pool).
Expand Down
135 changes: 108 additions & 27 deletions jax/experimental/jax2tf/impl_no_xla.py
Expand Up @@ -120,10 +120,11 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:

def _pad_spatial_dims(x, x_shape, padding):
"""Pads `x` using `padding`, which specifies padding for the spatial dimensions."""
# Add empty padding for batch and feature dimensions.
no_pad = ((0, 0),)
padding = tuple(padding)
padding = no_pad + padding + no_pad
if len(padding) == len(x_shape) - 2:
# If necessary, add empty padding for batch and feature dimensions.
no_pad = ((0, 0),)
padding = no_pad + padding + no_pad
x = tf.pad(x, padding)
assert len(x.shape) == len(padding)
x_shape = tuple(p0 + xs + p1 for xs, (p0, p1) in zip(x_shape, padding))
Expand Down Expand Up @@ -517,11 +518,9 @@ def _argminmax(is_min: bool, operand: TfVal, axes: Sequence[int],
tf_impl_no_xla[lax.argmax_p] = partial(_argminmax, False)


def _reduce_monoid(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation, computation_name,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
dtype = operand.dtype
def _validate_reduce_window_inputs(operand_shape, computation_name, dtype,
window_dimensions, window_strides,
base_dilation, window_dilation):
if computation_name not in ["min", "max", "add"]:
raise _reduce_error("Reduction function should be either min, max, or add.")
if computation_name in ["min", "max"] and dtype in [
Expand All @@ -540,35 +539,117 @@ def _reduce_monoid(operand, window_dimensions, window_strides, padding,
raise _reduce_error("Add pooling does not support operands of type "
f"{dtype}")

# In presence of shape polymorphism, operand.shape may contain None. The
# actual dimension polynomial shapes are in _in_avals.
operand_shape = _in_avals[0].shape
if (len(operand_shape) != len(window_dimensions) != len(window_strides) !=
len(window_dilation)):
raise _reduce_error("Input shapes, window dimensions, window stride "
"dimensions, and window dilation dimensions should "
"match.")

has_only_spatial_dims = True
if len(operand_shape) > 4:
raise _reduce_error("Only 1D or 2D input are supported.")
if len(operand_shape) > 2:
# operand_shape = (batch, spatial_dims, ..., channel).
has_only_spatial_dims = False

for name, value in [("window_dimensions", window_dimensions),
("window_strides", window_strides),
("window_dilation", window_dilation)]:
if value[0] != value[-1] != 1:
raise _reduce_error("Only 1D or 2D input are supported, expected "
f"{name}=(1, spatial_dims, ..., 1), but got "
f"{value}")

if list(base_dilation) != [1] * len(operand_shape):
# TODO(marcvanzee): Add support for base dilations. We can do this using
# a scatter on operand.
raise _reduce_error("Unimplemented support for base dilation.")

return has_only_spatial_dims


def _padding_reduce_window(operand, operand_shape, computation_name,
window_dimensions, window_strides, padding):
padding_type = pads_to_padtype(operand_shape, window_dimensions,
window_strides, padding)
if padding_type == "EXPLICIT":
# TODO(marcvanzee): Add support for explicit padding. This can be done
# similarly like we did for convolutions.
raise _reduce_error("Only 'VALID' and 'SAME' padding are currently "
"supported.")

def tf_pool(op, pooling_type):
# Add batch and channel dimensions, these are expected by TF.
op = tf.reshape(op, (1,) + operand_shape + (1,))
op = tf.nn.pool(
input=op,

# https://github.com/google/jax/issues/11874.
needs_manual_padding = (
padding_type == "SAME" and computation_name == "add" and
window_dimensions != [1] * len(operand_shape))

if needs_manual_padding or padding_type == "EXPLICIT":
operand, operand_shape = _pad_spatial_dims(operand, operand_shape, padding)
padding_type = "VALID"

return operand, operand_shape, padding_type


def _reshape_reduce_window(operand, operand_shape, window_dimensions,
window_strides, window_dilation, *,
has_only_spatial_dims):
# Reshape inputs so they are accepted by tf.nn.pool, which expects batch and
# channel dimensions for operand but not for any of the other inputs.
if has_only_spatial_dims: # len(operand_shape) <= 2
# Call eval_shape on a shape that may contain polynomials, otherwise TF does
# not know what to do with polynomials in the shape.
operand_shape = jax2tf._eval_shape(operand_shape)
# Add batch and channel dimensions to operand.
operand = tf.reshape(operand, (1,) + operand_shape + (1,))
else:
# This branch assumes operand.shape = (batch, spatial_dims, ..., channel),
# and dimensions, strides, dilation are all (1, spatial_values, ..., 1).
# Input validation for this is done in _validate_reduce_window_inputs.
window_dimensions = window_dimensions[1:-1]
window_strides = window_strides[1:-1]
window_dilation = window_dilation[1:-1]

return operand, window_dimensions, window_strides, window_dilation


def _reduce_monoid(operand, window_dimensions, window_strides, padding,
base_dilation, window_dilation, computation_name,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray):
dtype = operand.dtype
# In presence of shape polymorphism, operand.shape may contain None. The
# actual dimension polynomial shapes are in _in_avals.
operand_shape = _in_avals[0].shape

# TODO(marcvanzee): Put reduce_window arguments into dataclass, similar to
# Gather, to simplify function calls.
has_only_spatial_dims = _validate_reduce_window_inputs(
operand_shape, computation_name, dtype, window_dimensions, window_strides,
base_dilation, window_dilation)

operand, operand_shape, padding_type = _padding_reduce_window(
operand, operand_shape, computation_name, window_dimensions,
window_strides, padding)

operand, window_dimensions, window_strides, dilations = _reshape_reduce_window(
operand,
operand_shape,
window_dimensions,
window_strides,
window_dilation,
has_only_spatial_dims=has_only_spatial_dims)

def tf_pool(inputs, pooling_type):
result = tf.nn.pool(
inputs,
window_shape=window_dimensions,
pooling_type=pooling_type,
padding=padding_type,
strides=window_strides,
dilations=window_dilation)
op = tf.reshape(op, jax2tf._aval_to_tf_shape(_out_aval))
return op
dilations=dilations)

if has_only_spatial_dims:
# If the input only had spatial dimensions we need to contract the batch
# and channel dimensions before returning the output.
result = tf.squeeze(result, [0, -1])

jax2tf._assert_matching_abstract_shape(result, _out_aval.shape)
return result

negate = lambda x: tf.multiply(x, tf.constant(-1, dtype))
if computation_name == "max":
Expand All @@ -577,8 +658,8 @@ def tf_pool(op, pooling_type):
return negate(tf_pool(negate(operand), "MAX"))
elif computation_name == "add":
# TODO(marcvanzee): This may give very large deviations on TPU when using
# floats as inputs. We should think of a different implementation if users
# run into this often.
# floats as inputs. Alternatively, we could implement this using a
# convolution with an all-1's kernel.
return tf.multiply(tf_pool(operand, "AVG"), np.prod(window_dimensions))


Expand Down
40 changes: 22 additions & 18 deletions jax/experimental/jax2tf/tests/jax2tf_limitations.py
Expand Up @@ -124,25 +124,20 @@ def limitations_for_harness(

# We keep here the explicit set of groups for which we don't have limitations
harness_groups_no_limitations = {
"abs", "add", "add_any", "and", "atan2",
"bitcast_convert_type", "broadcast", "broadcast_in_dim", "cbrt", "ceil",
"clamp", "concatenate", "cos", "cosh", "complex", "conj",
"convert_element_type",
"cummax", "cummin", "device_put", "dynamic_slice",
"dynamic_update_slice", "exp", "eq", "floor", "gather", "ge", "gt",
"imag",
"iota", "is_finite", "le", "lt", "log", "mul", "ne", "neg", "not",
"or", "pad", "population_count",
"abs", "add", "add_any", "and", "atan2", "bitcast_convert_type",
"broadcast", "broadcast_in_dim", "cbrt", "ceil", "clamp", "concatenate",
"cos", "cosh", "complex", "conj", "convert_element_type", "cummax",
"cummin", "device_put", "dynamic_slice", "dynamic_update_slice", "exp",
"eq", "floor", "gather", "ge", "gt", "imag", "iota", "is_finite", "le",
"lt", "log", "mul", "ne", "neg", "not", "or", "pad", "population_count",
"random_categorical", "random_split", "random_uniform", "random_randint",
"reduce",
"reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min",
"reduce_window_max",
"real", "reshape", "rev", "rsqrt", "scatter_max", "scatter_min",
"select_n", "select_and_scatter_add",
"shift_left", "shift_right_logical", "shift_right_arithmetic", "sign",
"sin", "sinh", "slice", "sqrt", "squeeze", "stop_gradient", "sub",
"tie_in", "transpose", "xor", "zeros_like"
"reduce", "reduce_and", "reduce_prod", "reduce_or", "reduce_sum",
"reduce_window_mul", "reduce_window_min", "reduce_window_max", "real",
"reshape", "rev", "rsqrt", "scatter_max", "scatter_min", "select_n",
"select_and_scatter_add", "shift_left", "shift_right_logical",
"shift_right_arithmetic", "sign", "sin", "sinh", "slice", "sqrt",
"squeeze", "stop_gradient", "sub", "tie_in", "transpose", "xor",
"zeros_like"
}

@classmethod
Expand Down Expand Up @@ -910,6 +905,15 @@ def reduce_min(cls, harness: primitive_harness.Harness):
@classmethod
def reduce_window_add(cls, harness: primitive_harness.Harness):
return [
Jax2TfLimitation(
"Small deviations on GPU for large inputs and enable_xla=False",
dtypes=[np.float32],
devices="gpu",
modes=("eager", "graph", "compiled"),
expect_tf_error=False,
skip_comparison=False,
enabled=not harness.params["enable_xla"],
tol=3e-5),
Jax2TfLimitation(
"Large deviations on TPU for enable_xla=False",
dtypes=[np.float16, np.float32],
Expand Down
61 changes: 39 additions & 22 deletions jax/experimental/jax2tf/tests/primitive_harness.py
Expand Up @@ -828,7 +828,7 @@ def _make_argminmax_harness(prim,
for index_dtype in jtu.dtypes.all_integer + jtu.dtypes.all_unsigned:
_make_argminmax_harness(prim, "index_dtype", index_dtype=index_dtype)

# Some special cases, with equal elements and NaN
# Some special cases, with equal elements and NaN
for name, operand in [
("nan_0", np.array([np.nan, np.nan, 2., -2., -np.nan, -np.nan], np.float32)),
("nan_1", np.array([np.nan, -np.nan, 2., -2.], np.float32)),
Expand Down Expand Up @@ -2488,19 +2488,22 @@ def requires_xla_for_reduce(name, dtype):
requires_xla=True)
# Validate window_dilation
_make_reduce_window_harness("window_dilation", window_dilation=(1, 2))
# Validate squeezing behavior and dimensions in tf.nn.max_pool
for shape, window_dimensions in [
((2,), (2,)), # 1 spatial dimension, left and right squeeze
((2, 1), (2, 1)), # 1 spatial dimension, left squeeze
((1, 2), (1, 2)), # 1 spatial dimension, right squeeze
((1, 2, 1), (1, 2, 1)), # 1 spatial dimension no squeeze
((2, 4), (2, 2)), # 2 spatial dimensions, left and right squeeze
((2, 4, 3), (2, 2, 2)), # 3 spatial dimensions, left and right squeeze
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1)) # 3 spatial dimensions, no squeeze
# Validate batch and channel dimensions behavior. lax.reduce_window accepts
# inputs that either have or do not have batch and channel dimensions.
# N=batch, DHW=spatial, C=channel.
# Without XLA only supports 1D/2D reductions.
for shape, window_dimensions, requires_xla in [
((2,), (2,), False), # W
((2, 1), (2, 1), False), # WC
((1, 2), (1, 2), False), # NW
((1, 2, 1), (1, 2, 1), False), # NWC
((2, 4), (2, 2), False), # HW
((1, 2, 4, 1), (1, 2, 2, 1), False), # NHWC
((2, 4, 3), (2, 2, 2), True), # DHW
((1, 4, 3, 2, 1), (1, 2, 2, 2, 1), True) # NDHWC
]:
requires_xla = len(shape) > 2 # Without XLA only supports 1D/2D reductions.
_make_reduce_window_harness(
"squeeze_dim",
"batch_channel_dims",
computation=lax.max,
shape=shape,
dtype=np.float32,
Expand All @@ -2512,17 +2515,31 @@ def requires_xla_for_reduce(name, dtype):
window_dimensions=window_dimensions,
requires_xla=requires_xla)

# This corresponds to SAME padding.
_make_reduce_window_harness(
"same_padding",
shape=(112, 112),
init_value=-np.inf,
computation=lax.max,
window_dimensions=(3, 3),
window_strides=(2, 2),
padding="SAME")
for computation, id_value in [(lax.max, _get_max_identity(np.float32)),
(lax.min, _get_min_identity(np.float32)),
(lax.add, 0.)]:
_make_reduce_window_harness(
"same_padding",
shape=(112, 112),
init_value=id_value,
computation=computation,
window_dimensions=(3, 3),
window_strides=(2, 2),
padding="SAME")

# A few additional test cases for manual padding, which is applied when calling
# reduce_window with lax.add, SAME padding and window_dimensions != (1, 1, ...).
for window_dimensions, window_strides in [((2, 2), (1, 1)), ((3, 3), (2, 2)),
((13, 13), (5, 6))]:
_make_reduce_window_harness(
"manual_padding",
shape=(12, 12),
init_value=0.,
computation=lax.add,
window_dimensions=window_dimensions,
window_strides=window_strides,
padding="SAME")

# b/240647139
_make_reduce_window_harness(
"init_value_1d",
shape=(1, 16000),
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/jax2tf/tests/primitives_test.py
Expand Up @@ -102,7 +102,7 @@ class JaxPrimitiveTest(tf_test_util.JaxToTfTestCase):
@primitive_harness.parameterized(
primitive_harness.all_harnesses,
include_jax_unimpl=False,
#one_containing="reduce_window_max",
#one_containing="reduce_window_add_same_padding",
)
@jtu.ignore_warning(
category=UserWarning, message="Using reduced precision for gradient.*")
Expand Down

0 comments on commit df5f3c5

Please sign in to comment.