Skip to content

Commit

Permalink
Use lax.conv_general_dilated_patches to extract image patches in ti…
Browse files Browse the repository at this point in the history
…me linear to the number of channels (vs quadratic).

Note that complexity of `extract_patches` is quadratic in the number of channels `C`, while `lax.conv_general_dilated_patches` is linear due to using depthwise convolution. AFAIK `extract_image_patches` has the same linear complexity, but can be shortened by calling to `lax.conv_general_dilated_patches` as well.

PiperOrigin-RevId: 495102958
  • Loading branch information
romanngg authored and Scenic Authors committed Oct 23, 2023
1 parent 6b01e9b commit d3c7026
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 92 deletions.
93 changes: 15 additions & 78 deletions scenic/model_lib/layers/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def extract_image_patches(lhs,
`'NCHW'`.
Returns:
A 4-D Tensor. Has the same type and data format as `lhs`, and with shape
`[batch, num_patches_col, num_patches_row, rhs_shape[1], rhs_shape[2], C]`.
A 4-D Tensor. Has the same type as `lhs`, and with shape
`[batch, num_patches_row, num_patches_col, rhs_shape[1], rhs_shape[2], C]`.
"""
num_dims = lhs.ndim
num_spatial_dims = num_dims - 2
Expand All @@ -76,83 +76,20 @@ def extract_image_patches(lhs,
'Current implementation does not support dilations in the batch '
'and depth dimensions.')

# Replicating tensorflow's implementation.
lhs_perm = lax.conv_general_permutations(
(data_format, 'HWIO', data_format))[0]
kernel_shape = [rhs_shape[i] for i in lhs_perm[2:]]

kernel_size = np.prod(kernel_shape)
conv_filter_shape = kernel_shape[:]
conv_filter_shape.append(1)
conv_filter_shape.append(kernel_size * depth)

iota_kernel_shape = (kernel_size, depth, kernel_size)

conv_filter = lax.eq(
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 0),
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 2),
filter_shape = tuple(rhs_shape[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim))
patches = lax.conv_general_dilated_patches(
lhs=lhs,
filter_shape=filter_shape,
padding=padding,
window_strides=tuple(strides[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim)),
dimension_numbers=(data_format, 'HWIO', 'NHWC')
)
conv_filter = lax.convert_element_type(conv_filter, lhs.dtype)
conv_filter = lax.reshape(conv_filter, conv_filter_shape)

dim_num = lax.conv_dimension_numbers(lhs.shape, conv_filter.shape,
(data_format, 'HWIO', data_format))
conv_strides = [0] * num_spatial_dims
conv_rhs_dilation = [0] * num_spatial_dims
for i in range(num_spatial_dims):
dim = dim_num.lhs_spec[i + 2]
conv_strides[i] = strides[dim]
conv_rhs_dilation[i] = rhs_dilation[dim]

conv = lax.conv_general_dilated(lhs, conv_filter, conv_strides, padding, None,
conv_rhs_dilation, dim_num, depth)

conv_dims = list(conv.shape[:-1])
conv_dims.append(depth)
conv_dims.extend(kernel_shape)
conv = lax.reshape(conv, conv_dims)

permutation = list(range(len(conv_dims)))
depth_dim = permutation.pop(-3)
permutation.append(depth_dim)

return lax.transpose(conv, permutation)


def extract_patches(lhs, rhs_shape, strides=(1, 1)):
"""Extracts patches from an image using a convolution operator.
Args:
lhs: A tensor of images of shapes (B, H, W, C).
rhs_shape: The size of the patches to extract (h, w).
strides: The shift between extracted patches (s1, s2)
Returns:
All the patches in a tensor of dimension
(B, (H - h + 1) // s1, (W - w + 1) // s2, h, w, C).
"""
# [batch, channels, height, width]
lhs = jnp.moveaxis(lhs, -1, 1)
d = lhs.shape[1]
h, w = rhs_shape

# Construct the lookup conv weights.
dim_out = jnp.arange(d * h * w).reshape((-1, 1, 1, 1))
dim_in = jnp.arange(d).reshape((1, -1, 1, 1))
i = jnp.arange(h).reshape((1, 1, -1, 1))
j = jnp.arange(w).reshape((1, 1, 1, -1))
weights = ((w * i + j) * d + dim_in == dim_out).astype(jnp.float32)

# [batch, h * w * d, (H - h + 1) // s1, (W - w + 1) // s2]
concatenated_patches = lax.conv(
lhs, weights, window_strides=strides, padding='VALID')

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h * w * d]
concatenated_patches = jnp.moveaxis(concatenated_patches, 1, -1)

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h, w, d]
shape = concatenated_patches.shape[:3] + (h, w, d)
return concatenated_patches.reshape(shape)
shape = patches.shape[:-1] + (depth,) + filter_shape
patches = patches.reshape(shape)
patches = jnp.moveaxis(patches, -1 - num_spatial_dims, -1)
return patches


def compute_relative_positions(query_spatial_shape,
Expand Down
11 changes: 0 additions & 11 deletions scenic/model_lib/layers/tests/test_nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,6 @@ def test_central_crop(self):
self.assertEqual(output[0, 0, 0, 0], 11.)
self.assertEqual(output[0, -1, -1, 0], 88.)

def test_extract_patches(self):
"""Tests extract_patches."""
input_shape = (16, 3, 3, 32)
inputs = np.array(np.random.normal(size=input_shape))

# patching a 3x3 image to 3x3 patches, with no stride 1x1 should do nothing
# but reshaping the (bs, h, w, c) to (bs, 1, 1, h, w, c)
patched = nn_ops.extract_patches(inputs, (3, 3), (1, 1))
self.assertEqual(patched.shape, (16, 1, 1, 3, 3, 32))
np.testing.assert_allclose(inputs, patched.reshape(input_shape), atol=1e-2)

@parameterized.named_parameters([('test_avg_pooling', 'avg_pooling'),
('test_max_pooling', 'max_pooling'),
('test_avg_pooling_bu', 'avg_pooling'),
Expand Down
11 changes: 8 additions & 3 deletions scenic/projects/fast_vit/xvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,14 @@ def __call__(self,
if transformer_encoder_type in ['grid', 'grid_attention']:
# First put patches of patches in rows. (we already projected pixels to
# patches in the stem and at this level, the input tokens are patches).
# TODO(dehghani): Check if nn_ops.extract_image_patches is faster.
pp_size = self.transformer_encoder_configs.patches_of_patches_size
x = nn_ops.extract_patches(lhs=x, rhs_shape=pp_size, strides=pp_size)
x = nn_ops.extract_image_patches(
lhs=x,
rhs_shape=(1,) + pp_size + (1,),
strides=(1,) + pp_size + (1,),
padding='VALID',
rhs_dilation=(1,) * 4
)
# TODO(dehghani): Check if we can output a 4D tensor directly and get
# rid of this reshape.
bs, ph, pw, h, w, c = x.shape
Expand All @@ -436,6 +441,7 @@ def __call__(self,
elif transformer_encoder_type == 'grid_attention':
transformer_encoder_type = 'axial_attention'

cls_token = x[:, 0, 0]
x = EncoderAxial(
mlp_dim=self.mlp_dim,
num_layers=self.num_layers,
Expand All @@ -446,7 +452,6 @@ def __call__(self,
attention_dropout_rate=self.attention_dropout_rate,
name='Transformer')(
x, train=train)
cls_token = x[:, 0, 0]

else:
raise ValueError(
Expand Down

0 comments on commit d3c7026

Please sign in to comment.