…time linear to the number of channels (vs quadratic).

Note that complexity of `extract_patches` is quadratic in the number of channels `C`, while `lax.conv_general_dilated_patches` is linear due to using depthwise convolution. AFAIK `extract_image_patches` has the same linear complexity, but can be shortened by calling to `lax.conv_general_dilated_patches` as well.

PiperOrigin-RevId: 495102958
romanngg authored and Scenic Authors committed Oct 23, 2023
1 parent 6b01e9b commit d3c7026
93 changes: 15 additions & 78 deletions scenic/model_lib/layers/
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def extract_image_patches(lhs,
A 4-D Tensor. Has the same type and data format as `lhs`, and with shape
`[batch, num_patches_col, num_patches_row, rhs_shape[1], rhs_shape[2], C]`.
A 4-D Tensor. Has the same type as `lhs`, and with shape
`[batch, num_patches_row, num_patches_col, rhs_shape[1], rhs_shape[2], C]`.
num_dims = lhs.ndim
num_spatial_dims = num_dims - 2
Expand All @@ -76,83 +76,20 @@ def extract_image_patches(lhs,
'Current implementation does not support dilations in the batch '
'and depth dimensions.')

# Replicating tensorflow's implementation.
lhs_perm = lax.conv_general_permutations(
(data_format, 'HWIO', data_format))[0]
kernel_shape = [rhs_shape[i] for i in lhs_perm[2:]]

kernel_size =
conv_filter_shape = kernel_shape[:]
conv_filter_shape.append(kernel_size * depth)

iota_kernel_shape = (kernel_size, depth, kernel_size)

conv_filter = lax.eq(
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 0),
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 2),
filter_shape = tuple(rhs_shape[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim))
patches = lax.conv_general_dilated_patches(
window_strides=tuple(strides[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim)),
dimension_numbers=(data_format, 'HWIO', 'NHWC')
conv_filter = lax.convert_element_type(conv_filter, lhs.dtype)
conv_filter = lax.reshape(conv_filter, conv_filter_shape)

dim_num = lax.conv_dimension_numbers(lhs.shape, conv_filter.shape,
(data_format, 'HWIO', data_format))
conv_strides = [0] * num_spatial_dims
conv_rhs_dilation = [0] * num_spatial_dims
for i in range(num_spatial_dims):
dim = dim_num.lhs_spec[i + 2]
conv_strides[i] = strides[dim]
conv_rhs_dilation[i] = rhs_dilation[dim]

conv = lax.conv_general_dilated(lhs, conv_filter, conv_strides, padding, None,
conv_rhs_dilation, dim_num, depth)

conv_dims = list(conv.shape[:-1])
conv = lax.reshape(conv, conv_dims)

permutation = list(range(len(conv_dims)))
depth_dim = permutation.pop(-3)

return lax.transpose(conv, permutation)

def extract_patches(lhs, rhs_shape, strides=(1, 1)):
"""Extracts patches from an image using a convolution operator.
lhs: A tensor of images of shapes (B, H, W, C).
rhs_shape: The size of the patches to extract (h, w).
strides: The shift between extracted patches (s1, s2)
All the patches in a tensor of dimension
(B, (H - h + 1) // s1, (W - w + 1) // s2, h, w, C).
# [batch, channels, height, width]
lhs = jnp.moveaxis(lhs, -1, 1)
d = lhs.shape[1]
h, w = rhs_shape

# Construct the lookup conv weights.
dim_out = jnp.arange(d * h * w).reshape((-1, 1, 1, 1))
dim_in = jnp.arange(d).reshape((1, -1, 1, 1))
i = jnp.arange(h).reshape((1, 1, -1, 1))
j = jnp.arange(w).reshape((1, 1, 1, -1))
weights = ((w * i + j) * d + dim_in == dim_out).astype(jnp.float32)

# [batch, h * w * d, (H - h + 1) // s1, (W - w + 1) // s2]
concatenated_patches = lax.conv(
lhs, weights, window_strides=strides, padding='VALID')

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h * w * d]
concatenated_patches = jnp.moveaxis(concatenated_patches, 1, -1)

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h, w, d]
shape = concatenated_patches.shape[:3] + (h, w, d)
return concatenated_patches.reshape(shape)
shape = patches.shape[:-1] + (depth,) + filter_shape
patches = patches.reshape(shape)
patches = jnp.moveaxis(patches, -1 - num_spatial_dims, -1)
return patches

def compute_relative_positions(query_spatial_shape,
Expand Down
11 changes: 0 additions & 11 deletions scenic/model_lib/layers/tests/
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,6 @@ def test_central_crop(self):
self.assertEqual(output[0, 0, 0, 0], 11.)
self.assertEqual(output[0, -1, -1, 0], 88.)

def test_extract_patches(self):
"""Tests extract_patches."""
input_shape = (16, 3, 3, 32)
inputs = np.array(np.random.normal(size=input_shape))

# patching a 3x3 image to 3x3 patches, with no stride 1x1 should do nothing
# but reshaping the (bs, h, w, c) to (bs, 1, 1, h, w, c)
patched = nn_ops.extract_patches(inputs, (3, 3), (1, 1))
self.assertEqual(patched.shape, (16, 1, 1, 3, 3, 32))
np.testing.assert_allclose(inputs, patched.reshape(input_shape), atol=1e-2)

@parameterized.named_parameters([('test_avg_pooling', 'avg_pooling'),
('test_max_pooling', 'max_pooling'),
('test_avg_pooling_bu', 'avg_pooling'),
Expand Down
11 changes: 8 additions & 3 deletions scenic/projects/fast_vit/
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,14 @@ def __call__(self,
if transformer_encoder_type in ['grid', 'grid_attention']:
# First put patches of patches in rows. (we already projected pixels to
# patches in the stem and at this level, the input tokens are patches).
# TODO(dehghani): Check if nn_ops.extract_image_patches is faster.
pp_size = self.transformer_encoder_configs.patches_of_patches_size
x = nn_ops.extract_patches(lhs=x, rhs_shape=pp_size, strides=pp_size)
x = nn_ops.extract_image_patches(
rhs_shape=(1,) + pp_size + (1,),
strides=(1,) + pp_size + (1,),
rhs_dilation=(1,) * 4
# TODO(dehghani): Check if we can output a 4D tensor directly and get
# rid of this reshape.
bs, ph, pw, h, w, c = x.shape
Expand All @@ -436,6 +441,7 @@ def __call__(self,
elif transformer_encoder_type == 'grid_attention':
transformer_encoder_type = 'axial_attention'

cls_token = x[:, 0, 0]
x = EncoderAxial(
Expand All @@ -446,7 +452,6 @@ def __call__(self,
x, train=train)
cls_token = x[:, 0, 0]

raise ValueError(
Expand Down

