Skip to content


[jax2tf] Add support for common audio convolutions (1D variants, dila…
Browse files Browse the repository at this point in the history
…ted depthwise, transpose with SAME padding).

PiperOrigin-RevId: 458266485
  • Loading branch information
jax authors committed Jun 30, 2022
1 parent 61b3dc5 commit 4e224bc
Show file tree
Hide file tree
Showing 2 changed files with 228 additions and 74 deletions.
221 changes: 149 additions & 72 deletions jax/experimental/jax2tf/
Expand Up @@ -50,6 +50,12 @@ def _xla_disabled_error(primitive_name: str,
return NotImplementedError(msg)

def _conv_error(msg):
suffix = ("See source code for the precise conditions under which "
"convolutions can be converted without XLA.")
return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}")

def _unimplemented(name):

def op(*arg, **kwargs):
Expand All @@ -66,13 +72,20 @@ def _transpose_for_tf_conv(lhs, rhs, dimension_numbers):
# TODO(marcvanzee): Consider merging tranposes if we want to optimize.
# For `lhs_perm` / `output_perm`, perm (0, 1, 2, 3) corresponds to "NCHW".
lhs = tf.transpose(lhs, lhs_perm) # lhs --> "NCHW"
if len(lhs_perm) == 3:
# For 1D convolution, we add a trivial "W" dimension, so that 2D Convolution
# logic can be applied downstream.
lhs = lhs[:, :, :, np.newaxis]
# However, the TF ops only support "NHWC" on CPU, so we transpose again.
lhs = tf.transpose(lhs, (0, 2, 3, 1)) # "NCHW" --> "NHWC"

# For `rhs_perm`, perm (0, 1, 2, 3) corresponds to "OIHW".
rhs = tf.transpose(rhs, rhs_perm) # rhs --> "OIHW"
# Handle conv1d case.
if len(rhs_perm) == 3:
rhs = rhs[:, :, :, np.newaxis]
# For the tf ops, rhs is expected to be "OIHW".
rhs = tf.transpose(rhs, (2, 3, 1, 0)) # "OIHW" --> "HWIO"

return lhs, rhs

Expand All @@ -84,26 +97,116 @@ def pads_to_padtype(in_shape, window_shape, window_strides, padding) -> str:
return "EXPLICIT"

def _pad_spatial_dims(in_shape, padding):
def _pad_spatial_dims(in_shape, padding, is_conv1d):
"""Pads `in_shape` using `padding`, which specifies padding for the spatial dimensions."""
# Add empty padding for batch and feature dimensions.
no_pad = tf.constant([[0, 0]])
padding = tf.concat([no_pad, padding, no_pad], 0)
in_shape = tf.pad(in_shape, padding)
if is_conv1d:
padding = tf.concat([no_pad, padding, no_pad], 0)
# Add empty padding for dummy dimension, too.
padding = tf.concat([no_pad, padding, no_pad, no_pad], 0)
padding = tf.concat([no_pad, padding, no_pad], 0)
in_shape = tf.pad(in_shape, padding)
return in_shape

def _is_valid_padding(kernel_sdims, strides, padding):
"""Returns True if `padding` corresponds to "VALID" padding for a transposed convolution."""
# This is simply the padding == 'VALID' part of lax._conv_transpose_padding.
for (begin, end), k, s in zip(padding, kernel_sdims, strides):
pad_len = k + s - 2 + builtins.max(k - s, 0)
def _conv_transpose_pads_to_padtype(kernel_sdims, lhs_dilation, padding):
"""Finds the padding type for a transpose convolution."""
# This is simply checking agreement with lax._conv_transpose_padding.
is_valid = True
is_same = True
if not len(kernel_sdims) == len(lhs_dilation) == len(padding):
raise ValueError(f'Found different lengths for '
f'kernel_sdims ({kernel_sdims}), '
f'lhs_dilation ({lhs_dilation}), '
f'and padding ({padding}).')
for k, s, (begin, end) in zip(kernel_sdims, lhs_dilation, padding):
# Check for VALID padding.
pad_len_valid = k + s - 2 + builtins.max(k - s, 0)
pad_a = k - 1
pad_b = pad_len - pad_a
pad_b = pad_len_valid - pad_a
if begin != pad_a or end != pad_b:
return False
is_valid = False

return True
# Check for SAME padding.
pad_len_same = k + s - 2
if s > k - 1:
pad_a = k - 1
pad_a = int(np.ceil(pad_len_same / 2))
pad_b = pad_len_same - pad_a
if begin != pad_a or end != pad_b:
is_same = False

if is_valid:
return 'VALID'
elif is_same:
return 'SAME'
raise ValueError('Transpose convolution padding mode must be '
'`SAME` or `VALID`.')

def _validate_spatial_dimensions(nr_spatial_dimensions):
"""Check spatial dimension support."""
# Currently we only support 1D+2D convolutions because it keeps the code
# relatively simple and covers most cases.
if nr_spatial_dimensions > 2:
raise _conv_error(
"We only support 1D or 2D convolutions, but found "

def _normalize_padding_and_dilations(
padding, lhs_dilation, rhs_dilation, is_conv1d):
if is_conv1d:
lhs_dilation = list(lhs_dilation) + [1]
rhs_dilation = list(rhs_dilation) + [1]
# Empty padding in the dummy dimension.
# Note that when kernel_size=stride=1, padding of (0, 0) is both 'VALID' and
# 'SAME'. So the inferred padding type will still register according to the
# first dimension padding.
padding = list(padding) + [(0, 0)]
return padding, lhs_dilation, rhs_dilation

def _normalize_window_strides(window_strides):
"""Ensure window_strides has length 4."""
# Some TF ops require len(window_strides) == 4 while others do not. We simply
# ensure it always has len(4).
if len(window_strides) == 1:
# This is the Conv1D case. We add a dummy dimension to allow using 2D ops,
# and use stride=1 on the dummy dimension.
window_strides = list(window_strides) + [1]
if len(window_strides) == 2:
window_strides = [1] + list(window_strides) + [1]
return window_strides

def _normalize_output_perm(output_perm, is_conv1d):
"""Ensure that output_perm has length 4."""
if is_conv1d:
output_perm = list(output_perm) + [1]
return output_perm

def _validate_conv_features(
is_transpose, is_atrous, is_depthwise, feature_group_count,
batch_group_count, preferred_element_type, lhs_dtype):
if feature_group_count > 1 and not is_depthwise:
raise _conv_error("Grouped convolutions are unsupported")
if (is_depthwise and is_atrous) and not is_transpose:
# We allow dilated depthwise convolutions.
elif [is_depthwise, is_atrous, is_transpose].count(True) > 1:
raise _conv_error(
f"Can only do one of depthwise ({is_depthwise}), atrous ({is_atrous}) "
f"and tranposed convolutions ({is_transpose})")

# We can implement batch grouping when there is a need for it.
if batch_group_count != 1:
raise _conv_error("Unimplemented support for batch_group_count != 1 "
f"(found {batch_group_count})")

if (preferred_element_type is not None and
preferred_element_type != lhs_dtype):
raise _conv_error("Unimplemented support for preferred_element_type")

def _conv_general_dilated(
Expand All @@ -116,64 +219,40 @@ def _conv_general_dilated(
"""Implementation of lax.conv_general_dilated_p using XlaConv."""
del lhs_shape, rhs_shape, precision # Unused arguments.
out_shape = jax2tf._aval_to_tf_shape(_out_aval)
_validate_spatial_dimensions(len(lhs.shape) - 2)
is_conv1d = len(lhs.shape) - 2 == 1

def error(msg):
suffix = ("See source code for the precise conditions under which "
"convolutions can be converted without XLA.")
return _xla_disabled_error("conv_general_dilated", f"{msg} - {suffix}")

nr_spatial_dimensions = len(lhs.shape) - 2

# Currently we only support 2D convolutions because it keeps the code
# relatively simple and covers most cases.
if nr_spatial_dimensions != 2:
f"We only support 2D convolutions, but found {nr_spatial_dimensions}.")

# We can implement batch grouping when there is a need for it.
if batch_group_count != 1:
raise error("Unimplemented support for batch_group_count != 1 "
f"(found {batch_group_count})")

if (preferred_element_type is not None and
preferred_element_type != lhs.dtype.as_numpy_dtype):
raise error("Unimplemented support for preferred_element_type")
tf_window_strides = _normalize_window_strides(window_strides)
padding, lhs_dilation, rhs_dilation = _normalize_padding_and_dilations(
padding, lhs_dilation, rhs_dilation, is_conv1d)

lhs, rhs = _transpose_for_tf_conv(lhs, rhs, dimension_numbers)
output_perm = dimension_numbers[2]

in_channels = lhs.shape[-1]
*rhs_spatial_shapes, _, rhs_out_channel = rhs.shape

is_transpose = any([d != 1 for d in lhs_dilation])
is_atrous = any([d != 1 for d in rhs_dilation])
is_depthwise = in_channels == feature_group_count and feature_group_count > 1
is_transpose = list(lhs_dilation) != [1] * nr_spatial_dimensions
is_atrous = list(rhs_dilation) != [1] * nr_spatial_dimensions

if feature_group_count > 1 and not is_depthwise:
raise error("Grouped convolutions are unsupported")

if is_transpose:
# We provide support for transposed convolutions called through
# lax.conv2d_tranpose, but only if the provided padding was VALID.
if not _is_valid_padding(rhs_spatial_shapes, window_strides, padding):
raise error(
"Can only convert Transposed Convolutions with 'VALID' padding")

if [is_depthwise, is_atrous, is_transpose].count(True) > 1:
raise error(
"Can only do one of depthwise, atrous and tranposed convolutions")
_validate_conv_features(is_transpose, is_atrous, is_depthwise,
feature_group_count, batch_group_count,
preferred_element_type, lhs.dtype.as_numpy_dtype)

rhs_dilated_shape = [
(k - 1) * r + 1 for k, r in zip(rhs_spatial_shapes, rhs_dilation)
output_perm = dimension_numbers[2]

padding_type = pads_to_padtype(lhs.shape[1:3], rhs_dilated_shape, window_strides, padding)

# We only manually pad if we aren't using a tranposed convolutions, because
# there we don't do any padding.
if padding_type == "EXPLICIT" and not is_transpose:
lhs = _pad_spatial_dims(lhs, padding)
padding_type = "VALID"
if is_transpose:
padding_type = _conv_transpose_pads_to_padtype(
rhs_spatial_shapes, lhs_dilation, padding)
padding_type = pads_to_padtype(
lhs.shape[1:3], rhs_dilated_shape, window_strides, padding)
# We only manually pad if we aren't using a tranposed convolutions.
if padding_type == "EXPLICIT":
lhs = _pad_spatial_dims(lhs, padding, is_conv1d)
padding_type = "VALID"

if any(r > l for l, r in zip(lhs.shape[1:3], rhs_dilated_shape)
) and padding_type != "SAME":
Expand All @@ -182,13 +261,6 @@ def error(msg):
# We thus return zeros to make sure the behavior is consistent.
return tf.broadcast_to(tf.constant(0, dtype=tf.float32), out_shape)

# Some TF ops require len(window_strides) == 4 while others do not. We simply
# ensure it always has len(4).
if type(window_strides) == int:
window_strides = [window_strides] * 2
if len(window_strides) == 2:
window_strides = [1] + list(window_strides) + [1]

if is_depthwise:
# Reshape filter from
# [filter_height, filter_width, 1, in_channels * channel_multiplier] to
Expand All @@ -198,7 +270,7 @@ def error(msg):
output = tf.nn.depthwise_conv2d(
filter=tf.reshape(rhs, new_rhs_shape),

Expand All @@ -207,29 +279,34 @@ def error(msg):
rhs_t = tf.reverse(rhs, [0, 1])
rhs_t = tf.transpose(rhs_t, (0, 1, 3, 2))

# We should tranpose `out_shape` so it conforms to what TF expects.
tf_out_shape = tuple(out_shape[i] for i in output_perm) # "NCHW"
tf_out_shape = tuple(
tf_out_shape[i] for i in (0, 2, 3, 1)) # "NCHW" -> "NHWC"

# We should tranpose `out_shape` to "NHWC", which is what TF expects.
# First transpose to "NCHW".
if is_conv1d:
tf_out_shape = tuple(out_shape[i] for i in output_perm) + (1,)
tf_out_shape = tuple(out_shape[i] for i in output_perm)
# Then transpose "NCHW" to "NHWC".
tf_out_shape = tuple(tf_out_shape[i] for i in (0, 2, 3, 1))
output = tf.nn.conv2d_transpose(

output = tf.nn.conv2d(

# TF outputs in format "NHWC", so convert to "NCHW", which is lax's default
# format.
output = tf.transpose(output, (0, 3, 1, 2)) # "NHWC" --> "NCHW"
if is_conv1d:
output = output[:, :, :, 0]
# To determine the right permutation, we compute the inverse permutation of
# `output_perm`, so that when `output_perm` is applied to `output`, we obtain
# the outpt in NCHW format.
Expand Down

0 comments on commit 4e224bc

Please sign in to comment.