From 4e224bcfb99c3bd9b6a32b8ad7836d12517e788f Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 30 Jun 2022 11:02:07 -0700 Subject: [PATCH] [jax2tf] Add support for common audio convolutions (1D variants, dilated depthwise, transpose with SAME padding). PiperOrigin-RevId: 458266485 --- jax/experimental/jax2tf/impl_no_xla.py | 221 ++++++++++++------ .../jax2tf/tests/primitive_harness.py | 81 ++++++- 2 files changed, 228 insertions(+), 74 deletions(-) diff --git a/jax/experimental/jax2tf/impl_no_xla.py b/jax/experimental/jax2tf/impl_no_xla.py index b29dbd3fc70e..246d7835d3ff 100644 --- a/jax/experimental/jax2tf/impl_no_xla.py +++ b/jax/experimental/jax2tf/impl_no_xla.py @@ -50,6 +50,12 @@ def _xla_disabled_error(primitive_name: str, return NotImplementedError(msg) +def _conv_error(msg): + suffix = ("See source code for the precise conditions under which " + "convolutions can be converted without XLA.") + return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}") + + def _unimplemented(name): def op(*arg, **kwargs): @@ -66,13 +72,20 @@ def _transpose_for_tf_conv(lhs, rhs, dimension_numbers): # TODO(marcvanzee): Consider merging tranposes if we want to optimize. # For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW". lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW" + if len(lhs_perm) == 3: + # For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution + # logic can be applied downstream. + lhs = lhs[:, :, :, np.newaxis] # However, the TF ops only support "NHWC" on CPU, so we transpose again. lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC" + # For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW". rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW" + # Handle conv1d case. + if len(rhs_perm) == 3: + rhs = rhs[:, :, :, np.newaxis] # For the tf ops, rhs is expected to be "OIHW". rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO" - return lhs, rhs @@ -84,26 +97,116 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str: return "EXPLICIT" -def _pad_spatial_dims(in_shape, padding): +def _pad_spatial_dims(in_shape, padding, is_conv1d): """Pads `in_shape` using `padding`, which specifies padding for the spatial dimensions.""" # Add empty padding for batch and feature dimensions. no_pad = tf.constant([[0, 0]]) - padding = tf.concat([no_pad, padding, no_pad], 0) - in_shape = tf.pad(in_shape, padding) + if is_conv1d: + padding = tf.concat([no_pad, padding, no_pad], 0) + # Add empty padding for dummy dimension, too. + padding = tf.concat([no_pad, padding, no_pad, no_pad], 0) + else: + padding = tf.concat([no_pad, padding, no_pad], 0) + in_shape = tf.pad(in_shape, padding) return in_shape -def _is_valid_padding(kernel_sdims, strides, padding): - """Returns True if `padding` corresponds to "VALID" padding for a transposed convolution.""" - # This is simply the padding == 'VALID' part of lax._conv_transpose_padding. - for (begin, end), k, s in zip(padding, kernel_sdims, strides): - pad_len = k + s - 2 + builtins.max(k - s, 0) +def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding): + """Finds the padding type for a transpose convolution.""" + # This is simply checking agreement with lax._conv_transpose_padding. + is_valid = True + is_same = True + if not len(kernel_sdims) == len(lhs_dilation) == len(padding): + raise ValueError(f'Found different lengths for ' + f'kernel_sdims ({kernel_sdims}), ' + f'lhs_dilation ({lhs_dilation}), ' + f'and padding ({padding}).') + for k, s, (begin, end) in zip(kernel_sdims, lhs_dilation, padding): + # Check for VALID padding. + pad_len_valid = k + s - 2 + builtins.max(k - s, 0) pad_a = k - 1 - pad_b = pad_len - pad_a + pad_b = pad_len_valid - pad_a if begin != pad_a or end != pad_b: - return False + is_valid = False - return True + # Check for SAME padding. + pad_len_same = k + s - 2 + if s > k - 1: + pad_a = k - 1 + else: + pad_a = int(np.ceil(pad_len_same / 2)) + pad_b = pad_len_same - pad_a + if begin != pad_a or end != pad_b: + is_same = False + + if is_valid: + return 'VALID' + elif is_same: + return 'SAME' + raise ValueError('Transpose convolution padding mode must be ' + '`SAME` or `VALID`.') + +def _validate_spatial_dimensions(nr_spatial_dimensions): + """Check spatial dimension support.""" + # Currently we only support 1D+2D convolutions because it keeps the code + # relatively simple and covers most cases. + if nr_spatial_dimensions > 2: + raise _conv_error( + "We only support 1D or 2D convolutions, but found " + f"{nr_spatial_dimensions}.") + + +def _normalize_padding_and_dilations( + padding, lhs_dilation, rhs_dilation, is_conv1d): + if is_conv1d: + lhs_dilation = list(lhs_dilation) + [1] + rhs_dilation = list(rhs_dilation) + [1] + # Empty padding in the dummy dimension. + # Note that when kernel_size=stride=1, padding of (0, 0) is both 'VALID' and + # 'SAME'. So the inferred padding type will still register according to the + # first dimension padding. + padding = list(padding) + [(0, 0)] + return padding, lhs_dilation, rhs_dilation + +def _normalize_window_strides(window_strides): + """Ensure window_strides has length 4.""" + # Some TF ops require len(window_strides) == 4 while others do not. We simply + # ensure it always has len(4). + if len(window_strides) == 1: + # This is the Conv1D case. We add a dummy dimension to allow using 2D ops, + # and use stride=1 on the dummy dimension. + window_strides = list(window_strides) + [1] + if len(window_strides) == 2: + window_strides = [1] + list(window_strides) + [1] + return window_strides + +def _normalize_output_perm(output_perm, is_conv1d): + """Ensure that output_perm has length 4.""" + if is_conv1d: + output_perm = list(output_perm) + [1] + return output_perm + +def _validate_conv_features( + is_transpose, is_atrous, is_depthwise, feature_group_count, + batch_group_count, preferred_element_type, lhs_dtype): + if feature_group_count > 1 and not is_depthwise: + raise _conv_error("Grouped convolutions are unsupported") + if (is_depthwise and is_atrous) and not is_transpose: + # We allow dilated depthwise convolutions. + pass + elif [is_depthwise, is_atrous, is_transpose].count(True) > 1: + raise _conv_error( + f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) " + f"and tranposed convolutions ({is_transpose})") + + # We can implement batch grouping when there is a need for it. + if batch_group_count != 1: + raise _conv_error("Unimplemented support for batch_group_count != 1 " + f"(found {batch_group_count})") + + if (preferred_element_type is not None and + preferred_element_type != lhs_dtype): + raise _conv_error("Unimplemented support for preferred_element_type") def _conv_general_dilated( @@ -116,64 +219,40 @@ def _conv_general_dilated( """Implementation of lax.conv_general_dilated_p using XlaConv.""" del lhs_shape, rhs_shape, precision # Unused arguments. out_shape = jax2tf._aval_to_tf_shape(_out_aval) + _validate_spatial_dimensions(len(lhs.shape) - 2) + is_conv1d = len(lhs.shape) - 2 == 1 - def error(msg): - suffix = ("See source code for the precise conditions under which " - "convolutions can be converted without XLA.") - return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}") - - nr_spatial_dimensions = len(lhs.shape) - 2 - - # Currently we only support 2D convolutions because it keeps the code - # relatively simple and covers most cases. - if nr_spatial_dimensions != 2: - error( - f"We only support 2D convolutions, but found {nr_spatial_dimensions}.") - - # We can implement batch grouping when there is a need for it. - if batch_group_count != 1: - raise error("Unimplemented support for batch_group_count != 1 " - f"(found {batch_group_count})") - - if (preferred_element_type is not None and - preferred_element_type != lhs.dtype.as_numpy_dtype): - raise error("Unimplemented support for preferred_element_type") + tf_window_strides = _normalize_window_strides(window_strides) + padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations( + padding, lhs_dilation, rhs_dilation, is_conv1d) lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers) - output_perm = dimension_numbers[2] in_channels = lhs.shape[-1] *rhs_spatial_shapes, _, rhs_out_channel = rhs.shape + is_transpose = any([d != 1 for d in lhs_dilation]) + is_atrous = any([d != 1 for d in rhs_dilation]) is_depthwise = in_channels == feature_group_count and feature_group_count > 1 - is_transpose = list(lhs_dilation) != [1] * nr_spatial_dimensions - is_atrous = list(rhs_dilation) != [1] * nr_spatial_dimensions - - if feature_group_count > 1 and not is_depthwise: - raise error("Grouped convolutions are unsupported") - - if is_transpose: - # We provide support for transposed convolutions called through - # lax.conv2d_tranpose, but only if the provided padding was VALID. - if not _is_valid_padding(rhs_spatial_shapes, window_strides, padding): - raise error( - "Can only convert Transposed Convolutions with 'VALID' padding") - - if [is_depthwise, is_atrous, is_transpose].count(True) > 1: - raise error( - "Can only do one of depthwise, atrous and tranposed convolutions") + _validate_conv_features(is_transpose, is_atrous, is_depthwise, + feature_group_count, batch_group_count, + preferred_element_type, lhs.dtype.as_numpy_dtype) rhs_dilated_shape = [ (k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation) ] + output_perm = dimension_numbers[2] - padding_type = pads_to_padtype(lhs.shape[1:3], rhs_dilated_shape, window_strides, padding) - - # We only manually pad if we aren't using a tranposed convolutions, because - # there we don't do any padding. - if padding_type == "EXPLICIT" and not is_transpose: - lhs = _pad_spatial_dims(lhs, padding) - padding_type = "VALID" + if is_transpose: + padding_type = _conv_transpose_pads_to_padtype( + rhs_spatial_shapes, lhs_dilation, padding) + else: + padding_type = pads_to_padtype( + lhs.shape[1:3], rhs_dilated_shape, window_strides, padding) + # We only manually pad if we aren't using a tranposed convolutions. + if padding_type == "EXPLICIT": + lhs = _pad_spatial_dims(lhs, padding, is_conv1d) + padding_type = "VALID" if any(r > l for l, r in zip(lhs.shape[1:3], rhs_dilated_shape) ) and padding_type != "SAME": @@ -182,13 +261,6 @@ def error(msg): # We thus return zeros to make sure the behavior is consistent. return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape) - # Some TF ops require len(window_strides) == 4 while others do not. We simply - # ensure it always has len(4). - if type(window_strides) == int: - window_strides = [window_strides] * 2 - if len(window_strides) == 2: - window_strides = [1] + list(window_strides) + [1] - if is_depthwise: # Reshape filter from # [filter_height, filter_width, 1, in_channels * channel_multiplier] to @@ -198,7 +270,7 @@ def error(msg): output = tf.nn.depthwise_conv2d( input=lhs, filter=tf.reshape(rhs, new_rhs_shape), - strides=window_strides, + strides=tf_window_strides, padding=padding_type, dilations=rhs_dilation) @@ -207,29 +279,34 @@ def error(msg): rhs_t = tf.reverse(rhs, [0, 1]) rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2)) - # We should tranpose `out_shape` so it conforms to what TF expects. - tf_out_shape = tuple(out_shape[i] for i in output_perm) # "NCHW" - tf_out_shape = tuple( - tf_out_shape[i] for i in (0, 2, 3, 1)) # "NCHW" -> "NHWC" - + # We should tranpose `out_shape` to "NHWC", which is what TF expects. + # First transpose to "NCHW". + if is_conv1d: + tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,) + else: + tf_out_shape = tuple(out_shape[i] for i in output_perm) + # Then transpose "NCHW" to "NHWC". + tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1)) output = tf.nn.conv2d_transpose( input=lhs, filters=rhs_t, output_shape=tf_out_shape, strides=lhs_dilation, - padding="VALID") + padding=padding_type) else: output = tf.nn.conv2d( input=lhs, filters=rhs, - strides=window_strides, + strides=tf_window_strides, padding=padding_type, dilations=rhs_dilation) # TF outputs in format "NHWC", so convert to "NCHW", which is lax's default # format. output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW" + if is_conv1d: + output = output[:, :, :, 0] # To determine the right permutation, we compute the inverse permutation of # `output_perm`, so that when `output_perm` is applied to `output`, we obtain # the outpt in NCHW format. diff --git a/jax/experimental/jax2tf/tests/primitive_harness.py b/jax/experimental/jax2tf/tests/primitive_harness.py index 2ff06e857d2f..e7909d885870 100644 --- a/jax/experimental/jax2tf/tests/primitive_harness.py +++ b/jax/experimental/jax2tf/tests/primitive_harness.py @@ -2889,8 +2889,19 @@ def _make_conv_harness(name, feature_group_count=feature_group_count, batch_group_count=batch_group_count) +#--- BEGIN Tests for conv_general_dilated with works_without_xla=True --- -#--- BEGIN Tests for conv_general_dilated with works_for_xla=True --- +# Validate Conv1D. +_make_conv_harness( + "conv1d", + lhs_shape=(2, 3, 10), + rhs_shape=(3, 3, 5), + window_strides=(1,), + padding=((0, 0),), + lhs_dilation=(1,), + rhs_dilation=(1,), + dimension_numbers=("NCH", "OIH", "NCH"), + works_without_xla=True) # feature_group_count is supported for enable_xla=False only if we are doing a @@ -2904,6 +2915,39 @@ def _make_conv_harness(name, feature_group_count=3, works_without_xla=True) +_make_conv_harness( + "depthwise2d_dilated", + lhs_shape=(2, 3, 9, 9), # "NCHW": in_channels == 3 + rhs_shape=(12, 1, 3, 3), # "OIHW": channel_multiplier = 12/3 = 4 + feature_group_count=3, + lhs_dilation=(1, 1), + rhs_dilation=(2, 1), + works_without_xla=True) + +_make_conv_harness( + "depthwise1d", + lhs_shape=(2, 3, 9), # "NCH": in_channels == 3 + rhs_shape=(12, 1, 3), # "OIH": channel_multiplier = 12/3 = 4 + feature_group_count=3, + lhs_dilation=(1,), + rhs_dilation=(1,), + window_strides=(1, ), + padding=((0, 0),), + dimension_numbers=("NCH", "OIH", "NCH"), + works_without_xla=True) + +_make_conv_harness( + "depthwise1d_dilated", + lhs_shape=(2, 3, 9), # "NCH": in_channels == 3 + rhs_shape=(12, 1, 3), # "OIH": channel_multiplier = 12/3 = 4 + feature_group_count=3, + lhs_dilation=(1,), + rhs_dilation=(2,), + window_strides=(1,), + padding=((0, 0),), + dimension_numbers=("NCH", "OIH", "NCH"), + works_without_xla=True) + # Validate variations of window_strides for window_strides in [(2, 3)]: _make_conv_harness( @@ -2939,6 +2983,39 @@ def _make_conv_harness(name, dimension_numbers=("NHWC", "HWIO", "NHWC"), works_without_xla=True) +# Simulate a call from lax.conv_transpose. +_make_conv_harness( + "conv_tranpose1d_valid_padding", + lhs_shape=(1, 16, 2), + rhs_shape=(3, 2, 2), + window_strides=(1,), + lhs_dilation=(2,), + rhs_dilation=(1,), + padding=((2, 2), ), + dimension_numbers=("NHC", "HIO", "NHC"), + works_without_xla=True) + +_make_conv_harness( + "conv_tranpose1d_same_padding", + lhs_shape=(1, 16, 2), + rhs_shape=(3, 2, 2), + window_strides=(1,), + lhs_dilation=(2,), + rhs_dilation=(1,), + padding=((2, 1), ), + dimension_numbers=("NHC", "HIO", "NHC"), + works_without_xla=True) + +_make_conv_harness( + "conv_tranpose2d_same_padding", + lhs_shape=(1, 16, 16, 2), + rhs_shape=(2, 3, 2, 2), + window_strides=(1, 1), + lhs_dilation=(2, 2), + padding=((1, 1), (2, 1)), + dimension_numbers=("NHWC", "HWIO", "NHWC"), + works_without_xla=True) + # Validate rhs > lhs. # One dimension of rhs is bigger than lhs. _make_conv_harness( @@ -3008,7 +3085,7 @@ def _make_conv_harness(name, rhs_dilation=rhs_dilation, works_without_xla=True) -#--- END Tests for conv_general_dilated with works_for_xla=True --- +#--- END Tests for conv_general_dilated with works_without_xla=True --- for lhs_dilation, rhs_dilation in [