From ba964312573a973139c532437306d45b2a77e847 Mon Sep 17 00:00:00 2001 From: fra Date: Mon, 7 Mar 2022 15:21:26 +0100 Subject: [PATCH 01/25] padding done --- src/transformers/models/swin/modeling_swin.py | 249 +++++++++++++----- 1 file changed, 186 insertions(+), 63 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index bdfc66b0dc00..97b889564cc9 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -17,6 +17,8 @@ import collections.abc import math +from dataclasses import dataclass +from typing import Optional, Tuple import torch import torch.utils.checkpoint @@ -25,12 +27,13 @@ from ...activations import ACT2FN from ...file_utils import ( + ModelOutput, add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_swin import SwinConfig @@ -57,6 +60,71 @@ ] +@dataclass +class SwinBaseModelOutput(ModelOutput): + """ + Class for SwinEncoder's outputs. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward` + method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SwinModelOutputWithPooling(ModelOutput): + """ + Class for SwinModel's outputs that also contains the spatial dimensions of the hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state after a mean pooling operation. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): + A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to + `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the + `forward` method. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. @@ -130,7 +198,7 @@ def __init__(self, config, use_mask_token=False): self.dropout = nn.Dropout(config.hidden_dropout_prob) def forward(self, pixel_values, bool_masked_pos=None): - embeddings = self.patch_embeddings(pixel_values) + embeddings, output_dimensions = self.patch_embeddings(pixel_values) embeddings = self.norm(embeddings) batch_size, seq_len, _ = embeddings.size() @@ -145,7 +213,7 @@ def forward(self, pixel_values, bool_masked_pos=None): embeddings = self.dropout(embeddings) - return embeddings + return embeddings, output_dimensions class SwinPatchEmbeddings(nn.Module): @@ -165,9 +233,25 @@ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768) self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size) + def maybe_pad(self, pixel_values, height, width): + if width % self.patch_size[1] != 0: + pad_values = (0, self.patch_size[1] - width % self.patch_size[1]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + if height % self.patch_size[0] != 0: + pad_values = (0, 0, 0, self.patch_size[0] - height % self.patch_size[0]) + pixel_values = nn.functional.pad(pixel_values, pad_values) + return pixel_values + def forward(self, pixel_values): - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings + _, _, height, width = pixel_values.shape + # pad the input to be divisible by self.patch_size, if needed + pixel_values = self.maybe_pad(pixel_values, height, width) + embeddings = self.projection(pixel_values) + _, _, height, width = embeddings.shape + output_dimensions = (height, width) + embeddings = embeddings.flatten(2).transpose(1, 2) + + return embeddings, output_dimensions class SwinPatchMerging(nn.Module): @@ -190,17 +274,30 @@ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) - def forward(self, input_feature): - height, width = self.input_resolution + def maybe_pad(self, input_feature, width, height): + should_pad = (height % 2 == 1) or (width % 2 == 1) + if should_pad: + pad_values = (0, 0, 0, width % 2, 0, height % 2) + input_feature = nn.functional.pad(input_feature, pad_values) + + return input_feature + + def forward(self, input_feature, input_dimensions): + height, width = input_dimensions # `dim` is height * width batch_size, dim, num_channels = input_feature.shape input_feature = input_feature.view(batch_size, height, width, num_channels) - - input_feature_0 = input_feature[:, 0::2, 0::2, :] # batch_size height/2 width/2 num_channels - input_feature_1 = input_feature[:, 1::2, 0::2, :] # batch_size height/2 width/2 num_channels - input_feature_2 = input_feature[:, 0::2, 1::2, :] # batch_size height/2 width/2 num_channels - input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 num_channels + # pad input to be disible by width and height, if needed + input_feature = self.maybe_pad(input_feature, height, width) + # [batch_size, height/2, width/2, num_channels] + input_feature_0 = input_feature[:, 0::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_1 = input_feature[:, 1::2, 0::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_2 = input_feature[:, 0::2, 1::2, :] + # [batch_size, height/2, width/2, num_channels] + input_feature_3 = input_feature[:, 1::2, 1::2, :] # batch_size height/2 width/2 4*num_channels input_feature = torch.cat([input_feature_0, input_feature_1, input_feature_2, input_feature_3], -1) input_feature = input_feature.view(batch_size, -1, 4 * num_channels) # batch_size height/2*width/2 4*C @@ -360,9 +457,9 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): - self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) - attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + self_outputs, self_attentions = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs, hidden_states) + outputs = (attention_output, self_attentions) if output_attentions else (attention_output) return outputs @@ -413,9 +510,10 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.intermediate = SwinIntermediate(config, dim) self.output = SwinOutput(config, dim) + def get_attn_mask(self, input_resolution): if self.shift_size > 0: # calculate attention mask for SW-MSA - height, width = self.input_resolution + height, width = input_resolution img_mask = torch.zeros((1, height, width, 1)) height_slices = ( slice(0, -self.window_size), @@ -439,17 +537,27 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: attn_mask = None - - self.attn_mask = attn_mask - - def forward(self, hidden_states, head_mask=None, output_attentions=False): - height, width = self.input_resolution - batch_size, dim, channels = hidden_states.size() + return attn_mask + + def maybe_pad(self, hidden_states, height, width): + pad_left = pad_top = 0 + pad_rigth = (self.window_size - width % self.window_size) % self.window_size + pad_bottom = (self.window_size - height % self.window_size) % self.window_size + pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom) + hidden_states = nn.functional.pad(hidden_states, pad_values) + return hidden_states, pad_values + + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + height, width = input_dimensions + batch_size, _, channels = hidden_states.size() shortcut = hidden_states hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) + # pad hidden_states to multiples of window size + hidden_states, pad_values = self.maybe_pad(hidden_states, height, width) + _, height_pad, width_pad, _ = hidden_states.shape # cyclic shift if self.shift_size > 0: shifted_hidden_states = torch.roll(hidden_states, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) @@ -459,23 +567,18 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) + attn_mask = self.get_attn_mask((height_pad, width_pad)) + if attn_mask is not None: + attn_mask = attn_mask.to(hidden_states_windows.device) - if self.attn_mask is not None: - self.attn_mask = self.attn_mask.to(hidden_states_windows.device) - - self_attention_outputs = self.attention( - hidden_states_windows, - self.attn_mask, - head_mask, - output_attentions=output_attentions, + attention_outputs = self.attention( + hidden_states_windows, attn_mask, head_mask, output_attentions=output_attentions ) - attention_output = self_attention_outputs[0] - - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + attention_output = attention_outputs[0] attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) - shifted_windows = window_reverse(attention_windows, self.window_size, height, width) # B H' W' C + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C # reverse cyclic shift if self.shift_size > 0: @@ -483,6 +586,10 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): else: attention_windows = shifted_windows + was_padded = pad_values[3] > 0 or pad_values[5] > 0 + if was_padded: + attention_windows = attention_windows[:, :height, :width, :].contiguous() + attention_windows = attention_windows.view(batch_size, height * width, channels) hidden_states = shortcut + self.drop_path(attention_windows) @@ -491,9 +598,8 @@ def forward(self, hidden_states, head_mask=None, output_attentions=False): layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) - outputs = (layer_output,) + outputs - - return outputs + layer_outputs = (layer_output, attention_output[1:]) if output_attentions else (layer_output) + return layer_outputs class SwinLayer(nn.Module): @@ -522,28 +628,26 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d self.pointing = False - def forward(self, hidden_states, head_mask=None, output_attentions=False, output_hidden_states=False): - all_hidden_states = () if output_hidden_states else None + def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + height, width = input_dimensions for i, block_module in enumerate(self.blocks): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - layer_outputs = block_module( - hidden_states, - layer_head_mask, - output_attentions, - ) - - hidden_states = layer_outputs[0] + block_outputs = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) if self.downsample is not None: - layer_outputs_list = list(layer_outputs) - layer_outputs_list[0] = self.downsample(layer_outputs[0]) - layer_outputs = tuple(layer_outputs_list) + height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 + output_dimensions = (height, width, height_downsampled, width_downsampled) + hidden_states = self.downsample(block_outputs[0], input_dimensions) + else: + output_dimensions = (height, width, height, width) + layer_outputs = (hidden_states, output_dimensions) + + if output_attentions: + layer_outputs += (block_outputs[1:],) return layer_outputs @@ -573,18 +677,20 @@ def __init__(self, config, grid_size): def forward( self, hidden_states, + input_dimensions, head_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): + all_input_dimensions = () all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None - for i, layer_module in enumerate(self.layers): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None if self.gradient_checkpointing and self.training: @@ -599,20 +705,33 @@ def custom_forward(*inputs): create_custom_forward(layer_module), hidden_states, layer_head_mask ) else: - layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) + layer_outputs = layer_module( + hidden_states, + input_dimensions, + layer_head_mask, + output_attentions, + ) hidden_states = layer_outputs[0] - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) + output_dimensions = layer_outputs[1] - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + input_dimensions = (output_dimensions[-2], output_dimensions[-1]) + all_input_dimensions += (input_dimensions,) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if output_attentions: + all_self_attentions += (layer_outputs[2],) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + return SwinBaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + hidden_states_spatial_dimensions=all_input_dimensions, + attentions=all_self_attentions, ) @@ -742,10 +861,11 @@ def forward( # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] head_mask = self.get_head_mask(head_mask, len(self.config.depths)) - embedding_output = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) + embedding_output, input_dimensions = self.embeddings(pixel_values, bool_masked_pos=bool_masked_pos) encoder_outputs = self.encoder( embedding_output, + input_dimensions, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, @@ -763,10 +883,13 @@ def forward( if not return_dict: return (sequence_output, pooled_output) + encoder_outputs[1:] - return BaseModelOutputWithPooling( + hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions + + return SwinModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, + hidden_states_spatial_dimensions=hidden_states_spatial_dimensions, attentions=encoder_outputs.attentions, ) From 9edd61d0f14f83499d84e908f88971a1a08a22fd Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 8 Mar 2022 11:21:58 +0100 Subject: [PATCH 02/25] correctly return one attention per layer --- src/transformers/models/swin/modeling_swin.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 97b889564cc9..d7125d503bfd 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -629,13 +629,15 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d self.pointing = False def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): - height, width = input_dimensions + block_attentions = () if output_attentions else None for i, block_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None block_outputs = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + if output_attentions: + block_attentions += (block_outputs[1:],) if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 @@ -647,7 +649,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent layer_outputs = (hidden_states, output_dimensions) if output_attentions: - layer_outputs += (block_outputs[1:],) + layer_outputs += block_attentions return layer_outputs @@ -722,7 +724,7 @@ def custom_forward(*inputs): all_hidden_states += (hidden_states,) if output_attentions: - all_self_attentions += (layer_outputs[2],) + all_self_attentions += (layer_outputs[2:],) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) From c2398485f229372f8d06cb4e8f63f80b5878c6f4 Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 8 Mar 2022 14:01:28 +0100 Subject: [PATCH 03/25] almost correct, attentions are not flatten one tuple per stage --- src/transformers/models/swin/modeling_swin.py | 50 +++++++++++-------- tests/swin/test_modeling_swin.py | 11 ++-- 2 files changed, 34 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index d7125d503bfd..05601ac04443 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -87,8 +87,8 @@ class SwinBaseModelOutput(ModelOutput): last_hidden_state: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None @dataclass @@ -121,8 +121,8 @@ class SwinModelOutputWithPooling(ModelOutput): last_hidden_state: torch.FloatTensor = None pooler_output: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None + hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. @@ -457,9 +457,9 @@ def prune_heads(self, heads): self.pruned_heads = self.pruned_heads.union(heads) def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): - self_outputs, self_attentions = self.self(hidden_states, attention_mask, head_mask, output_attentions) - attention_output = self.output(self_outputs, hidden_states) - outputs = (attention_output, self_attentions) if output_attentions else (attention_output) + self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) + attention_output = self.output(self_outputs[0], hidden_states) + outputs = (attention_output, self_outputs[1]) if output_attentions else (attention_output,) return outputs @@ -497,12 +497,7 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.shift_size = shift_size self.window_size = config.window_size self.input_resolution = input_resolution - - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - + self.set_shift_and_window_size(input_resolution) self.layernorm_before = nn.LayerNorm(dim, eps=config.layer_norm_eps) self.attention = SwinAttention(config, dim, num_heads) self.drop_path = SwinDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity() @@ -510,6 +505,12 @@ def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): self.intermediate = SwinIntermediate(config, dim) self.output = SwinOutput(config, dim) + def set_shift_and_window_size(self, input_resolution): + if min(input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(input_resolution) + def get_attn_mask(self, input_resolution): if self.shift_size > 0: # calculate attention mask for SW-MSA @@ -548,10 +549,11 @@ def maybe_pad(self, hidden_states, height, width): return hidden_states, pad_values def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): + self.set_shift_and_window_size(input_dimensions) height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states - + hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) # pad hidden_states to multiples of window size @@ -598,7 +600,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent layer_output = self.intermediate(layer_output) layer_output = hidden_states + self.output(layer_output) - layer_outputs = (layer_output, attention_output[1:]) if output_attentions else (layer_output) + layer_outputs = (layer_output, attention_outputs[1]) if output_attentions else (layer_output,) return layer_outputs @@ -637,7 +639,9 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent block_outputs = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) if output_attentions: - block_attentions += (block_outputs[1:],) + block_attentions += block_outputs[1:] + + hidden_states = block_outputs[0] if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 @@ -704,7 +708,7 @@ def custom_forward(*inputs): return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( - create_custom_forward(layer_module), hidden_states, layer_head_mask + create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: layer_outputs = layer_module( @@ -725,15 +729,15 @@ def custom_forward(*inputs): if output_attentions: all_self_attentions += (layer_outputs[2:],) - + if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions, all_input_dimensions] if v is not None) return SwinBaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, - hidden_states_spatial_dimensions=all_input_dimensions, attentions=all_self_attentions, + hidden_states_spatial_dimensions=all_input_dimensions, ) @@ -882,16 +886,20 @@ def forward( pooled_output = self.pooler(sequence_output.transpose(1, 2)) pooled_output = torch.flatten(pooled_output, 1) + if not return_dict: - return (sequence_output, pooled_output) + encoder_outputs[1:] + output = (sequence_output, pooled_output) + encoder_outputs[1:-1] + # spatial hidden sizes is at the end + hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs[-1] + output += (hidden_states_spatial_dimensions,) - hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions + return output return SwinModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, - hidden_states_spatial_dimensions=hidden_states_spatial_dimensions, + hidden_states_spatial_dimensions=(input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions, attentions=encoder_outputs.attentions, ) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 25acbb724f68..4b1ba4746a6b 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -265,12 +265,13 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( - list(attentions[0].shape[-4:]), + list(attentions[0][0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], ) else: + # attentions is a tuple of tuple, since we have one attention per layer self.assertListEqual( - list(attentions[0].shape[-3:]), + list(attentions[0][0].shape[-3:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) out_len = len(outputs) @@ -297,12 +298,12 @@ def test_attention_outputs(self): self.assertEqual(len(self_attentions), len(self.model_tester.depths)) if chunk_length is not None: self.assertListEqual( - list(self_attentions[0].shape[-4:]), + list(self_attentions[0][0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], ) else: self.assertListEqual( - list(self_attentions[0].shape[-3:]), + list(self_attentions[0][0].shape[-3:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) @@ -395,7 +396,5 @@ def test_inference_image_classification_head(self): # verify the logits expected_shape = torch.Size((1, 1000)) self.assertEqual(outputs.logits.shape, expected_shape) - expected_slice = torch.tensor([-0.0948, -0.6454, -0.0921]).to(torch_device) - self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4)) From 09e46619de63ac2e3218dcdee6a152bd6d656b3d Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 8 Mar 2022 14:17:00 +0100 Subject: [PATCH 04/25] tests green --- docs/source/model_doc/swin.mdx | 1 + .../models/maskformer/modeling_maskformer.py | 2 +- src/transformers/models/swin/modeling_swin.py | 15 ++-- tests/swin/test_modeling_swin.py | 81 ++++++++++++++++++- 4 files changed, 88 insertions(+), 11 deletions(-) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index c1fd4e86d3a4..0483f792bece 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -34,6 +34,7 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. +- Swin pads the inputs thus it supports any input size (if divisible by `32`). drawing diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 39af8a27ebc6..346eb36daee4 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -799,7 +799,7 @@ def prune_heads(self, heads): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + outputs = (attention_output, self_outputs[1]) if output_attentions else (attention_output,) return outputs diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 05601ac04443..d3837eaddb4b 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -553,7 +553,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent height, width = input_dimensions batch_size, _, channels = hidden_states.size() shortcut = hidden_states - + hidden_states = self.layernorm_before(hidden_states) hidden_states = hidden_states.view(batch_size, height, width, channels) # pad hidden_states to multiples of window size @@ -728,10 +728,14 @@ def custom_forward(*inputs): all_hidden_states += (hidden_states,) if output_attentions: - all_self_attentions += (layer_outputs[2:],) - + all_self_attentions += layer_outputs[2:] + if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions, all_input_dimensions] if v is not None) + return tuple( + v + for v in [hidden_states, all_hidden_states, all_self_attentions, all_input_dimensions] + if v is not None + ) return SwinBaseModelOutput( last_hidden_state=hidden_states, @@ -886,9 +890,8 @@ def forward( pooled_output = self.pooler(sequence_output.transpose(1, 2)) pooled_output = torch.flatten(pooled_output, 1) - if not return_dict: - output = (sequence_output, pooled_output) + encoder_outputs[1:-1] + output = (sequence_output, pooled_output) + encoder_outputs[1:-1] # spatial hidden sizes is at the end hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs[-1] output += (hidden_states_spatial_dimensions,) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 4b1ba4746a6b..a701ac77206f 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -17,6 +17,7 @@ import copy import inspect import unittest +from typing import Dict, List, Tuple from transformers import SwinConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available @@ -265,13 +266,13 @@ def test_attention_outputs(self): if chunk_length is not None: self.assertListEqual( - list(attentions[0][0].shape[-4:]), + list(attentions[0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], ) else: # attentions is a tuple of tuple, since we have one attention per layer self.assertListEqual( - list(attentions[0][0].shape[-3:]), + list(attentions[0].shape[-3:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) out_len = len(outputs) @@ -298,15 +299,87 @@ def test_attention_outputs(self): self.assertEqual(len(self_attentions), len(self.model_tester.depths)) if chunk_length is not None: self.assertListEqual( - list(self_attentions[0][0].shape[-4:]), + list(self_attentions[0].shape[-4:]), [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], ) else: self.assertListEqual( - list(self_attentions[0][0].shape[-3:]), + list(self_attentions[0].shape[-3:]), [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) + def test_model_outputs_equivalence(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + def set_nan_tensor_to_zero(t): + + t[t != t] = 0 + return t + + def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): + with torch.no_grad(): + tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) + dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() + + def recursive_check(tuple_object, dict_object): + # for spatial dimensions and exit condition for recursion + if type(tuple_object) is int: + self.assertEqual(tuple_object, dict_object) + elif isinstance(tuple_object, (List, Tuple)): + for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif isinstance(tuple_object, Dict): + for tuple_iterable_value, dict_iterable_value in zip( + tuple_object.values(), dict_object.values() + ): + recursive_check(tuple_iterable_value, dict_iterable_value) + elif tuple_object is None: + return + else: + self.assertTrue( + torch.allclose( + set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 + ), + msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.", + ) + + recursive_check(tuple_output, dict_output) + + for model_class in self.all_model_classes: + model = model_class(config) + model.to(torch_device) + model.eval() + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class) + dict_inputs = self._prepare_for_class(inputs_dict, model_class) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) + + tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + check_equivalence( + model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + ) + def test_hidden_states_output(self): def check_hidden_states_output(inputs_dict, config, model_class): model = model_class(config) From 91be4b7083b80f7a580b29127c064745180b6450 Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 8 Mar 2022 14:18:58 +0100 Subject: [PATCH 05/25] doc --- docs/source/model_doc/swin.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index 0483f792bece..3a4d69af65a0 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -34,7 +34,7 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. -- Swin pads the inputs thus it supports any input size (if divisible by `32`). +- Swin pads the inputs supporting any input size (if divisible by `32`). drawing From 8d6022f57898838922d3cefeb2545aa28d712fe3 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 9 Mar 2022 11:45:58 +0100 Subject: [PATCH 06/25] conversations --- src/transformers/models/swin/modeling_swin.py | 14 ++++---------- tests/swin/test_modeling_swin.py | 1 - 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index d3837eaddb4b..408a70c4b813 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -541,10 +541,9 @@ def get_attn_mask(self, input_resolution): return attn_mask def maybe_pad(self, hidden_states, height, width): - pad_left = pad_top = 0 - pad_rigth = (self.window_size - width % self.window_size) % self.window_size + pad_right = (self.window_size - width % self.window_size) % self.window_size pad_bottom = (self.window_size - height % self.window_size) % self.window_size - pad_values = (0, 0, pad_left, pad_rigth, pad_top, pad_bottom) + pad_values = (0, 0, 0, pad_right, 0, pad_bottom) hidden_states = nn.functional.pad(hidden_states, pad_values) return hidden_states, pad_values @@ -580,7 +579,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent attention_output = attention_outputs[0] attention_windows = attention_output.view(-1, self.window_size, self.window_size, channels) - shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # B H' W' C + shifted_windows = window_reverse(attention_windows, self.window_size, height_pad, width_pad) # reverse cyclic shift if self.shift_size > 0: @@ -711,12 +710,7 @@ def custom_forward(*inputs): create_custom_forward(layer_module), hidden_states, input_dimensions, layer_head_mask ) else: - layer_outputs = layer_module( - hidden_states, - input_dimensions, - layer_head_mask, - output_attentions, - ) + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] output_dimensions = layer_outputs[1] diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index a701ac77206f..5a67af735bc7 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -312,7 +312,6 @@ def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() def set_nan_tensor_to_zero(t): - t[t != t] = 0 return t From f2b0d3df06ad6738cc83a3e3bf172cb815a369d5 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 9 Mar 2022 16:24:23 +0100 Subject: [PATCH 07/25] reshaping hidden_states --- src/transformers/models/swin/modeling_swin.py | 101 +++--------------- tests/swin/test_modeling_swin.py | 4 +- 2 files changed, 19 insertions(+), 86 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 408a70c4b813..fd9d22581451 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -33,7 +33,7 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_swin import SwinConfig @@ -59,72 +59,6 @@ # See all Swin models at https://huggingface.co/models?filter=swin ] - -@dataclass -class SwinBaseModelOutput(ModelOutput): - """ - Class for SwinEncoder's outputs. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): - A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to - `batch, channels, height, width`. Due to padding, their spatial size cannot inferred before the `forward` - method. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None - - -@dataclass -class SwinModelOutputWithPooling(ModelOutput): - """ - Class for SwinModel's outputs that also contains the spatial dimensions of the hidden states. - - Args: - last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Sequence of hidden-states at the output of the last layer of the model. - pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state after a mean pooling operation. - hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. - - Hidden-states of the model at the output of each layer plus the initial embedding outputs. - hidden_states_spatial_dimensions (`tuple(tuple(int, int))`, *optional*): - A tuple containing the spatial dimension of each `hidden_state` needed to reshape the `hidden_states` to - `batch, channels, height, width`. Due to padding, their spatial size cannot be inferred before the - `forward` method. - attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, - sequence_length)`. - - Attentions weights after the attention softmax, used to compute the weighted average in the self-attention - heads. - """ - - last_hidden_state: torch.FloatTensor = None - pooler_output: torch.FloatTensor = None - hidden_states: Optional[Tuple[torch.FloatTensor]] = None - attentions: Optional[Tuple[torch.FloatTensor]] = None - hidden_states_spatial_dimensions: Tuple[Tuple[int, int]] = None - - # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. @@ -693,7 +627,10 @@ def forward( all_self_attentions = () if output_attentions else None if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size).permute(0, 3, 1, 2) + all_hidden_states = all_hidden_states + (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None @@ -719,23 +656,21 @@ def custom_forward(*inputs): all_input_dimensions += (input_dimensions,) if output_hidden_states: - all_hidden_states += (hidden_states,) + batch_size, _, hidden_size = hidden_states.shape + # rearrange b (h w) c -> b c h w + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size).permute( + 0, 3, 1, 2 + ) + all_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[2:] if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attentions, all_input_dimensions] - if v is not None - ) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return SwinBaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - hidden_states_spatial_dimensions=all_input_dimensions, + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions ) @@ -885,18 +820,14 @@ def forward( pooled_output = torch.flatten(pooled_output, 1) if not return_dict: - output = (sequence_output, pooled_output) + encoder_outputs[1:-1] - # spatial hidden sizes is at the end - hidden_states_spatial_dimensions = (input_dimensions,) + encoder_outputs[-1] - output += (hidden_states_spatial_dimensions,) + output = (sequence_output, pooled_output) + encoder_outputs[1:] return output - return SwinModelOutputWithPooling( + return BaseModelOutputWithPooling( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, - hidden_states_spatial_dimensions=(input_dimensions,) + encoder_outputs.hidden_states_spatial_dimensions, attentions=encoder_outputs.attentions, ) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 5a67af735bc7..96ed9c7596c6 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -400,8 +400,10 @@ def check_hidden_states_output(inputs_dict, config, model_class): patch_size = to_2tuple(self.model_tester.patch_size) num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + batch_size, num_channels, height, width = hidden_states[0].shape + hidden_states = hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) self.assertListEqual( - list(hidden_states[0].shape[-2:]), + list(hidden_states.shape[-2:]), [num_patches, self.model_tester.embed_dim], ) From bfa25a9d69cd92a3164a4968a6f0e1795dfa4dbe Mon Sep 17 00:00:00 2001 From: fra Date: Fri, 11 Mar 2022 09:41:38 +0100 Subject: [PATCH 08/25] view in the test --- src/transformers/models/swin/modeling_swin.py | 5 ++-- tests/swin/test_modeling_swin.py | 29 +++++++++++++++++++ 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index fd9d22581451..adc9d81b1d9d 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -658,9 +658,8 @@ def custom_forward(*inputs): if output_hidden_states: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w - reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size).permute( - 0, 3, 1, 2 - ) + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (reshaped_hidden_state,) if output_attentions: diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 96ed9c7596c6..137da957909d 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -308,6 +308,35 @@ def test_attention_outputs(self): [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) + def test_retain_grad_hidden_states_attentions(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.output_hidden_states = True + config.output_attentions = True + + # no need to test all models as different heads yield the same functionality + model_class = self.all_model_classes[0] + model = model_class(config) + model.to(torch_device) + + inputs = self._prepare_for_class(inputs_dict, model_class) + + outputs = model(**inputs) + + output = outputs[0] + + hidden_states = outputs.hidden_states[0] + batch_size, num_channels, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, num_channels, height * width).permute(0, 2, 1) + attentions = outputs.attentions[0] + + hidden_states.retain_grad() + attentions.retain_grad() + + output.flatten()[0].backward(retain_graph=True) + + self.assertIsNotNone(hidden_states.grad) + self.assertIsNotNone(attentions.grad) + def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() From d7abc4d35145ee79d8c80df0f2ea0a54186e6bb8 Mon Sep 17 00:00:00 2001 From: fra Date: Fri, 11 Mar 2022 10:53:05 +0100 Subject: [PATCH 09/25] reshape_hidden_states in Encoder and Model --- src/transformers/models/swin/modeling_swin.py | 86 +++++++++++++++++-- tests/swin/test_modeling_swin.py | 44 +++------- 2 files changed, 92 insertions(+), 38 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index adc9d81b1d9d..97276b2b52e7 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -61,6 +61,73 @@ # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. +@dataclass +class SwinEncoderOutput(ModelOutput): + """ + Swin encoder's outputs, with potential hidden states and attentions. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class SwinModelOutput(ModelOutput): + """ + Swin model's outputs that also contains a pooling of the last hidden states. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): + Last layer hidden-state of the first token of the sequence (classification token) after further processing + through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns + the classification token after processing through a linear layer and a tanh activation function. The linear + layer weights are trained from the next sentence prediction (classification) objective during pretraining. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. + """ + + last_hidden_state: torch.FloatTensor = None + pooler_output: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None # Copied from transformers.models.vit.modeling_vit.to_2tuple def to_2tuple(x): @@ -624,13 +691,16 @@ def forward( ): all_input_dimensions = () all_hidden_states = () if output_hidden_states else None + all_reshape_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: batch_size, _, hidden_size = hidden_states.shape # rearrange b (h w) c -> b c h w - reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size).permute(0, 3, 1, 2) - all_hidden_states = all_hidden_states + (reshaped_hidden_state,) + reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) + reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) + all_hidden_states += (hidden_states,) + all_reshape_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None @@ -660,7 +730,8 @@ def custom_forward(*inputs): # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) - all_hidden_states += (reshaped_hidden_state,) + all_hidden_states += (hidden_states,) + all_reshape_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[2:] @@ -668,8 +739,8 @@ def custom_forward(*inputs): if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions + return SwinEncoderOutput( + last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshape_hidden_states ) @@ -769,7 +840,7 @@ class PreTrainedModel @add_code_sample_docstrings( processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_CHECKPOINT_FOR_DOC, - output_type=BaseModelOutputWithPooling, + output_type=SwinModelOutput, config_class=_CONFIG_FOR_DOC, modality="vision", expected_output=_EXPECTED_OUTPUT_SHAPE, @@ -823,11 +894,12 @@ def forward( return output - return BaseModelOutputWithPooling( + return SwinModelOutput( last_hidden_state=sequence_output, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states ) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index 137da957909d..db550c8e863e 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -291,6 +291,7 @@ def test_attention_outputs(self): elif self.is_encoder_decoder: added_hidden_states = 2 else: + # TODO also another +1 for reshaped_hidden_states added_hidden_states = 1 self.assertEqual(out_len + added_hidden_states, len(outputs)) @@ -308,34 +309,6 @@ def test_attention_outputs(self): [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) - def test_retain_grad_hidden_states_attentions(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - config.output_hidden_states = True - config.output_attentions = True - - # no need to test all models as different heads yield the same functionality - model_class = self.all_model_classes[0] - model = model_class(config) - model.to(torch_device) - - inputs = self._prepare_for_class(inputs_dict, model_class) - - outputs = model(**inputs) - - output = outputs[0] - - hidden_states = outputs.hidden_states[0] - batch_size, num_channels, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, num_channels, height * width).permute(0, 2, 1) - attentions = outputs.attentions[0] - - hidden_states.retain_grad() - attentions.retain_grad() - - output.flatten()[0].backward(retain_graph=True) - - self.assertIsNotNone(hidden_states.grad) - self.assertIsNotNone(attentions.grad) def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -427,12 +400,21 @@ def check_hidden_states_output(inputs_dict, config, model_class): # Swin has a different seq_length image_size = to_2tuple(self.model_tester.image_size) patch_size = to_2tuple(self.model_tester.patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - batch_size, num_channels, height, width = hidden_states[0].shape - hidden_states = hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) self.assertListEqual( - list(hidden_states.shape[-2:]), + list(hidden_states[0].shape[-2:]), + [num_patches, self.model_tester.embed_dim], + ) + + reshaped_hidden_states = outputs.reshaped_hidden_states + self.assertEqual(len(reshaped_hidden_states), expected_num_layers) + + batch_size, num_channels, height, width = reshaped_hidden_states[0].shape + reshaped_hidden_states = reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + self.assertListEqual( + list(reshaped_hidden_states.shape[-2:]), [num_patches, self.model_tester.embed_dim], ) From fbd07d33ccf1189b32466fb62ca49323f41f78ee Mon Sep 17 00:00:00 2001 From: fra Date: Fri, 11 Mar 2022 14:01:50 +0100 Subject: [PATCH 10/25] new outputs with reshaped_hidden_states --- src/transformers/models/swin/modeling_swin.py | 103 ++++++++++++++++-- tests/swin/test_modeling_swin.py | 9 +- 2 files changed, 97 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 97276b2b52e7..e3485118344e 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -33,7 +33,6 @@ add_start_docstrings_to_model_forward, replace_return_docstrings, ) -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import logging from .configuration_swin import SwinConfig @@ -61,6 +60,7 @@ # to_2tuple, drop_path, SwinPatchEmbeddings, SwinPatchMerging and SwinDropPath are from the timm library. + @dataclass class SwinEncoderOutput(ModelOutput): """ @@ -84,7 +84,8 @@ class SwinEncoderOutput(ModelOutput): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, hidden_size, height, width)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. """ last_hidden_state: torch.FloatTensor = None @@ -92,7 +93,8 @@ class SwinEncoderOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None -@dataclass + +@dataclass class SwinModelOutput(ModelOutput): """ Swin model's outputs that also contains a pooling of the last hidden states. @@ -120,7 +122,8 @@ class SwinModelOutput(ModelOutput): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, hidden_size, height, width)`. - Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. """ last_hidden_state: torch.FloatTensor = None @@ -129,6 +132,79 @@ class SwinModelOutput(ModelOutput): attentions: Optional[Tuple[torch.FloatTensor]] = None reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + +@dataclass +class SwinMaskedLMOutput(ModelOutput): + """ + Swin masked language models outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Masked language modeling (MLM) loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +class SwinImageClassifierOutput(ModelOutput): + """ + Swin outputs for image classification. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Classification (or regression if config.num_labels==1) loss. + logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): + Classification (or regression if config.num_labels==1) scores (before SoftMax). + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of + shape `(batch_size, hidden_size, height, width)`. + + Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to + include the spatial dimensions. + """ + + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + reshaped_hidden_states: Optional[Tuple[torch.FloatTensor]] = None + + # Copied from transformers.models.vit.modeling_vit.to_2tuple def to_2tuple(x): if isinstance(x, collections.abc.Iterable): @@ -699,7 +775,7 @@ def forward( # rearrange b (h w) c -> b c h w reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) - all_hidden_states += (hidden_states,) + all_hidden_states += (hidden_states,) all_reshape_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): @@ -740,7 +816,10 @@ def custom_forward(*inputs): return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return SwinEncoderOutput( - last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, reshaped_hidden_states=all_reshape_hidden_states + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + reshaped_hidden_states=all_reshape_hidden_states, ) @@ -899,7 +978,7 @@ def forward( pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, - reshaped_hidden_states=encoder_outputs.reshaped_hidden_states + reshaped_hidden_states=encoder_outputs.reshaped_hidden_states, ) @@ -923,7 +1002,7 @@ def __init__(self, config): self.post_init() @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=SwinMaskedLMOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values=None, @@ -1001,11 +1080,12 @@ def forward( output = (reconstructed_pixel_values,) + outputs[2:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return MaskedLMOutput( + return SwinMaskedLMOutput( loss=masked_im_loss, logits=reconstructed_pixel_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, ) @@ -1035,7 +1115,7 @@ def __init__(self, config): @add_code_sample_docstrings( processor_class=_FEAT_EXTRACTOR_FOR_DOC, checkpoint=_IMAGE_CLASS_CHECKPOINT, - output_type=SequenceClassifierOutput, + output_type=SwinImageClassifierOutput, config_class=_CONFIG_FOR_DOC, expected_output=_IMAGE_CLASS_EXPECTED_OUTPUT, ) @@ -1095,9 +1175,10 @@ def forward( output = (logits,) + outputs[2:] return ((loss,) + output) if loss is not None else output - return SequenceClassifierOutput( + return SwinImageClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, + reshaped_hidden_states=outputs.reshaped_hidden_states, ) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index db550c8e863e..e5e53cb80601 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -292,7 +292,7 @@ def test_attention_outputs(self): added_hidden_states = 2 else: # TODO also another +1 for reshaped_hidden_states - added_hidden_states = 1 + added_hidden_states = 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions @@ -309,7 +309,6 @@ def test_attention_outputs(self): [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) - def test_model_outputs_equivalence(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -408,11 +407,13 @@ def check_hidden_states_output(inputs_dict, config, model_class): [num_patches, self.model_tester.embed_dim], ) - reshaped_hidden_states = outputs.reshaped_hidden_states + reshaped_hidden_states = outputs.reshaped_hidden_states self.assertEqual(len(reshaped_hidden_states), expected_num_layers) batch_size, num_channels, height, width = reshaped_hidden_states[0].shape - reshaped_hidden_states = reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + reshaped_hidden_states = ( + reshaped_hidden_states[0].view(batch_size, num_channels, height * width).permute(0, 2, 1) + ) self.assertListEqual( list(reshaped_hidden_states.shape[-2:]), [num_patches, self.model_tester.embed_dim], From e4b4db912d2dc43d3dcb48e9edf46ad18d562b68 Mon Sep 17 00:00:00 2001 From: fra Date: Sat, 12 Mar 2022 13:50:59 +0100 Subject: [PATCH 11/25] conversations --- src/transformers/models/maskformer/modeling_maskformer.py | 2 +- src/transformers/models/swin/modeling_swin.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index 346eb36daee4..39af8a27ebc6 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -799,7 +799,7 @@ def prune_heads(self, heads): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output, self_outputs[1]) if output_attentions else (attention_output,) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index e3485118344e..da144be4ef7e 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -536,7 +536,7 @@ def prune_heads(self, heads): def forward(self, hidden_states, attention_mask=None, head_mask=None, output_attentions=False): self_outputs = self.self(hidden_states, attention_mask, head_mask, output_attentions) attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output, self_outputs[1]) if output_attentions else (attention_output,) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs From 7719aafab496f17873d834b6196218aa92d02bee Mon Sep 17 00:00:00 2001 From: fra Date: Mon, 14 Mar 2022 09:17:01 +0100 Subject: [PATCH 12/25] doc --- docs/source/model_doc/swin.mdx | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index 3a4d69af65a0..528709d1284a 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -35,6 +35,7 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. - Swin pads the inputs supporting any input size (if divisible by `32`). +- Swin can be used as a *backbone*, when when `output_hidden_states = True` it will outputs both `hidden_states` and `reshaped_hidden_states`. `reshaped_hidden_states` have a size of `batch, channels, height, width`. drawing From f0311dea7887dd2b6b6970d8f1b20ddc72d85f9c Mon Sep 17 00:00:00 2001 From: Francesco Saverio Zuppichini Date: Tue, 15 Mar 2022 10:37:54 +0100 Subject: [PATCH 13/25] Update docs/source/model_doc/swin.mdx Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- docs/source/model_doc/swin.mdx | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index 528709d1284a..f5b6463033a2 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -35,7 +35,7 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. - Swin pads the inputs supporting any input size (if divisible by `32`). -- Swin can be used as a *backbone*, when when `output_hidden_states = True` it will outputs both `hidden_states` and `reshaped_hidden_states`. `reshaped_hidden_states` have a size of `batch, channels, height, width`. +- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, channels, height, width)`. drawing From b7c5dd23fedf35ce1991cf2f8e7471ec5eae0a94 Mon Sep 17 00:00:00 2001 From: Francesco Saverio Zuppichini Date: Tue, 15 Mar 2022 10:39:17 +0100 Subject: [PATCH 14/25] Apply suggestions from code review Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com> --- src/transformers/models/swin/modeling_swin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index da144be4ef7e..ee00224532b1 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -134,9 +134,9 @@ class SwinModelOutput(ModelOutput): @dataclass -class SwinMaskedLMOutput(ModelOutput): +class SwinMaskedImageModelingOutput(ModelOutput): """ - Swin masked language models outputs. + Swin masked image model outputs. Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): From 906df505bbcce54e3143662b0cbe620db641809a Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 15 Mar 2022 15:04:02 +0100 Subject: [PATCH 15/25] conversations --- docs/source/model_doc/swin.mdx | 4 ++-- src/transformers/models/swin/modeling_swin.py | 11 ++++------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index f5b6463033a2..230fcecd5907 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -34,8 +34,8 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. -- Swin pads the inputs supporting any input size (if divisible by `32`). -- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, channels, height, width)`. +- Swin pads the inputs supporting any input height and width (if divisible by `32`). +- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.`. drawing diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index ee00224532b1..bc961a8cbda9 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -103,10 +103,7 @@ class SwinModelOutput(ModelOutput): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): - Last layer hidden-state of the first token of the sequence (classification token) after further processing - through the layers used for the auxiliary pretraining task. E.g. for BERT-family of models, this returns - the classification token after processing through a linear layer and a tanh activation function. The linear - layer weights are trained from the next sentence prediction (classification) objective during pretraining. + Average pooling of the last layer hidden-state. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. @@ -140,7 +137,7 @@ class SwinMaskedImageModelingOutput(ModelOutput): Args: loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): - Masked language modeling (MLM) loss. + Masked image modeling (MLM) loss. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): @@ -1002,7 +999,7 @@ def __init__(self, config): self.post_init() @add_start_docstrings_to_model_forward(SWIN_INPUTS_DOCSTRING) - @replace_return_docstrings(output_type=SwinMaskedLMOutput, config_class=_CONFIG_FOR_DOC) + @replace_return_docstrings(output_type=SwinMaskedImageModelingOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values=None, @@ -1080,7 +1077,7 @@ def forward( output = (reconstructed_pixel_values,) + outputs[2:] return ((masked_im_loss,) + output) if masked_im_loss is not None else output - return SwinMaskedLMOutput( + return SwinMaskedImageModelingOutput( loss=masked_im_loss, logits=reconstructed_pixel_values, hidden_states=outputs.hidden_states, From 72a07985e7abf3e4ed5bb324b20303774606b849 Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 15 Mar 2022 15:19:23 +0100 Subject: [PATCH 16/25] fix tests --- tests/swin/test_modeling_swin.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index e5e53cb80601..a9ca5234ed00 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -250,7 +250,8 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), len(self.model_tester.depths)) + expected_num_attentions = sum(self.model_tester.depths) + self.assertEqual(len(attentions), expected_num_attentions) # check that output_attentions also work using config del inputs_dict["output_attentions"] @@ -262,7 +263,7 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(attentions), len(self.model_tester.depths)) + self.assertEqual(len(attentions), expected_num_attentions) if chunk_length is not None: self.assertListEqual( @@ -291,13 +292,13 @@ def test_attention_outputs(self): elif self.is_encoder_decoder: added_hidden_states = 2 else: - # TODO also another +1 for reshaped_hidden_states + # also another +1 for reshaped_hidden_states added_hidden_states = 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - self.assertEqual(len(self_attentions), len(self.model_tester.depths)) + self.assertEqual(len(self_attentions), expected_num_attentions) if chunk_length is not None: self.assertListEqual( list(self_attentions[0].shape[-4:]), From 9ff078e8f2676f521d2bfcf145fd08ec8bca46f3 Mon Sep 17 00:00:00 2001 From: fra Date: Tue, 15 Mar 2022 15:21:54 +0100 Subject: [PATCH 17/25] minor changes --- src/transformers/models/swin/modeling_swin.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index bc961a8cbda9..93ff819ab63c 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -136,10 +136,10 @@ class SwinMaskedImageModelingOutput(ModelOutput): Swin masked image model outputs. Args: - loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `bool_masked_pos` is provided): Masked image modeling (MLM) loss. - logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): - Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). + logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. From 5a5fc6d24820908d886a9b77970cada358a6be51 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 13:49:05 +0100 Subject: [PATCH 18/25] resolved conversations --- docs/source/model_doc/swin.mdx | 2 +- src/transformers/models/swin/modeling_swin.py | 24 +++++++++---------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/docs/source/model_doc/swin.mdx b/docs/source/model_doc/swin.mdx index 230fcecd5907..1f247284afbc 100644 --- a/docs/source/model_doc/swin.mdx +++ b/docs/source/model_doc/swin.mdx @@ -35,7 +35,7 @@ The hierarchical design and the shifted window approach also prove beneficial fo Tips: - One can use the [`AutoFeatureExtractor`] API to prepare images for the model. - Swin pads the inputs supporting any input height and width (if divisible by `32`). -- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`.`. +- Swin can be used as a *backbone*. When `output_hidden_states = True`, it will output both `hidden_states` and `reshaped_hidden_states`. The `reshaped_hidden_states` have a shape of `(batch, num_channels, height, width)` rather than `(batch_size, sequence_length, num_channels)`. drawing diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 93ff819ab63c..d1781e57d5db 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -348,7 +348,7 @@ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) self.norm = norm_layer(4 * dim) - def maybe_pad(self, input_feature, width, height): + def maybe_pad(self, input_feature, height, width): should_pad = (height % 2 == 1) or (width % 2 == 1) if should_pad: pad_values = (0, 0, 0, width % 2, 0, height % 2) @@ -564,7 +564,7 @@ def forward(self, hidden_states): return hidden_states -class SwinBlock(nn.Module): +class SwinLayer(nn.Module): def __init__(self, config, dim, input_resolution, num_heads, shift_size=0): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward @@ -677,14 +677,14 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent return layer_outputs -class SwinLayer(nn.Module): +class SwinStage(nn.Module): def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, downsample): super().__init__() self.config = config self.dim = dim self.blocks = nn.ModuleList( [ - SwinBlock( + SwinLayer( config=config, dim=dim, input_resolution=input_resolution, @@ -705,28 +705,28 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): height, width = input_dimensions - block_attentions = () if output_attentions else None - for i, block_module in enumerate(self.blocks): + layer_attentions = () if output_attentions else None + for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None - block_outputs = block_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) + layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) if output_attentions: - block_attentions += block_outputs[1:] + layer_attentions += layer_outputs[1:] - hidden_states = block_outputs[0] + hidden_states = layer_outputs[0] if self.downsample is not None: height_downsampled, width_downsampled = (height + 1) // 2, (width + 1) // 2 output_dimensions = (height, width, height_downsampled, width_downsampled) - hidden_states = self.downsample(block_outputs[0], input_dimensions) + hidden_states = self.downsample(layer_outputs[0], input_dimensions) else: output_dimensions = (height, width, height, width) layer_outputs = (hidden_states, output_dimensions) if output_attentions: - layer_outputs += block_attentions + layer_outputs += layer_attentions return layer_outputs @@ -738,7 +738,7 @@ def __init__(self, config, grid_size): dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, sum(config.depths))] self.layers = nn.ModuleList( [ - SwinLayer( + SwinStage( config=config, dim=int(config.embed_dim * 2**i_layer), input_resolution=(grid_size[0] // (2**i_layer), grid_size[1] // (2**i_layer)), From 1d36bb8e00c96da442ac220aac8e9297c3364f3b Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 14:25:46 +0100 Subject: [PATCH 19/25] attentions one per stage --- src/transformers/models/swin/modeling_swin.py | 49 +++++++++---------- tests/swin/test_modeling_swin.py | 2 +- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index d1781e57d5db..5df7919d5056 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -70,19 +70,19 @@ class SwinEncoderOutput(ModelOutput): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -105,19 +105,19 @@ class SwinModelOutput(ModelOutput): pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): Average pooling of the last layer hidden-state. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -141,19 +141,19 @@ class SwinMaskedImageModelingOutput(ModelOutput): logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -177,19 +177,19 @@ class SwinImageClassifierOutput(ModelOutput): logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): - Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + Tuple of `torch.FloatTensor` (one for each stage) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of - shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of + each stage) stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -705,14 +705,11 @@ def __init__(self, config, dim, input_resolution, depth, num_heads, drop_path, d def forward(self, hidden_states, input_dimensions, head_mask=None, output_attentions=False): height, width = input_dimensions - layer_attentions = () if output_attentions else None for i, layer_module in enumerate(self.blocks): layer_head_mask = head_mask[i] if head_mask is not None else None layer_outputs = layer_module(hidden_states, input_dimensions, layer_head_mask, output_attentions) - if output_attentions: - layer_attentions += layer_outputs[1:] hidden_states = layer_outputs[0] @@ -723,11 +720,11 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent else: output_dimensions = (height, width, height, width) - layer_outputs = (hidden_states, output_dimensions) + stage_outputs = (hidden_states, output_dimensions) if output_attentions: - layer_outputs += layer_attentions - return layer_outputs + stage_outputs += layer_outputs[1:] + return stage_outputs class SwinEncoder(nn.Module): diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index a9ca5234ed00..b64532914f45 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -250,7 +250,7 @@ def test_attention_outputs(self): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions - expected_num_attentions = sum(self.model_tester.depths) + expected_num_attentions = len(self.model_tester.depths) self.assertEqual(len(attentions), expected_num_attentions) # check that output_attentions also work using config From ce005bebcaa136234551f0757fb1a9bd93417ee8 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 14:27:29 +0100 Subject: [PATCH 20/25] typo --- src/transformers/models/swin/modeling_swin.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 5df7919d5056..319248f4495a 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -761,7 +761,7 @@ def forward( ): all_input_dimensions = () all_hidden_states = () if output_hidden_states else None - all_reshape_hidden_states = () if output_hidden_states else None + all_reshaped_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if output_hidden_states: @@ -770,7 +770,7 @@ def forward( reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) - all_reshape_hidden_states += (reshaped_hidden_state,) + all_reshaped_hidden_states += (reshaped_hidden_state,) for i, layer_module in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None @@ -801,7 +801,7 @@ def custom_forward(*inputs): reshaped_hidden_state = hidden_states.view(batch_size, *input_dimensions, hidden_size) reshaped_hidden_state = reshaped_hidden_state.permute(0, 3, 1, 2) all_hidden_states += (hidden_states,) - all_reshape_hidden_states += (reshaped_hidden_state,) + all_reshaped_hidden_states += (reshaped_hidden_state,) if output_attentions: all_self_attentions += layer_outputs[2:] @@ -813,7 +813,7 @@ def custom_forward(*inputs): last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, - reshaped_hidden_states=all_reshape_hidden_states, + reshaped_hidden_states=all_reshaped_hidden_states, ) From cfb365032231f0d93f4c5fd607141f7759911824 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 14:49:38 +0100 Subject: [PATCH 21/25] typos --- src/transformers/models/swin/modeling_swin.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index 319248f4495a..dcbfb6817fe6 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -71,7 +71,7 @@ class SwinEncoderOutput(ModelOutput): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. + each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -82,7 +82,7 @@ class SwinEncoderOutput(ModelOutput): heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, hidden_size, height, width)`. + each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -106,7 +106,7 @@ class SwinModelOutput(ModelOutput): Average pooling of the last layer hidden-state. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. + each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -117,7 +117,7 @@ class SwinModelOutput(ModelOutput): heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, hidden_size, height, width)`. + each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -142,7 +142,7 @@ class SwinMaskedImageModelingOutput(ModelOutput): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. + each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -153,7 +153,7 @@ class SwinMaskedImageModelingOutput(ModelOutput): heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, hidden_size, height, width)`. + each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -178,7 +178,7 @@ class SwinImageClassifierOutput(ModelOutput): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, sequence_length, hidden_size)`. + each stage) of shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -189,7 +189,7 @@ class SwinImageClassifierOutput(ModelOutput): heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) stage) of shape `(batch_size, hidden_size, height, width)`. + each stage) of shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. From 66b23068f2836417614dc849bf29e9e9cfab7b9b Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 15:01:51 +0100 Subject: [PATCH 22/25] typos --- src/transformers/models/swin/modeling_swin.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index dcbfb6817fe6..ef427df3cb4c 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -70,8 +70,8 @@ class SwinEncoderOutput(ModelOutput): last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Sequence of hidden-states at the output of the last layer of the model. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -81,8 +81,8 @@ class SwinEncoderOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -105,8 +105,8 @@ class SwinModelOutput(ModelOutput): pooler_output (`torch.FloatTensor` of shape `(batch_size, hidden_size)`): Average pooling of the last layer hidden-state. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -116,8 +116,8 @@ class SwinModelOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -141,8 +141,8 @@ class SwinMaskedImageModelingOutput(ModelOutput): logits (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): Reconstructed pixel values. hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -152,8 +152,8 @@ class SwinMaskedImageModelingOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. @@ -177,8 +177,8 @@ class SwinImageClassifierOutput(ModelOutput): logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`): Classification (or regression if config.num_labels==1) scores (before SoftMax). hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, sequence_length, hidden_size)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs. attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): @@ -188,8 +188,8 @@ class SwinImageClassifierOutput(ModelOutput): Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. reshaped_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): - Tuple of `torch.FloatTensor` (one for each(one for the output of the embeddings + one for the output of - each stage) of shape `(batch_size, hidden_size, height, width)`. + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, hidden_size, height, width)`. Hidden-states of the model at the output of each layer plus the initial embedding outputs reshaped to include the spatial dimensions. From 120488aa6123240f6ff43d3e256b51219db884f7 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 16:58:43 +0100 Subject: [PATCH 23/25] function signature --- src/transformers/models/swin/modeling_swin.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/swin/modeling_swin.py b/src/transformers/models/swin/modeling_swin.py index ef427df3cb4c..45bf23d3cb65 100644 --- a/src/transformers/models/swin/modeling_swin.py +++ b/src/transformers/models/swin/modeling_swin.py @@ -585,10 +585,9 @@ def set_shift_and_window_size(self, input_resolution): self.shift_size = 0 self.window_size = min(input_resolution) - def get_attn_mask(self, input_resolution): + def get_attn_mask(self, height, width): if self.shift_size > 0: # calculate attention mask for SW-MSA - height, width = input_resolution img_mask = torch.zeros((1, height, width, 1)) height_slices = ( slice(0, -self.window_size), @@ -642,7 +641,7 @@ def forward(self, hidden_states, input_dimensions, head_mask=None, output_attent # partition windows hidden_states_windows = window_partition(shifted_hidden_states, self.window_size) hidden_states_windows = hidden_states_windows.view(-1, self.window_size * self.window_size, channels) - attn_mask = self.get_attn_mask((height_pad, width_pad)) + attn_mask = self.get_attn_mask(height_pad, width_pad) if attn_mask is not None: attn_mask = attn_mask.to(hidden_states_windows.device) From 25e5105368ac80f54e770e7dd590dad992c81049 Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 17:01:07 +0100 Subject: [PATCH 24/25] CI From 669c1c7f1d97319c137349fd7fd1f51627a694da Mon Sep 17 00:00:00 2001 From: fra Date: Wed, 16 Mar 2022 18:21:44 +0100 Subject: [PATCH 25/25] clean up tests --- tests/swin/test_modeling_swin.py | 117 +++---------------------------- 1 file changed, 11 insertions(+), 106 deletions(-) diff --git a/tests/swin/test_modeling_swin.py b/tests/swin/test_modeling_swin.py index b64532914f45..373e812d5fe7 100644 --- a/tests/swin/test_modeling_swin.py +++ b/tests/swin/test_modeling_swin.py @@ -17,7 +17,6 @@ import copy import inspect import unittest -from typing import Dict, List, Tuple from transformers import SwinConfig from transformers.file_utils import cached_property, is_torch_available, is_vision_available @@ -231,15 +230,6 @@ def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True - image_size = to_2tuple(self.model_tester.image_size) - patch_size = to_2tuple(self.model_tester.patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - seq_len = num_patches - encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) - chunk_length = getattr(self.model_tester, "chunk_length", None) - if chunk_length is not None and hasattr(self.model_tester, "num_hashes"): - encoder_seq_length = encoder_seq_length * self.model_tester.num_hashes - for model_class in self.all_model_classes: inputs_dict["output_attentions"] = True inputs_dict["output_hidden_states"] = False @@ -249,7 +239,7 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions expected_num_attentions = len(self.model_tester.depths) self.assertEqual(len(attentions), expected_num_attentions) @@ -262,20 +252,13 @@ def test_attention_outputs(self): model.eval() with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + attentions = outputs.attentions self.assertEqual(len(attentions), expected_num_attentions) - if chunk_length is not None: - self.assertListEqual( - list(attentions[0].shape[-4:]), - [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], - ) - else: - # attentions is a tuple of tuple, since we have one attention per layer - self.assertListEqual( - list(attentions[0].shape[-3:]), - [self.model_tester.num_heads[0], window_size_squared, window_size_squared], - ) + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], + ) out_len = len(outputs) # Check attention is always last and order is fine @@ -289,96 +272,18 @@ def test_attention_outputs(self): if hasattr(self.model_tester, "num_hidden_states_types"): added_hidden_states = self.model_tester.num_hidden_states_types - elif self.is_encoder_decoder: - added_hidden_states = 2 else: # also another +1 for reshaped_hidden_states added_hidden_states = 2 self.assertEqual(out_len + added_hidden_states, len(outputs)) - self_attentions = outputs.encoder_attentions if config.is_encoder_decoder else outputs.attentions + self_attentions = outputs.attentions self.assertEqual(len(self_attentions), expected_num_attentions) - if chunk_length is not None: - self.assertListEqual( - list(self_attentions[0].shape[-4:]), - [self.model_tester.num_heads[0], window_size_squared, chunk_length, window_size_squared], - ) - else: - self.assertListEqual( - list(self_attentions[0].shape[-3:]), - [self.model_tester.num_heads[0], window_size_squared, window_size_squared], - ) - def test_model_outputs_equivalence(self): - config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() - - def set_nan_tensor_to_zero(t): - t[t != t] = 0 - return t - - def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}): - with torch.no_grad(): - tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs) - dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple() - - def recursive_check(tuple_object, dict_object): - # for spatial dimensions and exit condition for recursion - if type(tuple_object) is int: - self.assertEqual(tuple_object, dict_object) - elif isinstance(tuple_object, (List, Tuple)): - for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif isinstance(tuple_object, Dict): - for tuple_iterable_value, dict_iterable_value in zip( - tuple_object.values(), dict_object.values() - ): - recursive_check(tuple_iterable_value, dict_iterable_value) - elif tuple_object is None: - return - else: - self.assertTrue( - torch.allclose( - set_nan_tensor_to_zero(tuple_object), set_nan_tensor_to_zero(dict_object), atol=1e-5 - ), - msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`: {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}.", - ) - - recursive_check(tuple_output, dict_output) - - for model_class in self.all_model_classes: - model = model_class(config) - model.to(torch_device) - model.eval() - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class) - dict_inputs = self._prepare_for_class(inputs_dict, model_class) - check_equivalence(model, tuple_inputs, dict_inputs) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence(model, tuple_inputs, dict_inputs) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class) - dict_inputs = self._prepare_for_class(inputs_dict, model_class) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class) - dict_inputs = self._prepare_for_class(inputs_dict, model_class) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True}) - - tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) - check_equivalence( - model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_heads[0], window_size_squared, window_size_squared], ) def test_hidden_states_output(self): @@ -390,7 +295,7 @@ def check_hidden_states_output(inputs_dict, config, model_class): with torch.no_grad(): outputs = model(**self._prepare_for_class(inputs_dict, model_class)) - hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states + hidden_states = outputs.hidden_states expected_num_layers = getattr( self.model_tester, "expected_num_hidden_layers", len(self.model_tester.depths) + 1