Skip to content

Commit

Permalink
TokenLearner module v1.1 now supports input tensors with non-square s…
Browse files Browse the repository at this point in the history
…hapes.

PiperOrigin-RevId: 443365186
  • Loading branch information
Michael Ryoo authored and Scenic Authors committed Apr 21, 2022
1 parent 77ba56d commit d795fc4
Showing 1 changed file with 33 additions and 111 deletions.
144 changes: 33 additions & 111 deletions scenic/projects/token_learner/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,19 @@ def __call__(self, inputs: jnp.ndarray) -> jnp.ndarray:
"""Applies learnable tokenization to the 2D inputs.
Args:
inputs: Inputs of shape `[bs, h, w, c]`.
inputs: Inputs of shape `[bs, h, w, c]` or `[bs, hw, c]`.
Returns:
Output of shape `[bs, n_token, c]`.
"""
if inputs.ndim == 3:
n, hw, c = inputs.shape
h = int(math.sqrt(hw))
inputs = jnp.reshape(inputs, [n, h, h, c])

if h * h != hw:
raise ValueError('Only square inputs supported.')

feature_shape = inputs.shape

selected = inputs
Expand Down Expand Up @@ -132,13 +140,9 @@ def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
Returns:
Output of shape `[bs, n_token, c]`.
"""
if inputs.ndim == 3:
n, hw, c = inputs.shape
h = int(math.sqrt(hw))
inputs = jnp.reshape(inputs, [n, h, h, c])

if h * h != hw:
raise ValueError('Only square inputs supported.')
if inputs.ndim == 4:
n, h, w, c = inputs.shape
inputs = jnp.reshape(inputs, [n, h*w, c])

feature_shape = inputs.shape

Expand All @@ -155,86 +159,20 @@ def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
selected, deterministic=deterministic)

selected = jnp.reshape(
selected, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
]) # Shape: [bs, h*w, n_token].
selected,
[feature_shape[0], -1, self.num_tokens]) # Shape: [bs, h*w, n_token].
selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].
selected = jax.nn.softmax(selected, axis=-1)

feat = inputs
feat = jnp.reshape(
feat, [feature_shape[0], feature_shape[1] * feature_shape[2], -1
]) # Shape: [bs, h*w, c].
feat, [feature_shape[0], -1, feature_shape[-1]]) # Shape: [bs, h*w, c].

feat = jnp.einsum('...si,...id->...sd', selected, feat)

return feat


class TokenLearnerModuleMixer(nn.Module):
"""TokenLearner module using the MLPMixer block instead of conv layers..
Attributes:
num_tokens: Number of tokens.
dropout_rate: Dropout rate.
"""
num_tokens: int
dropout_rate: float = 0.

@nn.compact
def __call__(self, inputs: jnp.ndarray, deterministic: bool) -> jnp.ndarray:
"""Applies learnable tokenization to the 2D inputs.
Args:
inputs: Inputs of shape `[bs, h, w, c]`.
deterministic: Weather we are in the deterministic mode (e.g inference
time) or not.
Returns:
Output of shape `[bs, n_token, c]`.
"""
b, h, w, c = inputs.shape

selected = jnp.reshape(inputs, [b, h*w, c])
skip = selected
selected = nn.LayerNorm()(selected)
selected = jnp.transpose(selected, [0, 2, 1])

selected = nn.Dense(
h*w,
kernel_init=nn.initializers.zeros)(selected)
selected = nn.gelu(selected)
selected = nn.Dense(
h*w,
kernel_init=nn.initializers.zeros)(selected)

selected = jnp.transpose(selected, [0, 2, 1])

selected = selected + skip

selected = nn.LayerNorm()(selected)
selected = nn.Dense(
c,
kernel_init=nn.initializers.zeros)(selected)
selected = nn.gelu(selected)
selected = nn.Dense(
self.num_tokens,
kernel_init=nn.initializers.zeros)(selected)
# Shape: [bs, h*w, n_token].

selected = jnp.transpose(selected, [0, 2, 1]) # Shape: [bs, n_token, h*w].
selected = nn.LayerNorm()(selected)
selected = nn.sigmoid(selected) # Shape: [bs, n_token, h*w].

selected = nn.Dropout(rate=self.dropout_rate)(selected,
deterministic=deterministic)

feat = inputs
feat = jnp.reshape(
feat, [b, h * w, c]) # Shape: [bs, h*w, c].

return jnp.einsum('...si,...id->...sd', selected, feat)


class TokenFuser(nn.Module):
"""Token fusion module.
Expand All @@ -256,16 +194,20 @@ def __call__(self, inputs: jnp.ndarray, original: jnp.ndarray,
Args:
inputs: Inputs of shape `[bs, n_token, c]`.
original: Inputs of shape `[bs, h, w, c]`.
original: Inputs of shape `[bs, hw, c]` or `[bs, h, w, c]`.
deterministic: Weather we are in the deterministic mode (e.g inference
time) or not.
Returns:
Output of shape `[bs, h, w, c]`.
Output tensor with the shape identical to `original'.
"""
feature_shape = inputs.shape
num_tokens = feature_shape[-2]

if original.ndim == 4:
n, h, w, c = original.shape
original = jnp.reshape(original, [n, h*w, c])

if self.use_normalization:
inputs = nn.LayerNorm(name='fuser_mix_norm1')(inputs)

Expand All @@ -286,16 +228,16 @@ def __call__(self, inputs: jnp.ndarray, original: jnp.ndarray,
activation_fn=nn.gelu,
name='token_masking')(
original, deterministic=deterministic)
mix = nn.sigmoid(mix)[..., None]
mix = nn.sigmoid(mix)

inputs = inputs[:, None, None, ...]

inputs = inputs * mix
inputs = jnp.sum(inputs, axis=-2)
inputs = jnp.einsum('...sc,...hs->...hc', inputs, mix)

inputs = nn.Dropout(rate=self.dropout_rate)(
inputs, deterministic=deterministic)

if original.ndim == 4:
inputs = jnp.reshape(inputs, [n, h, w, -1])

return inputs


Expand Down Expand Up @@ -351,9 +293,8 @@ def __call__(self, inputs: jnp.ndarray, *, train: bool = False):
tl_index = tl_locs.index(lyr)

n, thw, c = x.shape
hw = thw / self.temporal_dimensions
h = int(math.sqrt(hw))
x = jnp.reshape(x, [n * self.temporal_dimensions, h, h, c])
hw = thw // self.temporal_dimensions
x = jnp.reshape(x, [n * self.temporal_dimensions, hw, c])
if self.use_v11:
x = TokenLearnerModuleV11(
tl_size[tl_index], dropout_rate=self.dropout_rate)(
Expand Down Expand Up @@ -436,9 +377,8 @@ def __call__(self, inputs: jnp.ndarray, *, train: bool = False):
if (self.tokenizer_type in {'dynamic', 'video'} and
lyr >= self.tokenlearner_loc):
n, thw, c = x.shape
hw = thw / self.temporal_dimensions
h = int(math.sqrt(hw))
x = jnp.reshape(x, [n * self.temporal_dimensions, h, h, c])
hw = thw // self.temporal_dimensions
x = jnp.reshape(x, [n * self.temporal_dimensions, hw, c])
residual = x
if self.use_v11:
x = TokenLearnerModuleV11(
Expand All @@ -461,9 +401,9 @@ def __call__(self, inputs: jnp.ndarray, *, train: bool = False):
x = jnp.reshape(x, [n * self.temporal_dimensions, n_tokens, c])
x = TokenFuser(dropout_rate=self.dropout_rate)(
x, residual,
deterministic=not train) # [n * t, n_tokens, c], [n * t, h, h, c]
deterministic=not train) # [n * t, n_tokens, c], [n * t, hw, c]
x = x + residual
x = jnp.reshape(x, [n, self.temporal_dimensions * h * h, c])
x = jnp.reshape(x, [n, self.temporal_dimensions * hw, c])

else:
x = vit.Encoder1DBlock(
Expand Down Expand Up @@ -628,6 +568,7 @@ def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False):
n, t, h, w, _ = x.shape
x = jnp.reshape(x, [n * t, h, w, -1])
else:
n = x.shape[0]
t = 1

x = nn.Conv(
Expand Down Expand Up @@ -671,25 +612,6 @@ def __call__(self, x: jnp.ndarray, *, train: bool, debug: bool = False):
name='Transformer')(
x, train=train)

if self.use_concat_final:
n, hw, _ = x.shape
per_loc_channel_dim = self.target_channel_dim // hw
x = nn.Dense(
per_loc_channel_dim,
kernel_init=nn.initializers.zeros,
name='output_projection')(
x)

x = nn.tanh(x)
x = jnp.reshape(x, [n, -1])
else:
x = jnp.mean(x, axis=1)

x = nn.Dense(
self.target_channel_dim,
kernel_init=nn.initializers.zeros,
name='output_projection')(
x)
return x


Expand Down

0 comments on commit d795fc4

Please sign in to comment.