diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index c1fd4e86d3a4..1f247284afbc 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -34,6 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. +- Swin pads the inputs supporting any input height and width (if divisible by `32`). +- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`. drawing diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index bdfc66b0dc00..45bf23d3cb65 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -17,6 +17,8 @@ import collections.abc import math +from dataclasses import dataclass +from typing import Optional, Tuple import torch import torch.utils.checkpoint @@ -25,12 +27,12 @@ from ...activations import ACT2FN from ...file_utils import ( + ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_swin import SwinConfig @@ -56,10 +58,150 @@ # See all Swin models at https://huggingface.co/models?filter=swin ] - # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. +@dataclass +class SwinEncoderOutput(ModelOutput): + """ + Swin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SwinModelOutput(ModelOutput): + """ + Swin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Average pooling of the last layer hidden-state. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SwinMaskedImageModelingOutput(ModelOutput): + """ + Swin masked image model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): + Masked image modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SwinImageClassifierOutput(ModelOutput): + """ + Swin outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.vit.modeling_vit.to_2tuple def to_2tuple(x): if isinstance(x, collections.abc.Iterable): @@ -130,7 +272,7 @@ def __init__(self, config, use_mask_token=False): self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, pixel_values, bool_masked_pos=None): - embeddings = self.patch_embeddings(pixel_values) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -145,7 +287,7 @@ def forward(self, pixel_values, bool_masked_pos=None): embeddings = self.dropout(embeddings) - return embeddings + return embeddings, output_dimensions class SwinPatchEmbeddings(nn.Module): @@ -165,9 +307,25 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + def forward(self, pixel_values): - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings + _, _, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions class SwinPatchMerging(nn.Module): @@ -190,17 +348,30 @@ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) - def forward(self, input_feature): - height, width = self.input_resolution + def maybe_pad(self, input_feature, height, width): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature, input_dimensions): + height, width = input_dimensions # `dim` is height * width batch_size, dim, num_channels = input_feature.shape input_feature = input_feature.view(batch_size, height, width, num_channels) - - input_feature_0 = input_feature[:, 0::2, 0::2, :] # batch_size height/2 width/2 num_channels - input_feature_1 = input_feature[:, 1::2, 0::2, :] # batch_size height/2 width/2 num_channels - input_feature_2 = input_feature[:, 0::2, 1::2, :] # batch_size height/2 width/2 num_channels - input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 num_channels + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 4*num_channels input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C @@ -393,19 +564,14 @@ def forward(self, hidden_states): return hidden_states -class SwinBlock(nn.Module): +class SwinLayer(nn.Module): def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - + self.set_shift_and_window_size(input_resolution) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = SwinAttention(config, dim, num_heads) self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -413,9 +579,15 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.intermediate = SwinIntermediate(config, dim) self.output = SwinOutput(config, dim) + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(input_resolution) + + def get_attn_mask(self, height, width): if self.shift_size > 0: # calculate attention mask for SW-MSA - height, width = self.input_resolution img_mask = torch.zeros((1, height, width, 1)) height_slices = ( slice(0, -self.window_size), @@ -439,17 +611,27 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - - self.attn_mask = attn_mask - - def forward(self, hidden_states, head_mask=None, output_attentions=False): - height, width = self.input_resolution - batch_size, dim, channels = hidden_states.size() + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_right = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + self.set_shift_and_window_size(input_dimensions) + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + _, height_pad, width_pad, _ = hidden_states.shape # cyclic shift if self.shift_size > 0: shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) @@ -459,23 +641,18 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask(height_pad, width_pad) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) - if self.attn_mask is not None: - self.attn_mask = self.attn_mask.to(hidden_states_windows.device) - - self_attention_outputs = self.attention( - hidden_states_windows, - self.attn_mask, - head_mask, - output_attentions=output_attentions, + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions ) - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + attention_output = attention_outputs[0] attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) - shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # reverse cyclic shift if self.shift_size > 0: @@ -483,6 +660,10 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): else: attention_windows = shifted_windows + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + attention_windows = attention_windows.view(batch_size, height * width, channels) hidden_states = shortcut + self.drop_path(attention_windows) @@ -491,19 +672,18 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) - outputs = (layer_output,) + outputs - - return outputs + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) + return layer_outputs -class SwinLayer(nn.Module): +class SwinStage(nn.Module): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config self.dim = dim self.blocks = nn.ModuleList( [ - SwinBlock( + SwinLayer( config=config, dim=dim, input_resolution=input_resolution, @@ -522,29 +702,28 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d self.pointing = False - def forward(self, hidden_states, head_mask=None, output_attentions=False, output_hidden_states=False): - all_hidden_states = () if output_hidden_states else None - - for i, block_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + height, width = input_dimensions + for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = block_module( - hidden_states, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] if self.downsample is not None: - layer_outputs_list = list(layer_outputs) - layer_outputs_list[0] = self.downsample(layer_outputs[0]) - layer_outputs = tuple(layer_outputs_list) + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(layer_outputs[0], input_dimensions) + else: + output_dimensions = (height, width, height, width) - return layer_outputs + stage_outputs = (hidden_states, output_dimensions) + + if output_attentions: + stage_outputs += layer_outputs[1:] + return stage_outputs class SwinEncoder(nn.Module): @@ -555,7 +734,7 @@ def __init__(self, config, grid_size): dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] self.layers = nn.ModuleList( [ - SwinLayer( + SwinStage( config=config, dim=int(config.embed_dim * 2**i_layer), input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), @@ -573,18 +752,26 @@ def __init__(self, config, grid_size): def forward( self, hidden_states, + input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): + all_input_dimensions = () all_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - for i, layer_module in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: @@ -596,23 +783,36 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask + create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + output_dimensions = layer_outputs[1] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshaped_hidden_states += (reshaped_hidden_state,) + + if output_attentions: + all_self_attentions += layer_outputs[2:] if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + return SwinEncoderOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshaped_hidden_states, ) @@ -712,7 +912,7 @@ class PreTrainedModel @add_code_sample_docstrings( processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPooling, + output_type=SwinModelOutput, config_class=_CONFIG_FOR_DOC, modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -742,10 +942,11 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs = self.encoder( embedding_output, + input_dimensions, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -761,13 +962,16 @@ def forward( pooled_output = torch.flatten(pooled_output, 1) if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + output = (sequence_output, pooled_output) + encoder_outputs[1:] + + return output - return BaseModelOutputWithPooling( + return SwinModelOutput( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, ) @@ -791,7 +995,7 @@ def __init__(self, config): self.post_init() @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values=None, @@ -869,11 +1073,12 @@ def forward( output = (reconstructed_pixel_values,) + outputs[2:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return SwinMaskedImageModelingOutput( loss=masked_im_loss, logits=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, ) @@ -903,7 +1108,7 @@ def __init__(self, config): @add_code_sample_docstrings( processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=SequenceClassifierOutput, + output_type=SwinImageClassifierOutput, config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) @@ -963,9 +1168,10 @@ def forward( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput( + return SwinImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, ) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 25acbb724f68..373e812d5fe7 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -230,15 +230,6 @@ def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = False @@ -248,8 +239,9 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), len(self.model_tester.depths)) + attentions = outputs.attentions + expected_num_attentions = len(self.model_tester.depths) + self.assertEqual(len(attentions), expected_num_attentions) # check that output_attentions also work using config del inputs_dict["output_attentions"] @@ -260,19 +252,13 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), len(self.model_tester.depths)) - - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], - ) - else: - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_heads[0], window_size_squared, window_size_squared], - ) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) out_len = len(outputs) # Check attention is always last and order is fine @@ -286,25 +272,19 @@ def test_attention_outputs(self): if hasattr(self.model_tester, "num_hidden_states_types"): added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 else: - added_hidden_states = 1 + # also another +1 for reshaped_hidden_states + added_hidden_states = 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions - self.assertEqual(len(self_attentions), len(self.model_tester.depths)) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_heads[0], window_size_squared, window_size_squared], - ) + self.assertEqual(len(self_attentions), expected_num_attentions) + + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): @@ -315,7 +295,7 @@ def check_hidden_states_output(inputs_dict, config, model_class): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1 @@ -325,6 +305,7 @@ def check_hidden_states_output(inputs_dict, config, model_class): # Swin has a different seq_length image_size = to_2tuple(self.model_tester.image_size) patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.assertListEqual( @@ -332,6 +313,18 @@ def check_hidden_states_output(inputs_dict, config, model_class): [num_patches, self.model_tester.embed_dim], ) + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: @@ -395,7 +388,5 @@ def test_inference_image_classification_head(self): # verify the logits expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) - self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))