Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added interpolation for vitmae model in pytorch as well as tf. #30732

Merged
164 changes: 133 additions & 31 deletions src/transformers/models/vit_mae/modeling_tf_vit_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,38 @@ def build(self, input_shape=None):
with tf.name_scope(self.patch_embeddings.name):
self.patch_embeddings.build(None)

def interpolate_pos_encoding(self, embeddings, height, width) -> tf.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

batch_size, seq_len, dim = shape_list(embeddings)
num_patches = seq_len - 1

_, num_positions, _ = shape_list(self.position_embeddings)
num_positions -= 1

if num_patches == num_positions and height == width:
return self.position_embeddings
class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
patch_pos_embed = tf.image.resize(
images=tf.reshape(
patch_pos_embed, shape=(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
),
size=(h0, w0),
method="bicubic",
)

patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):
"""
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
Expand Down Expand Up @@ -282,17 +314,23 @@ def random_masking(self, sequence: tf.Tensor, noise: tf.Tensor | None = None):

return sequence_unmasked, mask, ids_restore

def call(self, pixel_values: tf.Tensor, noise: tf.Tensor = None) -> tf.Tensor:
embeddings = self.patch_embeddings(pixel_values)

def call(
self, pixel_values: tf.Tensor, noise: tf.Tensor = None, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if interpolate_pos_encoding:
position_embeddings = self.interpolate_pos_encoding(embeddings, height, width)
else:
position_embeddings = self.position_embeddings
# add position embeddings w/o cls token
embeddings = embeddings + self.position_embeddings[:, 1:, :]
embeddings = embeddings + position_embeddings[:, 1:, :]

# masking: length -> length * config.mask_ratio
embeddings, mask, ids_restore = self.random_masking(embeddings, noise)

# append cls token
cls_token = self.cls_token + self.position_embeddings[:, :1, :]
cls_token = self.cls_token + position_embeddings[:, :1, :]
cls_tokens = tf.tile(cls_token, (shape_list(embeddings)[0], 1, 1))
embeddings = tf.concat([cls_tokens, embeddings], axis=1)

Expand Down Expand Up @@ -330,15 +368,17 @@ def __init__(self, config: ViTMAEConfig, **kwargs):
name="projection",
)

def call(self, pixel_values: tf.Tensor, training: bool = False) -> tf.Tensor:
def call(
self, pixel_values: tf.Tensor, training: bool = False, interpolate_pos_encoding: bool = False
) -> tf.Tensor:
batch_size, num_channels, height, width = shape_list(pixel_values)
if tf.executing_eagerly():
if num_channels != self.num_channels:
raise ValueError(
"Make sure that the channel dimension of the pixel values match with the one set in the"
" configuration."
)
if height != self.image_size[0] or width != self.image_size[1]:
if not interpolate_pos_encoding and (height != self.image_size[0] or width != self.image_size[1]):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model"
f" ({self.image_size[0]}*{self.image_size[1]})."
Expand Down Expand Up @@ -742,9 +782,13 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
embedding_output, mask, ids_restore = self.embeddings(
pixel_values=pixel_values, training=training, noise=noise
pixel_values=pixel_values,
training=training,
noise=noise,
interpolate_pos_encoding=interpolate_pos_encoding,
)

# Prepare head mask if needed
Expand Down Expand Up @@ -875,6 +919,9 @@ class TFViTMAEPreTrainedModel(TFPreTrainedModel):
training (`bool`, *optional*, defaults to `False``):
Whether or not to use the model in training mode (some modules like dropout modules have different
behaviors between training and evaluation).

interpolate_pos_encoding (`bool`, *optional*, defaults to `False`):
Whether to interpolate the position encodings at the encoder and decoder.
"""


Expand Down Expand Up @@ -903,6 +950,7 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEModelOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -932,6 +980,7 @@ def call(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)

return outputs
Expand Down Expand Up @@ -1005,17 +1054,50 @@ def build(self, input_shape=None):
with tf.name_scope(layer.name):
layer.build(None)

def interpolate_pos_encoding(self, embeddings) -> tf.Tensor:
"""
This method is a modified version of the interpolation function for ViT-mae model at the deocder, that
allows to interpolate the pre-trained decoder position encodings, to be able to use the model on higher
resolution images.

Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""

# [batch_size, num_patches + 1, hidden_size]
_, num_positions, dim = shape_list(self.decoder_pos_embed)

# -1 removes the class dimension since we later append it without interpolation
seq_len = shape_list(embeddings)[1] - 1
num_positions = num_positions - 1

# Separation of class token and patch tokens
class_pos_embed = self.decoder_pos_embed[:, :1, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]

# interpolate the position embeddings
patch_pos_embed = tf.image.resize(
images=tf.reshape(patch_pos_embed, shape=(1, 1, -1, dim)),
size=(1, seq_len),
method="bicubic",
)

# [1, seq_len, hidden_size]
patch_pos_embed = tf.reshape(tensor=patch_pos_embed, shape=(1, -1, dim))
# Adding the class token back
return tf.concat(values=(class_pos_embed, patch_pos_embed), axis=1)

def call(
self,
hidden_states,
ids_restore,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
interpolate_pos_encoding=False,
):
# embed tokens
x = self.decoder_embed(hidden_states)

# append mask tokens to sequence
mask_tokens = tf.tile(
self.mask_token,
Expand All @@ -1024,10 +1106,12 @@ def call(
x_ = tf.concat([x[:, 1:, :], mask_tokens], axis=1) # no cls token
x_ = tf.gather(x_, axis=1, batch_dims=1, indices=ids_restore) # unshuffle
x = tf.concat([x[:, :1, :], x_], axis=1) # append cls token

if interpolate_pos_encoding:
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
# add pos embed
hidden_states = x + self.decoder_pos_embed

hidden_states = x + decoder_pos_embed
# apply Transformer layers (blocks)
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
Expand Down Expand Up @@ -1084,11 +1168,13 @@ def get_input_embeddings(self):
def _prune_heads(self, heads_to_prune):
raise NotImplementedError

def patchify(self, pixel_values):
def patchify(self, pixel_values, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)` or `(batch_size, num_channels, height, width)`):
Pixel values.
interpolate_pos_encoding (`bool`, default `False`):
interpolation flag passed during the forward pass.

Returns:
`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Expand All @@ -1100,11 +1186,12 @@ def patchify(self, pixel_values):
pixel_values = tf.transpose(pixel_values, perm=(0, 2, 3, 1))

# sanity checks
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
if not interpolate_pos_encoding:
tf.debugging.assert_equal(
shape_list(pixel_values)[1],
shape_list(pixel_values)[2],
message="Make sure the pixel values have a squared size",
)
tf.debugging.assert_equal(
shape_list(pixel_values)[1] % patch_size,
0,
Expand All @@ -1120,51 +1207,61 @@ def patchify(self, pixel_values):

# patchify
batch_size = shape_list(pixel_values)[0]
num_patches_one_direction = shape_list(pixel_values)[2] // patch_size
num_patches_h = shape_list(pixel_values)[1] // patch_size
num_patches_w = shape_list(pixel_values)[2] // patch_size
patchified_pixel_values = tf.reshape(
pixel_values,
(batch_size, num_patches_one_direction, patch_size, num_patches_one_direction, patch_size, num_channels),
(batch_size, num_patches_h, patch_size, num_patches_w, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhpwqc->nhwpqc", patchified_pixel_values)
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * num_patches_one_direction, patch_size**2 * num_channels),
(batch_size, num_patches_h * num_patches_w, patch_size**2 * num_channels),
)
return patchified_pixel_values

def unpatchify(self, patchified_pixel_values):
def unpatchify(self, patchified_pixel_values, original_image_size: Optional[Tuple[int, int]] = None):
"""
Args:
patchified_pixel_values (`tf.Tensor` of shape `(batch_size, num_patches, patch_size**2 * num_channels)`:
Patchified pixel values.
original_image_size (`Tuple[int, int]`, default `None`):
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
Original image size.

Returns:
`tf.Tensor` of shape `(batch_size, height, width, num_channels)`:
Pixel values.
"""
patch_size, num_channels = self.config.patch_size, self.config.num_channels
num_patches_one_direction = int(shape_list(patchified_pixel_values)[1] ** 0.5)
original_image_size = (
original_image_size
if original_image_size is not None
else (self.config.image_size, self.config.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
# sanity check
tf.debugging.assert_equal(
num_patches_one_direction * num_patches_one_direction,
num_patches_h * num_patches_w,
shape_list(patchified_pixel_values)[1],
message="Make sure that the number of patches can be squared",
message=f"The number of patches in the patchified pixel values is {shape_list(patchified_pixel_values)[1]} does not match the patches of original image {num_patches_w}*{num_patches_h}",
)

# unpatchify
batch_size = shape_list(patchified_pixel_values)[0]
patchified_pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction, num_patches_one_direction, patch_size, patch_size, num_channels),
(batch_size, num_patches_h, num_patches_w, patch_size, patch_size, num_channels),
)
patchified_pixel_values = tf.einsum("nhwpqc->nhpwqc", patchified_pixel_values)
pixel_values = tf.reshape(
patchified_pixel_values,
(batch_size, num_patches_one_direction * patch_size, num_patches_one_direction * patch_size, num_channels),
(batch_size, num_patches_h * patch_size, num_patches_w * patch_size, num_channels),
)
return pixel_values

def forward_loss(self, pixel_values, pred, mask):
def forward_loss(self, pixel_values, pred, mask, interpolate_pos_encoding: bool = False):
"""
Args:
pixel_values (`tf.Tensor` of shape `(batch_size, height, width, num_channels)`):
Expand All @@ -1173,11 +1270,13 @@ def forward_loss(self, pixel_values, pred, mask):
Predicted pixel values.
mask (`tf.Tensor` of shape `(batch_size, sequence_length)`):
Tensor indicating which patches are masked (1) and which are not (0).
interpolate_pos_encoding (`bool`, default `False`):
bhuvanmdev marked this conversation as resolved.
Show resolved Hide resolved
interpolation flag passed during the forward pass.

Returns:
`tf.Tensor`: Pixel reconstruction loss.
"""
target = self.patchify(pixel_values)
target = self.patchify(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
if self.config.norm_pix_loss:
mean = tf.reduce_mean(target, axis=-1, keepdims=True)
var = tf.math.reduce_variance(target, axis=-1, keepdims=True)
Expand All @@ -1202,6 +1301,7 @@ def call(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
training: bool = False,
interpolate_pos_encoding: bool = False,
) -> Union[TFViTMAEForPreTrainingOutput, Tuple[tf.Tensor]]:
r"""
Returns:
Expand Down Expand Up @@ -1235,16 +1335,18 @@ def call(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
interpolate_pos_encoding=interpolate_pos_encoding,
)

latent = outputs.last_hidden_state
ids_restore = outputs.ids_restore
mask = outputs.mask

decoder_outputs = self.decoder(latent, ids_restore) # [batch_size, num_patches, patch_size**2*3]
# [batch_size, num_patches, patch_size**2*3]
decoder_outputs = self.decoder(latent, ids_restore, interpolate_pos_encoding=interpolate_pos_encoding)
logits = decoder_outputs.logits

loss = self.forward_loss(pixel_values, logits, mask)
loss = self.forward_loss(pixel_values, logits, mask, interpolate_pos_encoding=interpolate_pos_encoding)

if not return_dict:
output = (logits, mask, ids_restore) + outputs[2:]
Expand Down