Skip to content

Commit

Permalink
Refactoring : improving readability and reducing the number of permut…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
mathieujouffroy committed Oct 7, 2022
1 parent a6aac0f commit 1978d26
Showing 1 changed file with 32 additions and 49 deletions.
81 changes: 32 additions & 49 deletions src/transformers/models/cvt/modeling_tf_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class TFCvtConvEmbeddings(tf.keras.layers.Layer):

def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: int, padding: int, **kwargs):
super().__init__(**kwargs)
self.pad_value = padding
self.padding = tf.keras.layers.ZeroPadding2D(padding=padding)
self.patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
self.projection = tf.keras.layers.Conv2D(
filters=embed_dim,
Expand All @@ -147,31 +147,20 @@ def __init__(self, config: CvtConfig, patch_size: int, embed_dim: int, stride: i
# Using the same default epsilon as PyTorch
self.normalization = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="normalization")

def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
# Custom padding to match the model implementation in PyTorch
height_pad = width_pad = (self.pad_value, self.pad_value)
hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)])
hidden_state = self.projection(hidden_state)
return hidden_state

def call(self, pixel_values: tf.Tensor) -> tf.Tensor:
if isinstance(pixel_values, dict):
pixel_values = pixel_values["pixel_values"]

# When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
# as input format. So change the input format to (batch_size, height, width, num_channels).
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))
pixel_values = self.convolution(pixel_values)
pixel_values = self.projection(self.padding(pixel_values))

# rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels"
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
batch_size, height, width, num_channels = shape_list(pixel_values)
hidden_size = height * width
pixel_values = tf.reshape(pixel_values, shape=(batch_size, hidden_size, num_channels))
pixel_values = self.normalization(pixel_values)

# rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width"
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 1))
pixel_values = tf.reshape(pixel_values, shape=(batch_size, num_channels, height, width))
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
pixel_values = tf.reshape(pixel_values, shape=(batch_size, height, width, num_channels))
return pixel_values


Expand All @@ -180,8 +169,8 @@ class TFCvtSelfAttentionConvProjection(tf.keras.layers.Layer):

def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride: int, padding: int, **kwargs):
super().__init__(**kwargs)
self.pad_value = padding
self.conv = tf.keras.layers.Conv2D(
self.padding = tf.keras.layers.ZeroPadding2D(padding=padding)
self.convolution = tf.keras.layers.Conv2D(
filters=embed_dim,
kernel_size=kernel_size,
kernel_initializer=get_initializer(config.initializer_range),
Expand All @@ -194,15 +183,8 @@ def __init__(self, config: CvtConfig, embed_dim: int, kernel_size: int, stride:
# Using the same default epsilon & momentum as PyTorch
self.normalization = tf.keras.layers.BatchNormalization(epsilon=1e-5, momentum=0.1, name="normalization")

def convolution(self, hidden_state: tf.Tensor) -> tf.Tensor:
# Custom padding to match the model implementation in PyTorch
height_pad = width_pad = (self.pad_value, self.pad_value)
hidden_state = tf.pad(hidden_state, [(0, 0), height_pad, width_pad, (0, 0)])
hidden_state = self.conv(hidden_state)
return hidden_state

def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_state = self.convolution(hidden_state)
hidden_state = self.convolution(self.padding(hidden_state))
hidden_state = self.normalization(hidden_state, training=training)
return hidden_state

Expand All @@ -211,7 +193,7 @@ class TFCvtSelfAttentionLinearProjection(tf.keras.layers.Layer):
"""Linear projection layer used to flatten tokens into 1D."""

def call(self, hidden_state: tf.Tensor) -> tf.Tensor:
# rearrange "batch_size, height, width, num_channels -> batch_size, (height, width), num_channels"
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
batch_size, height, width, num_channels = shape_list(hidden_state)
hidden_size = height * width
hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))
Expand Down Expand Up @@ -239,9 +221,6 @@ def __init__(
self.linear_projection = TFCvtSelfAttentionLinearProjection()

def call(self, hidden_state: tf.Tensor, training: bool = False) -> tf.Tensor:
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
# as input format. So change the input format to (batch_size, height, width, num_channels).
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))
hidden_state = self.convolution_projection(hidden_state, training=training)
hidden_state = self.linear_projection(hidden_state)
return hidden_state
Expand Down Expand Up @@ -337,10 +316,9 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool
if self.with_cls_token:
cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)

# rearrange "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width"
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
batch_size, hidden_size, num_channels = shape_list(hidden_state)
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1))
hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width))
hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))

key = self.convolution_projection_key(hidden_state, training=training)
query = self.convolution_projection_query(hidden_state, training=training)
Expand All @@ -352,16 +330,17 @@ def call(self, hidden_state: tf.Tensor, height: int, width: int, training: bool
value = tf.concat((cls_token, value), axis=1)

head_dim = self.embed_dim // self.num_heads

query = self.rearrange_for_multi_head_attention(self.projection_query(query))
key = self.rearrange_for_multi_head_attention(self.projection_key(key))
value = self.rearrange_for_multi_head_attention(self.projection_value(value))

attention_score = tf.matmul(query, key, transpose_b=True) * self.scale
attention_probs = stable_softmax(logits=attention_score, axis=-1)
attention_probs = self.dropout(attention_probs, training=training)
context = tf.matmul(attention_probs, value)

# rearrange "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads, head_dim)"
context = tf.matmul(attention_probs, value)
# "batch_size, num_heads, hidden_size, head_dim -> batch_size, hidden_size, (num_heads*head_dim)"
_, _, hidden_size, _ = shape_list(context)
context = tf.transpose(context, perm=(0, 2, 1, 3))
context = tf.reshape(context, (batch_size, hidden_size, self.num_heads * head_dim))
Expand Down Expand Up @@ -601,11 +580,10 @@ def call(self, hidden_state: tf.Tensor, training: bool = False):
cls_token = None
hidden_state = self.embedding(hidden_state, training)

batch_size, num_channels, height, width = shape_list(hidden_state)
# rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels"
# "batch_size, height, width, num_channels -> batch_size, (height*width), num_channels"
batch_size, height, width, num_channels = shape_list(hidden_state)
hidden_size = height * width
hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, hidden_size))
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 1))
hidden_state = tf.reshape(hidden_state, shape=(batch_size, hidden_size, num_channels))

if self.config.cls_token[self.stage]:
cls_token = tf.repeat(self.cls_token, repeats=batch_size, axis=0)
Expand All @@ -618,9 +596,8 @@ def call(self, hidden_state: tf.Tensor, training: bool = False):
if self.config.cls_token[self.stage]:
cls_token, hidden_state = tf.split(hidden_state, [1, height * width], 1)

# rearrange -> "batch_size, (height, width), num_channels -> batch_size, num_channels, height, width"
hidden_state = tf.transpose(hidden_state, (0, 2, 1))
hidden_state = tf.reshape(hidden_state, shape=(batch_size, num_channels, height, width))
# "batch_size, (height*width), num_channels -> batch_size, height, width, num_channels"
hidden_state = tf.reshape(hidden_state, shape=(batch_size, height, width, num_channels))
return hidden_state, cls_token


Expand Down Expand Up @@ -651,13 +628,21 @@ def call(
) -> Union[TFBaseModelOutputWithCLSToken, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
hidden_state = pixel_values
# When running on CPU, `tf.keras.layers.Conv2D` doesn't support (batch_size, num_channels, height, width)
# as input format. So change the input format to (batch_size, height, width, num_channels).
hidden_state = tf.transpose(hidden_state, perm=(0, 2, 3, 1))

cls_token = None
for _, (stage_module) in enumerate(self.stages):
hidden_state, cls_token = stage_module(hidden_state, training=training)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_state,)

# Change back to (batch_size, num_channels, height, width) format to have uniformity in the modules
hidden_state = tf.transpose(hidden_state, perm=(0, 3, 1, 2))
if output_hidden_states:
all_hidden_states = tuple([tf.transpose(hs, perm=(0, 3, 1, 2)) for hs in all_hidden_states])

if not return_dict:
return tuple(v for v in [hidden_state, cls_token, all_hidden_states] if v is not None)

Expand Down Expand Up @@ -727,9 +712,7 @@ def dummy_inputs(self) -> Dict[str, tf.Tensor]:
Returns:
`Dict[str, tf.Tensor]`: The dummy inputs.
"""
VISION_DUMMY_INPUTS = tf.random.uniform(
shape=(3, self.config.num_channels, self.config.image_size, self.config.image_size), dtype=tf.float32
)
VISION_DUMMY_INPUTS = tf.random.uniform(shape=(3, self.config.num_channels, 224, 224), dtype=tf.float32)
return {"pixel_values": tf.constant(VISION_DUMMY_INPUTS)}

@tf.function(
Expand Down Expand Up @@ -880,7 +863,7 @@ def __init__(self, config: CvtConfig, *inputs, **kwargs):
self.num_labels = config.num_labels
self.cvt = TFCvtMainLayer(config, name="cvt")
# Using same default epsilon as in the original implementation.
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")
self.layernorm = tf.keras.layers.LayerNormalization(epsilon=1e-5, name="layernorm")

# Classifier head
self.classifier = tf.keras.layers.Dense(
Expand Down Expand Up @@ -942,13 +925,13 @@ def call(
sequence_output = outputs[0]
cls_token = outputs[1]
if self.config.cls_token[-1]:
sequence_output = self.LayerNorm(cls_token)
sequence_output = self.layernorm(cls_token)
else:
# rearrange "batch_size, num_channels, height, width -> batch_size, (height, width), num_channels"
# rearrange "batch_size, num_channels, height, width -> batch_size, (height*width), num_channels"
batch_size, num_channels, height, width = shape_list(sequence_output)
sequence_output = tf.reshape(sequence_output, shape=(batch_size, num_channels, height * width))
sequence_output = tf.transpose(sequence_output, perm=(0, 2, 1))
sequence_output = self.LayerNorm(sequence_output)
sequence_output = self.layernorm(sequence_output)

sequence_output_mean = tf.reduce_mean(sequence_output, axis=1)
logits = self.classifier(sequence_output_mean)
Expand Down

0 comments on commit 1978d26

Please sign in to comment.