Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use lax.conv_general_dilated_patches to extract image patches in time linear to the number of channels (vs quadratic). #630

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
93 changes: 15 additions & 78 deletions scenic/model_lib/layers/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def extract_image_patches(lhs,
`'NCHW'`.

Returns:
A 4-D Tensor. Has the same type and data format as `lhs`, and with shape
`[batch, num_patches_col, num_patches_row, rhs_shape[1], rhs_shape[2], C]`.
A 4-D Tensor. Has the same type as `lhs`, and with shape
`[batch, num_patches_row, num_patches_col, rhs_shape[1], rhs_shape[2], C]`.
"""
num_dims = lhs.ndim
num_spatial_dims = num_dims - 2
Expand All @@ -76,83 +76,20 @@ def extract_image_patches(lhs,
'Current implementation does not support dilations in the batch '
'and depth dimensions.')

# Replicating tensorflow's implementation.
lhs_perm = lax.conv_general_permutations(
(data_format, 'HWIO', data_format))[0]
kernel_shape = [rhs_shape[i] for i in lhs_perm[2:]]

kernel_size = np.prod(kernel_shape)
conv_filter_shape = kernel_shape[:]
conv_filter_shape.append(1)
conv_filter_shape.append(kernel_size * depth)

iota_kernel_shape = (kernel_size, depth, kernel_size)

conv_filter = lax.eq(
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 0),
lax.broadcasted_iota(jnp.int32, iota_kernel_shape, 2),
filter_shape = tuple(rhs_shape[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim))
patches = lax.conv_general_dilated_patches(
lhs=lhs,
filter_shape=filter_shape,
padding=padding,
window_strides=tuple(strides[i] for i in range(num_dims) if i not in
(batch_dim, feature_dim)),
dimension_numbers=(data_format, 'HWIO', 'NHWC')
)
conv_filter = lax.convert_element_type(conv_filter, lhs.dtype)
conv_filter = lax.reshape(conv_filter, conv_filter_shape)

dim_num = lax.conv_dimension_numbers(lhs.shape, conv_filter.shape,
(data_format, 'HWIO', data_format))
conv_strides = [0] * num_spatial_dims
conv_rhs_dilation = [0] * num_spatial_dims
for i in range(num_spatial_dims):
dim = dim_num.lhs_spec[i + 2]
conv_strides[i] = strides[dim]
conv_rhs_dilation[i] = rhs_dilation[dim]

conv = lax.conv_general_dilated(lhs, conv_filter, conv_strides, padding, None,
conv_rhs_dilation, dim_num, depth)

conv_dims = list(conv.shape[:-1])
conv_dims.append(depth)
conv_dims.extend(kernel_shape)
conv = lax.reshape(conv, conv_dims)

permutation = list(range(len(conv_dims)))
depth_dim = permutation.pop(-3)
permutation.append(depth_dim)

return lax.transpose(conv, permutation)


def extract_patches(lhs, rhs_shape, strides=(1, 1)):
"""Extracts patches from an image using a convolution operator.

Args:
lhs: A tensor of images of shapes (B, H, W, C).
rhs_shape: The size of the patches to extract (h, w).
strides: The shift between extracted patches (s1, s2)

Returns:
All the patches in a tensor of dimension
(B, (H - h + 1) // s1, (W - w + 1) // s2, h, w, C).
"""
# [batch, channels, height, width]
lhs = jnp.moveaxis(lhs, -1, 1)
d = lhs.shape[1]
h, w = rhs_shape

# Construct the lookup conv weights.
dim_out = jnp.arange(d * h * w).reshape((-1, 1, 1, 1))
dim_in = jnp.arange(d).reshape((1, -1, 1, 1))
i = jnp.arange(h).reshape((1, 1, -1, 1))
j = jnp.arange(w).reshape((1, 1, 1, -1))
weights = ((w * i + j) * d + dim_in == dim_out).astype(jnp.float32)

# [batch, h * w * d, (H - h + 1) // s1, (W - w + 1) // s2]
concatenated_patches = lax.conv(
lhs, weights, window_strides=strides, padding='VALID')

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h * w * d]
concatenated_patches = jnp.moveaxis(concatenated_patches, 1, -1)

# [batch, (H - h + 1) // s1, (W - w + 1) // s2, h, w, d]
shape = concatenated_patches.shape[:3] + (h, w, d)
return concatenated_patches.reshape(shape)
shape = patches.shape[:-1] + (depth,) + filter_shape
patches = patches.reshape(shape)
patches = jnp.moveaxis(patches, -1 - num_spatial_dims, -1)
return patches


def compute_relative_positions(query_spatial_shape,
Expand Down
11 changes: 0 additions & 11 deletions scenic/model_lib/layers/tests/test_nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,17 +140,6 @@ def test_central_crop(self):
self.assertEqual(output[0, 0, 0, 0], 11.)
self.assertEqual(output[0, -1, -1, 0], 88.)

def test_extract_patches(self):
"""Tests extract_patches."""
input_shape = (16, 3, 3, 32)
inputs = np.array(np.random.normal(size=input_shape))

# patching a 3x3 image to 3x3 patches, with no stride 1x1 should do nothing
# but reshaping the (bs, h, w, c) to (bs, 1, 1, h, w, c)
patched = nn_ops.extract_patches(inputs, (3, 3), (1, 1))
self.assertEqual(patched.shape, (16, 1, 1, 3, 3, 32))
np.testing.assert_allclose(inputs, patched.reshape(input_shape), atol=1e-2)

@parameterized.named_parameters([('test_avg_pooling', 'avg_pooling'),
('test_max_pooling', 'max_pooling'),
('test_avg_pooling_bu', 'avg_pooling'),
Expand Down
11 changes: 8 additions & 3 deletions scenic/projects/fast_vit/xvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,9 +422,14 @@ def __call__(self,
if transformer_encoder_type in ['grid', 'grid_attention']:
# First put patches of patches in rows. (we already projected pixels to
# patches in the stem and at this level, the input tokens are patches).
# TODO(dehghani): Check if nn_ops.extract_image_patches is faster.
pp_size = self.transformer_encoder_configs.patches_of_patches_size
x = nn_ops.extract_patches(lhs=x, rhs_shape=pp_size, strides=pp_size)
x = nn_ops.extract_image_patches(
lhs=x,
rhs_shape=(1,) + pp_size + (1,),
strides=(1,) + pp_size + (1,),
padding='VALID',
rhs_dilation=(1,) * 4
)
# TODO(dehghani): Check if we can output a 4D tensor directly and get
# rid of this reshape.
bs, ph, pw, h, w, c = x.shape
Expand All @@ -436,6 +441,7 @@ def __call__(self,
elif transformer_encoder_type == 'grid_attention':
transformer_encoder_type = 'axial_attention'

cls_token = x[:, 0, 0]
x = EncoderAxial(
mlp_dim=self.mlp_dim,
num_layers=self.num_layers,
Expand All @@ -446,7 +452,6 @@ def __call__(self,
attention_dropout_rate=self.attention_dropout_rate,
name='Transformer')(
x, train=train)
cls_token = x[:, 0, 0]

else:
raise ValueError(
Expand Down