diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 2821b60728a0..6585e894a701 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -206,6 +206,7 @@ def is_local_dist_rank_0(): "qwen2_5_vl", "videollava", "vipllava", + "detr", ] diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 89441a8b1246..1303d50b8214 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -16,24 +16,32 @@ import math from dataclasses import dataclass -from typing import Optional, Union +from typing import Callable, Optional, Union import torch -from torch import Tensor, nn +import torch.nn as nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions, Seq2SeqModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_outputs import ( + BaseModelOutput, + BaseModelOutputWithCrossAttentions, + Seq2SeqModelOutput, +) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...pytorch_utils import compile_compatible_method_lru_cache from ...utils import ( ModelOutput, + TransformersKwargs, auto_docstring, is_timm_available, logging, requires_backends, ) from ...utils.backbone_utils import load_backbone +from ...utils.generic import OutputRecorder, can_return_tuple, check_model_inputs from .configuration_detr import DetrConfig @@ -316,55 +324,42 @@ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor): return out -class DetrConvModel(nn.Module): - """ - This module adds 2D position embeddings to all intermediate feature maps of the convolutional encoder. - """ - - def __init__(self, conv_encoder, position_embedding): - super().__init__() - self.conv_encoder = conv_encoder - self.position_embedding = position_embedding - - def forward(self, pixel_values, pixel_mask): - # send pixel_values and pixel_mask through backbone to get list of (feature_map, pixel_mask) tuples - out = self.conv_encoder(pixel_values, pixel_mask) - pos = [] - for feature_map, mask in out: - # position encoding - pos.append(self.position_embedding(feature_map, mask).to(feature_map.dtype)) - - return out, pos - - class DetrSinePositionEmbedding(nn.Module): """ This is a more standard version of the position embedding, very similar to the one used by the Attention is all you need paper, generalized to work on images. """ - def __init__(self, embedding_dim=64, temperature=10000, normalize=False, scale=None): + def __init__( + self, num_pos_feats: int = 64, temperature: int = 10000, normalize: bool = False, scale: Optional[float] = None + ): super().__init__() - self.embedding_dim = embedding_dim - self.temperature = temperature - self.normalize = normalize if scale is not None and normalize is False: raise ValueError("normalize should be True if scale is passed") - if scale is None: - scale = 2 * math.pi - self.scale = scale + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + self.scale = 2 * math.pi if scale is None else scale - def forward(self, pixel_values, pixel_mask): - if pixel_mask is None: - raise ValueError("No pixel mask provided") - y_embed = pixel_mask.cumsum(1, dtype=torch.float32) - x_embed = pixel_mask.cumsum(2, dtype=torch.float32) + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros((shape[0], shape[2], shape[3]), device=device, dtype=torch.bool) + y_embed = mask.cumsum(1, dtype=dtype) + x_embed = mask.cumsum(2, dtype=dtype) if self.normalize: - y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * self.scale - x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * self.scale + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale - dim_t = torch.arange(self.embedding_dim, dtype=torch.int64, device=pixel_values.device).float() - dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.embedding_dim) + dim_t = torch.arange(self.num_pos_feats, dtype=torch.int64, device=device).to(dtype) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t @@ -384,207 +379,253 @@ def __init__(self, embedding_dim=256): self.row_embeddings = nn.Embedding(50, embedding_dim) self.column_embeddings = nn.Embedding(50, embedding_dim) - def forward(self, pixel_values, pixel_mask=None): - height, width = pixel_values.shape[-2:] - width_values = torch.arange(width, device=pixel_values.device) - height_values = torch.arange(height, device=pixel_values.device) + @compile_compatible_method_lru_cache(maxsize=1) + def forward( + self, + shape: torch.Size, + device: Union[torch.device, str], + dtype: torch.dtype, + mask: Optional[torch.Tensor] = None, + ): + height, width = shape[-2:] + width_values = torch.arange(width, device=device) + height_values = torch.arange(height, device=device) x_emb = self.column_embeddings(width_values) y_emb = self.row_embeddings(height_values) pos = torch.cat([x_emb.unsqueeze(0).repeat(height, 1, 1), y_emb.unsqueeze(1).repeat(1, width, 1)], dim=-1) pos = pos.permute(2, 0, 1) pos = pos.unsqueeze(0) - pos = pos.repeat(pixel_values.shape[0], 1, 1, 1) + pos = pos.repeat(shape[0], 1, 1, 1) return pos -def build_position_encoding(config): - n_steps = config.d_model // 2 - if config.position_embedding_type == "sine": - # TODO find a better way of exposing other arguments - position_embedding = DetrSinePositionEmbedding(n_steps, normalize=True) - elif config.position_embedding_type == "learned": - position_embedding = DetrLearnedPositionEmbedding(n_steps) - else: - raise ValueError(f"Not supported {config.position_embedding_type}") +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: Optional[float] = None, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + if scaling is None: + scaling = query.size(-1) ** -0.5 + + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + + if attention_mask is not None: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) - return position_embedding + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights -class DetrAttention(nn.Module): + +class DetrSelfAttention(nn.Module): """ - Multi-headed attention from 'Attention Is All You Need' paper. + Multi-headed self-attention from 'Attention Is All You Need' paper. - Here, we add position embeddings to the queries and keys (as explained in the DETR paper). + In DETR, position embeddings are added to both queries and keys (but not values) in self-attention. """ def __init__( self, - embed_dim: int, - num_heads: int, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, dropout: float = 0.0, - bias: bool = True, ): super().__init__() - self.embed_dim = embed_dim - self.num_heads = num_heads - self.dropout = dropout - self.head_dim = embed_dim // num_heads - if self.head_dim * num_heads != self.embed_dim: - raise ValueError( - f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" - f" {num_heads})." - ) + self.config = config + self.head_dim = hidden_size // num_attention_heads self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - - def _shape(self, tensor: torch.Tensor, seq_len: int, batch_size: int): - return tensor.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - - def with_pos_embed(self, tensor: torch.Tensor, object_queries: Optional[Tensor]): - return tensor if object_queries is None else tensor + object_queries + self.k_proj = nn.Linear(hidden_size, hidden_size) + self.v_proj = nn.Linear(hidden_size, hidden_size) + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.o_proj = nn.Linear(hidden_size, hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - object_queries: Optional[torch.Tensor] = None, - key_value_states: Optional[torch.Tensor] = None, - spatial_position_embeddings: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - batch_size, target_len, embed_dim = hidden_states.size() - - # add position embeddings to the hidden states before projecting to queries and keys - if object_queries is not None: - hidden_states_original = hidden_states - hidden_states = self.with_pos_embed(hidden_states, object_queries) - - # add key-value position embeddings to the key value states - if spatial_position_embeddings is not None: - key_value_states_original = key_value_states - key_value_states = self.with_pos_embed(key_value_states, spatial_position_embeddings) - - # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - if is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, batch_size) - value_states = self._shape(self.v_proj(key_value_states_original), -1, batch_size) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, batch_size) - value_states = self._shape(self.v_proj(hidden_states_original), -1, batch_size) + position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings are added to both queries and keys (but not values). + """ + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + + query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - proj_shape = (batch_size * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, target_len, batch_size).view(*proj_shape) - key_states = key_states.view(*proj_shape) - value_states = value_states.view(*proj_shape) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - source_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) +class DetrCrossAttention(nn.Module): + """ + Multi-headed cross-attention from 'Attention Is All You Need' paper. - if attn_weights.size() != (batch_size * self.num_heads, target_len, source_len): - raise ValueError( - f"Attention weights should be of size {(batch_size * self.num_heads, target_len, source_len)}, but is" - f" {attn_weights.size()}" - ) + In DETR, queries get their own position embeddings, while keys get encoder position embeddings. + Values don't get any position embeddings. + """ - if attention_mask is not None: - if attention_mask.size() != (batch_size, 1, target_len, source_len): - raise ValueError( - f"Attention mask should be of size {(batch_size, 1, target_len, source_len)}, but is" - f" {attention_mask.size()}" - ) - if attention_mask.dtype == torch.bool: - attention_mask = torch.zeros_like(attention_mask, dtype=attn_weights.dtype).masked_fill_( - attention_mask, -torch.inf - ) - attn_weights = attn_weights.view(batch_size, self.num_heads, target_len, source_len) + attention_mask - attn_weights = attn_weights.view(batch_size * self.num_heads, target_len, source_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(batch_size, self.num_heads, target_len, source_len) - attn_weights = attn_weights_reshaped.view(batch_size * self.num_heads, target_len, source_len) - else: - attn_weights_reshaped = None + def __init__( + self, + config: DetrConfig, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + ): + super().__init__() + self.config = config + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout + self.is_causal = False - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + self.k_proj = nn.Linear(hidden_size, hidden_size) + self.v_proj = nn.Linear(hidden_size, hidden_size) + self.q_proj = nn.Linear(hidden_size, hidden_size) + self.o_proj = nn.Linear(hidden_size, hidden_size) + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_embeddings: Optional[torch.Tensor] = None, + encoder_position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Position embeddings logic: + - Queries get position_embeddings + - Keys get encoder_position_embeddings + - Values don't get any position embeddings + """ + query_input_shape = hidden_states.shape[:-1] + query_hidden_shape = (*query_input_shape, -1, self.head_dim) - attn_output = torch.bmm(attn_probs, value_states) + kv_input_shape = key_value_states.shape[:-1] + kv_hidden_shape = (*kv_input_shape, -1, self.head_dim) - if attn_output.size() != (batch_size * self.num_heads, target_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(batch_size, self.num_heads, target_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) + query_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states + key_input = ( + key_value_states + encoder_position_embeddings + if encoder_position_embeddings is not None + else key_value_states + ) + + query_states = self.q_proj(query_input).view(query_hidden_shape).transpose(1, 2) + key_states = self.k_proj(key_input).view(kv_hidden_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(kv_hidden_shape).transpose(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) - attn_output = attn_output.view(batch_size, self.num_heads, target_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - attn_output = attn_output.reshape(batch_size, target_len, embed_dim) + attn_output = attn_output.reshape(*query_input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights - attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped +class DetrMLP(nn.Module): + def __init__(self, config: DetrConfig, hidden_size: int, intermediate_size: int): + super().__init__() + self.fc1 = nn.Linear(hidden_size, intermediate_size) + self.fc2 = nn.Linear(intermediate_size, hidden_size) + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.dropout = config.dropout + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + return hidden_states class DetrEncoderLayer(nn.Module): def __init__(self, config: DetrConfig): super().__init__() - self.embed_dim = config.d_model - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + self.hidden_size = config.d_model + self.self_attn = DetrSelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.encoder_attention_heads, dropout=config.attention_dropout, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.mlp = DetrMLP(config, self.hidden_size, config.encoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, - object_queries: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ): + spatial_position_embeddings: Optional[torch.Tensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - object_queries (`torch.FloatTensor`, *optional*): - Object queries (also called content embeddings), to be added to the hidden states. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings of image locations), to be added to both + the queries and keys in self-attention (but not to values). """ residual = hidden_states - hidden_states, attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - object_queries=object_queries, - output_attentions=output_attentions, + position_embeddings=spatial_position_embeddings, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -592,12 +633,7 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -606,78 +642,69 @@ def forward( clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states class DetrDecoderLayer(GradientCheckpointingLayer): def __init__(self, config: DetrConfig): super().__init__() - self.embed_dim = config.d_model + self.hidden_size = config.d_model - self.self_attn = DetrAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, + self.self_attn = DetrSelfAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = DetrAttention( - self.embed_dim, - config.decoder_attention_heads, + self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.encoder_attn = DetrCrossAttention( + config=config, + hidden_size=self.hidden_size, + num_attention_heads=config.decoder_attention_heads, dropout=config.attention_dropout, ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size) + self.mlp = DetrMLP(config, self.hidden_size, config.decoder_ffn_dim) + self.final_layer_norm = nn.LayerNorm(self.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - object_queries: Optional[torch.Tensor] = None, - query_position_embeddings: Optional[torch.Tensor] = None, + spatial_position_embeddings: Optional[torch.Tensor] = None, + object_queries_position_embeddings: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ): + **kwargs: Unpack[TransformersKwargs], + ) -> torch.Tensor: """ Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)` attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - object_queries (`torch.FloatTensor`, *optional*): - object_queries that are added to the hidden states - in the cross-attention layer. - query_position_embeddings (`torch.FloatTensor`, *optional*): - position embeddings that are added to the queries and keys - in the self-attention layer. + spatial_position_embeddings (`torch.FloatTensor`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys only + in the cross-attention layer (not to values). + object_queries_position_embeddings (`torch.FloatTensor`, *optional*): + Position embeddings for the object query slots. In self-attention, these are added to both queries + and keys (not values). In cross-attention, these are added to queries only (not to keys or values). encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + cross attention input to the layer of shape `(batch, seq_len, hidden_size)` encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative values. - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. """ residual = hidden_states # Self Attention - hidden_states, self_attn_weights = self.self_attn( + hidden_states, _ = self.self_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, + position_embeddings=object_queries_position_embeddings, attention_mask=attention_mask, - output_attentions=output_attentions, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -685,17 +712,16 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - hidden_states, cross_attn_weights = self.encoder_attn( + hidden_states, _ = self.encoder_attn( hidden_states=hidden_states, - object_queries=query_position_embeddings, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - spatial_position_embeddings=object_queries, - output_attentions=output_attentions, + position_embeddings=object_queries_position_embeddings, + encoder_position_embeddings=spatial_position_embeddings, + **kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -704,19 +730,11 @@ def forward( # Fully Connected residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states hidden_states = self.final_layer_norm(hidden_states) - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - return outputs + return hidden_states @auto_docstring @@ -725,16 +743,28 @@ class DetrPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "pixel_values" _no_split_modules = [r"DetrConvEncoder", r"DetrEncoderLayer", r"DetrDecoderLayer"] + _supports_sdpa = True + _supports_flash_attn = True + _supports_attention_backend = True + _supports_flex_attn = True + _checkpoint_conversion_mapping = { + "model.backbone.conv_encoder": "model.backbone", + "out_proj": "o_proj", + "bbox_attention.q_linear": "bbox_attention.q_proj", + "bbox_attention.k_linear": "bbox_attention.k_proj", + r"(\d+)\.fc1": r"\1.mlp.fc1", + r"(\d+)\.fc2": r"\1.mlp.fc2", + } def _init_weights(self, module): std = self.config.init_std xavier_std = self.config.init_xavier_std if isinstance(module, DetrMHAttentionMap): - nn.init.zeros_(module.k_linear.bias) - nn.init.zeros_(module.q_linear.bias) - nn.init.xavier_uniform_(module.k_linear.weight, gain=xavier_std) - nn.init.xavier_uniform_(module.q_linear.weight, gain=xavier_std) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.q_proj.bias) + nn.init.xavier_uniform_(module.k_proj.weight, gain=xavier_std) + nn.init.xavier_uniform_(module.q_proj.weight, gain=xavier_std) elif isinstance(module, DetrLearnedPositionEmbedding): nn.init.uniform_(module.row_embeddings.weight) nn.init.uniform_(module.column_embeddings.weight) @@ -750,46 +780,40 @@ def _init_weights(self, module): class DetrEncoder(DetrPreTrainedModel): """ - Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a - [`DetrEncoderLayer`]. - - The encoder updates the flattened feature map through multiple self-attention layers. - - Small tweak for DETR: - - - object_queries are added to the forward pass. + Transformer encoder that processes a flattened feature map from a vision backbone, composed of a stack of + [`DetrEncoderLayer`] modules. Args: - config: DetrConfig + config (`DetrConfig`): Model configuration object. """ + _can_record_outputs = { + "hidden_states": DetrEncoderLayer, + "attentions": OutputRecorder(DetrSelfAttention, layer_name="self_attn", index=1), + } + def __init__(self, config: DetrConfig): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop - self.layers = nn.ModuleList([DetrEncoderLayer(config) for _ in range(config.encoder_layers)]) - # in the original DETR, no layernorm is used at the end of the encoder, as "normalize_before" is set to False by default - # Initialize weights and apply final processing self.post_init() + @check_model_inputs() def forward( self, inputs_embeds=None, attention_mask=None, - object_queries=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + spatial_position_embeddings=None, + **kwargs: Unpack[TransformersKwargs], + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): Flattened feature map (output of the backbone + projection layer) that is passed to the encoder. - attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding pixel features. Mask values selected in `[0, 1]`: @@ -797,38 +821,22 @@ def forward( - 0 for pixel features that are padding (i.e. **masked**). [What are attention masks?](../glossary#attention-mask) - - object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): - Object queries that are added to the queries in each self-attention layer. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Spatial position embeddings (2D positional encodings) that are added to the queries and keys in each self-attention layer. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - hidden_states = inputs_embeds hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask if attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = create_bidirectional_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, + ) - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None - for i, encoder_layer in enumerate(self.layers): - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + for encoder_layer in self.layers: # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) to_drop = False if self.training: @@ -837,46 +845,31 @@ def forward( to_drop = True if to_drop: - layer_outputs = (None, None) + hidden_states = None else: - # we add object_queries as extra input to the encoder_layer - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - object_queries=object_queries, - output_attentions=output_attentions, + # we add spatial_position_embeddings as extra input to the encoder_layer + hidden_states = encoder_layer( + hidden_states, attention_mask, spatial_position_embeddings=spatial_position_embeddings, **kwargs ) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions - ) + return BaseModelOutput(last_hidden_state=hidden_states) class DetrDecoder(DetrPreTrainedModel): """ - Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DetrDecoderLayer`]. - - The decoder updates the query embeddings through multiple self-attention and cross-attention layers. - - Some small tweaks for DETR: - - - object_queries and query_position_embeddings are added to the forward pass. - - if self.config.auxiliary_loss is set to True, also returns a stack of activations from all decoding layers. + Transformer decoder that refines a set of object queries. It is composed of a stack of [`DetrDecoderLayer`] modules, + which apply self-attention to the queries and cross-attention to the encoder's outputs. Args: - config: DetrConfig + config (`DetrConfig`): Model configuration object. """ + _can_record_outputs = { + "hidden_states": DetrDecoderLayer, + "attentions": OutputRecorder(DetrSelfAttention, layer_name="self_attn", index=1), + "cross_attentions": OutputRecorder(DetrCrossAttention, layer_name="encoder_attn", index=1), + } + def __init__(self, config: DetrConfig): super().__init__(config) self.dropout = config.dropout @@ -890,18 +883,17 @@ def __init__(self, config: DetrConfig): # Initialize weights and apply final processing self.post_init() + @check_model_inputs() def forward( self, inputs_embeds=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - object_queries=None, - query_position_embeddings=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, - ): + spatial_position_embeddings=None, + object_queries_position_embeddings=None, + **kwargs: Unpack[TransformersKwargs], + ) -> DetrDecoderOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -924,108 +916,68 @@ def forward( - 1 for pixels that are real (i.e. **not masked**), - 0 for pixels that are padding (i.e. **masked**). - object_queries (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Object queries that are added to the queries and keys in each cross-attention layer. - query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`): - , *optional*): Position embeddings that are added to the values and keys in each self-attention layer. - - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - output_hidden_states (`bool`, *optional*): - Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors - for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + spatial_position_embeddings (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Spatial position embeddings (2D positional encodings from encoder) that are added to the keys in each cross-attention layer. + query_position_embeddings (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): + Position embeddings for the object query slots that are added to the queries and keys in each self-attention layer. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict if inputs_embeds is not None: hidden_states = inputs_embeds - input_shape = inputs_embeds.size()[:-1] - combined_attention_mask = None - - if attention_mask is not None and combined_attention_mask is not None: - # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - combined_attention_mask = combined_attention_mask + _prepare_4d_attention_mask( - attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + # expand decoder attention mask (for self-attention on object queries) + if attention_mask is not None: + # [batch_size, num_queries] -> [batch_size, 1, num_queries, num_queries] + attention_mask = create_bidirectional_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=attention_mask, ) - # expand encoder attention mask + # expand encoder attention mask (for cross-attention on encoder outputs) if encoder_hidden_states is not None and encoder_attention_mask is not None: # [batch_size, seq_len] -> [batch_size, 1, target_seq_len, source_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask = create_bidirectional_mask( + config=self.config, + input_embeds=inputs_embeds, + attention_mask=encoder_attention_mask, + encoder_hidden_states=encoder_hidden_states, ) # optional intermediate hidden states intermediate = () if self.config.auxiliary_loss else None # decoder layers - all_hidden_states = () if output_hidden_states else None - all_self_attns = () if output_attentions else None - all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://huggingface.co/papers/1909.11556 for description) - if output_hidden_states: - all_hidden_states += (hidden_states,) if self.training: dropout_probability = torch.rand([]) if dropout_probability < self.layerdrop: continue - layer_outputs = decoder_layer( + hidden_states = decoder_layer( hidden_states, - combined_attention_mask, - object_queries, - query_position_embeddings, + attention_mask, + spatial_position_embeddings, + object_queries_position_embeddings, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - output_attentions=output_attentions, + **kwargs, ) - hidden_states = layer_outputs[0] - if self.config.auxiliary_loss: hidden_states = self.layernorm(hidden_states) intermediate += (hidden_states,) - if output_attentions: - all_self_attns += (layer_outputs[1],) - - if encoder_hidden_states is not None: - all_cross_attentions += (layer_outputs[2],) - # finally, apply layernorm hidden_states = self.layernorm(hidden_states) - # add hidden states from the last decoder layer - if output_hidden_states: - all_hidden_states += (hidden_states,) - # stack intermediate decoder activations if self.config.auxiliary_loss: intermediate = torch.stack(intermediate) - if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attns, all_cross_attentions, intermediate] - if v is not None - ) - return DetrDecoderOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attns, - cross_attentions=all_cross_attentions, - intermediate_hidden_states=intermediate, - ) + return DetrDecoderOutput(last_hidden_state=hidden_states, intermediate_hidden_states=intermediate) @auto_docstring( @@ -1038,15 +990,16 @@ class DetrModel(DetrPreTrainedModel): def __init__(self, config: DetrConfig): super().__init__(config) - # Create backbone + positional encoding - backbone = DetrConvEncoder(config) - object_queries = build_position_encoding(config) - self.backbone = DetrConvModel(backbone, object_queries) - - # Create projection layer - self.input_projection = nn.Conv2d(backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) + self.backbone = DetrConvEncoder(config) + if config.position_embedding_type == "sine": + self.position_embedding = DetrSinePositionEmbedding(config.d_model // 2, normalize=True) + elif config.position_embedding_type == "learned": + self.position_embedding = DetrLearnedPositionEmbedding(config.d_model // 2) + else: + raise ValueError(f"Not supported {config.position_embedding_type}") self.query_position_embeddings = nn.Embedding(config.num_queries, config.d_model) + self.input_projection = nn.Conv2d(self.backbone.intermediate_channel_sizes[-1], config.d_model, kernel_size=1) self.encoder = DetrEncoder(config) self.decoder = DetrDecoder(config) @@ -1058,35 +1011,37 @@ def get_encoder(self): return self.encoder def freeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for _, param in self.backbone.model.named_parameters(): param.requires_grad_(False) def unfreeze_backbone(self): - for name, param in self.backbone.conv_encoder.model.named_parameters(): + for _, param in self.backbone.model.named_parameters(): param.requires_grad_(True) @auto_docstring + @can_return_tuple def forward( self, - pixel_values: torch.FloatTensor, + pixel_values: Optional[torch.FloatTensor] = None, pixel_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.FloatTensor] = None, encoder_outputs: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.FloatTensor], DetrModelOutput]: r""" decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): - Not used by default. Can be used to mask object queries. + Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you - can choose to directly pass a flattened representation of an image. + can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an - embedded representation. + embedded representation. Useful for tasks that require custom query initialization. Examples: @@ -1113,79 +1068,83 @@ def forward( >>> list(last_hidden_states.shape) [1, 100, 256] ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - batch_size, num_channels, height, width = pixel_values.shape - device = pixel_values.device - - if pixel_mask is None: - pixel_mask = torch.ones(((batch_size, height, width)), device=device) - - # First, sent pixel_values + pixel_mask through Backbone to obtain the features - # pixel_values should be of shape (batch_size, num_channels, height, width) - # pixel_mask should be of shape (batch_size, height, width) - features, object_queries_list = self.backbone(pixel_values, pixel_mask) - - # get final feature map and downsampled mask - feature_map, mask = features[-1] - - if mask is None: - raise ValueError("Backbone does not return downsampled pixel mask") - - # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) - projected_feature_map = self.input_projection(feature_map) - - # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC - # In other words, turn their shape into (batch_size, sequence_length, hidden_size) - flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) - object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) - - flattened_mask = mask.flatten(1) + if pixel_values is None and inputs_embeds is None: + raise ValueError("You have to specify either pixel_values or inputs_embeds") + + if inputs_embeds is None: + batch_size, num_channels, height, width = pixel_values.shape + device = pixel_values.device + + if pixel_mask is None: + pixel_mask = torch.ones(((batch_size, height, width)), device=device) + vision_features = self.backbone(pixel_values, pixel_mask) + feature_map, mask = vision_features[-1] + + # Apply 1x1 conv to map (N, C, H, W) -> (N, d_model, H, W), then flatten to (N, HW, d_model) + # (feature map and position embeddings are flattened and permuted to (batch_size, sequence_length, hidden_size)) + projected_feature_map = self.input_projection(feature_map) + flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) + spatial_position_embeddings = ( + self.position_embedding(shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask) + .flatten(2) + .permute(0, 2, 1) + ) + flattened_mask = mask.flatten(1) + else: + batch_size = inputs_embeds.shape[0] + device = inputs_embeds.device + flattened_features = inputs_embeds + # When using inputs_embeds, we need to infer spatial dimensions for position embeddings + # Assume square feature map + seq_len = inputs_embeds.shape[1] + feat_dim = int(seq_len**0.5) + # Create position embeddings for the inferred spatial size + spatial_position_embeddings = ( + self.position_embedding( + shape=torch.Size([batch_size, self.config.d_model, feat_dim, feat_dim]), + device=device, + dtype=inputs_embeds.dtype, + ) + .flatten(2) + .permute(0, 2, 1) + ) + # If a pixel_mask is provided with inputs_embeds, interpolate it to feat_dim, then flatten. + if pixel_mask is not None: + mask = nn.functional.interpolate(pixel_mask[None].float(), size=(feat_dim, feat_dim)).to(torch.bool)[0] + flattened_mask = mask.flatten(1) + else: + # If no mask provided, assume all positions are valid + flattened_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.long) - # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder - # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) - # flattened_mask is a Tensor of shape (batch_size, height*width) if encoder_outputs is None: encoder_outputs = self.encoder( inputs_embeds=flattened_features, attention_mask=flattened_mask, - object_queries=object_queries, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + spatial_position_embeddings=spatial_position_embeddings, + **kwargs, ) - # Fifth, sent query embeddings + object_queries through the decoder (which is conditioned on the encoder output) - query_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat(batch_size, 1, 1) - queries = torch.zeros_like(query_position_embeddings) + object_queries_position_embeddings = self.query_position_embeddings.weight.unsqueeze(0).repeat( + batch_size, 1, 1 + ) + + # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros + if decoder_inputs_embeds is not None: + queries = decoder_inputs_embeds + else: + queries = torch.zeros_like(object_queries_position_embeddings) # decoder outputs consists of (dec_features, dec_hidden, dec_attn) decoder_outputs = self.decoder( inputs_embeds=queries, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_outputs[0], + attention_mask=decoder_attention_mask, + spatial_position_embeddings=spatial_position_embeddings, + object_queries_position_embeddings=object_queries_position_embeddings, + encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=flattened_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) - if not return_dict: - return decoder_outputs + encoder_outputs - return DetrModelOutput( last_hidden_state=decoder_outputs.last_hidden_state, decoder_hidden_states=decoder_outputs.hidden_states, @@ -1245,6 +1204,7 @@ def __init__(self, config: DetrConfig): self.post_init() @auto_docstring + @can_return_tuple def forward( self, pixel_values: torch.FloatTensor, @@ -1254,19 +1214,20 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[list[dict]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.FloatTensor], DetrObjectDetectionOutput]: r""" decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): - Not used by default. Can be used to mask object queries. + Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you - can choose to directly pass a flattened representation of an image. + can choose to directly pass a flattened representation of an image. Useful for bypassing the vision backbone. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an - embedded representation. + embedded representation. Useful for tasks that require custom query initialization. labels (`list[Dict]` of len `(batch_size,)`, *optional*): Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch @@ -1308,7 +1269,6 @@ def forward( Detected cat with confidence 0.999 at location [13.24, 52.05, 314.02, 470.93] Detected cat with confidence 0.999 at location [345.4, 23.85, 640.37, 368.72] ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict # First, sent images through DETR base model to obtain encoder + decoder outputs outputs = self.model( @@ -1318,9 +1278,7 @@ def forward( encoder_outputs=encoder_outputs, inputs_embeds=inputs_embeds, decoder_inputs_embeds=decoder_inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) sequence_output = outputs[0] @@ -1333,20 +1291,13 @@ def forward( if labels is not None: outputs_class, outputs_coord = None, None if self.config.auxiliary_loss: - intermediate = outputs.intermediate_hidden_states if return_dict else outputs[4] + intermediate = outputs.intermediate_hidden_states outputs_class = self.class_labels_classifier(intermediate) outputs_coord = self.bbox_predictor(intermediate).sigmoid() loss, loss_dict, auxiliary_outputs = self.loss_function( logits, labels, self.device, pred_boxes, self.config, outputs_class, outputs_coord ) - if not return_dict: - if auxiliary_outputs is not None: - output = (logits, pred_boxes) + auxiliary_outputs + outputs - else: - output = (logits, pred_boxes) + outputs - return ((loss, loss_dict) + output) if loss is not None else output - return DetrObjectDetectionOutput( loss=loss, loss_dict=loss_dict, @@ -1378,19 +1329,18 @@ def __init__(self, config: DetrConfig): # segmentation head hidden_size, number_of_heads = config.d_model, config.encoder_attention_heads - intermediate_channel_sizes = self.detr.model.backbone.conv_encoder.intermediate_channel_sizes + intermediate_channel_sizes = self.detr.model.backbone.intermediate_channel_sizes self.mask_head = DetrMaskHeadSmallConv( hidden_size + number_of_heads, intermediate_channel_sizes[::-1][-3:], hidden_size ) - self.bbox_attention = DetrMHAttentionMap( - hidden_size, hidden_size, number_of_heads, dropout=0.0, std=config.init_xavier_std - ) + self.bbox_attention = DetrMHAttentionMap(hidden_size, number_of_heads, dropout=0.0) # Initialize weights and apply final processing self.post_init() @auto_docstring + @can_return_tuple def forward( self, pixel_values: torch.FloatTensor, @@ -1400,19 +1350,20 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[list[dict]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, + **kwargs: Unpack[TransformersKwargs], ) -> Union[tuple[torch.FloatTensor], DetrSegmentationOutput]: r""" decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*): - Not used by default. Can be used to mask object queries. + Mask to avoid performing attention on certain object queries in the decoder. Mask values selected in `[0, 1]`: + + - 1 for queries that are **not masked**, + - 0 for queries that are **masked**. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): - Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you - can choose to directly pass a flattened representation of an image. + Kept for backward compatibility, but cannot be used for segmentation, as segmentation requires + multi-scale features from the backbone that are not available when bypassing it with inputs_embeds. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*): Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an - embedded representation. + embedded representation. Useful for tasks that require custom query initialization. labels (`list[Dict]` of len `(batch_size,)`, *optional*): Labels for computing the bipartite matching loss, DICE/F-1 loss and Focal loss. List of dicts, each dictionary containing at least the following 3 keys: 'class_labels', 'boxes' and 'masks' (the class labels, @@ -1455,83 +1406,72 @@ def forward( >>> panoptic_segments_info = result[0]["segments_info"] ```""" - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size, num_channels, height, width = pixel_values.shape device = pixel_values.device if pixel_mask is None: pixel_mask = torch.ones((batch_size, height, width), device=device) - # First, get list of feature maps and position embeddings - features, object_queries_list = self.detr.model.backbone(pixel_values, pixel_mask=pixel_mask) + vision_features = self.detr.model.backbone(pixel_values, pixel_mask) + feature_map, mask = vision_features[-1] - # Second, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default) - feature_map, mask = features[-1] - batch_size, num_channels, height, width = feature_map.shape + # Apply 1x1 conv to map (N, C, H, W) -> (N, d_model, H, W), then flatten to (N, HW, d_model) projected_feature_map = self.detr.model.input_projection(feature_map) - - # Third, flatten the feature map + position embeddings of shape NxCxHxW to NxCxHW, and permute it to NxHWxC - # In other words, turn their shape into (batch_size, sequence_length, hidden_size) flattened_features = projected_feature_map.flatten(2).permute(0, 2, 1) - object_queries = object_queries_list[-1].flatten(2).permute(0, 2, 1) - + spatial_position_embeddings = ( + self.detr.model.position_embedding( + shape=feature_map.shape, device=device, dtype=pixel_values.dtype, mask=mask + ) + .flatten(2) + .permute(0, 2, 1) + ) flattened_mask = mask.flatten(1) - # Fourth, sent flattened_features + flattened_mask + position embeddings through encoder - # flattened_features is a Tensor of shape (batch_size, height*width, hidden_size) - # flattened_mask is a Tensor of shape (batch_size, height*width) if encoder_outputs is None: encoder_outputs = self.detr.model.encoder( inputs_embeds=flattened_features, attention_mask=flattened_mask, - object_queries=object_queries, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True - elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): - encoder_outputs = BaseModelOutput( - last_hidden_state=encoder_outputs[0], - hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, - attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + spatial_position_embeddings=spatial_position_embeddings, + **kwargs, ) - # Fifth, sent query embeddings + position embeddings through the decoder (which is conditioned on the encoder output) - query_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat( + object_queries_position_embeddings = self.detr.model.query_position_embeddings.weight.unsqueeze(0).repeat( batch_size, 1, 1 ) - queries = torch.zeros_like(query_position_embeddings) - # decoder outputs consists of (dec_features, dec_hidden, dec_attn) + # Use decoder_inputs_embeds as queries if provided, otherwise initialize with zeros + if decoder_inputs_embeds is not None: + queries = decoder_inputs_embeds + else: + queries = torch.zeros_like(object_queries_position_embeddings) + decoder_outputs = self.detr.model.decoder( inputs_embeds=queries, - attention_mask=None, - object_queries=object_queries, - query_position_embeddings=query_position_embeddings, - encoder_hidden_states=encoder_outputs[0], + attention_mask=decoder_attention_mask, + spatial_position_embeddings=spatial_position_embeddings, + object_queries_position_embeddings=object_queries_position_embeddings, + encoder_hidden_states=encoder_outputs.last_hidden_state, encoder_attention_mask=flattened_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + **kwargs, ) sequence_output = decoder_outputs[0] - # Sixth, compute logits, pred_boxes and pred_masks logits = self.detr.class_labels_classifier(sequence_output) pred_boxes = self.detr.bbox_predictor(sequence_output).sigmoid() - memory = encoder_outputs[0].permute(0, 2, 1).view(batch_size, self.config.d_model, height, width) - mask = flattened_mask.view(batch_size, height, width) + height, width = feature_map.shape[-2:] + memory = encoder_outputs.last_hidden_state.permute(0, 2, 1).view( + batch_size, self.config.d_model, height, width + ) + attention_mask = flattened_mask.view(batch_size, height, width) - # FIXME h_boxes takes the last one computed, keep this in mind - # important: we need to reverse the mask, since in the original implementation the mask works reversed - # bbox_mask is of shape (batch_size, num_queries, number_of_attention_heads in bbox_attention, height/32, width/32) - bbox_mask = self.bbox_attention(sequence_output, memory, mask=~mask) + # Note: mask is reversed because the original DETR implementation works with reversed masks + bbox_mask = self.bbox_attention(sequence_output, memory, attention_mask=~attention_mask) - seg_masks = self.mask_head(projected_feature_map, bbox_mask, [features[2][0], features[1][0], features[0][0]]) + seg_masks = self.mask_head( + projected_feature_map, bbox_mask, [vision_features[2][0], vision_features[1][0], vision_features[0][0]] + ) pred_masks = seg_masks.view(batch_size, self.detr.config.num_queries, seg_masks.shape[-2], seg_masks.shape[-1]) @@ -1539,20 +1479,13 @@ def forward( if labels is not None: outputs_class, outputs_coord = None, None if self.config.auxiliary_loss: - intermediate = decoder_outputs.intermediate_hidden_states if return_dict else decoder_outputs[-1] + intermediate = decoder_outputs.intermediate_hidden_states outputs_class = self.detr.class_labels_classifier(intermediate) outputs_coord = self.detr.bbox_predictor(intermediate).sigmoid() loss, loss_dict, auxiliary_outputs = self.loss_function( logits, labels, device, pred_boxes, pred_masks, self.config, outputs_class, outputs_coord ) - if not return_dict: - if auxiliary_outputs is not None: - output = (logits, pred_boxes, pred_masks) + auxiliary_outputs + decoder_outputs + encoder_outputs - else: - output = (logits, pred_boxes, pred_masks) + decoder_outputs + encoder_outputs - return ((loss, loss_dict) + output) if loss is not None else output - return DetrSegmentationOutput( loss=loss, loss_dict=loss_dict, @@ -1614,7 +1547,7 @@ def __init__(self, dim, fpn_dims, context_dim): nn.init.kaiming_uniform_(m.weight, a=1) nn.init.constant_(m.bias, 0) - def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): + def forward(self, x: torch.Tensor, bbox_mask: torch.Tensor, fpns: list[torch.Tensor]): # here we concatenate x, the projected feature map, of shape (batch_size, d_model, height/32, width/32) with # the bbox_mask = the attention maps of shape (batch_size, n_queries, n_heads, height/32, width/32). # We expand the projected feature map to match the number of heads. @@ -1658,29 +1591,44 @@ def forward(self, x: Tensor, bbox_mask: Tensor, fpns: list[Tensor]): class DetrMHAttentionMap(nn.Module): """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" - def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0, bias=True, std=None): + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + dropout: float = 0.0, + bias: bool = True, + ): super().__init__() - self.num_heads = num_heads - self.hidden_dim = hidden_dim + self.head_dim = hidden_size // num_attention_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = dropout self.dropout = nn.Dropout(dropout) - self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias) - self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias) + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias) + + def forward( + self, query_states: torch.Tensor, key_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ): + query_hidden_shape = (*query_states.shape[:-1], -1, self.head_dim) + key_hidden_shape = (key_states.shape[0], -1, self.head_dim, *key_states.shape[-2:]) + + query_states = self.q_proj(query_states).view(query_hidden_shape) + key_states = nn.functional.conv2d( + key_states, self.k_proj.weight.unsqueeze(-1).unsqueeze(-1), self.k_proj.bias + ).view(key_hidden_shape) - self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5 + attn_weights = torch.einsum("bqnc,bnchw->bqnhw", query_states * self.scaling, key_states) + + if attention_mask is not None: + attn_weights = attn_weights.masked_fill( + attention_mask.unsqueeze(1).unsqueeze(1), torch.finfo(attn_weights.dtype).min + ) - def forward(self, q, k, mask: Optional[Tensor] = None): - q = self.q_linear(q) - k = nn.functional.conv2d(k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias) - queries_per_head = q.view(q.shape[0], q.shape[1], self.num_heads, self.hidden_dim // self.num_heads) - keys_per_head = k.view(k.shape[0], self.num_heads, self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]) - weights = torch.einsum("bqnc,bnchw->bqnhw", queries_per_head * self.normalize_fact, keys_per_head) + attn_weights = nn.functional.softmax(attn_weights.flatten(2), dim=-1).view(attn_weights.size()) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) - if mask is not None: - weights = weights.masked_fill(mask.unsqueeze(1).unsqueeze(1), torch.finfo(weights.dtype).min) - weights = nn.functional.softmax(weights.flatten(2), dim=-1).view(weights.size()) - weights = self.dropout(weights) - return weights + return attn_weights __all__ = [ diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index 397240cadc9f..b4751774f868 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -28,7 +28,8 @@ def spawn_conversion(token: str, private: bool, model_id: str): def start(_sse_connection): for line in _sse_connection.iter_lines(): - line = line.decode() + if not isinstance(line, str): + line = line.decode() if line.startswith("event:"): status = line[7:] logger.debug(f"Safetensors conversion status: {status}") diff --git a/tests/models/detr/test_modeling_detr.py b/tests/models/detr/test_modeling_detr.py index bfa9575771b1..dadd12629d85 100644 --- a/tests/models/detr/test_modeling_detr.py +++ b/tests/models/detr/test_modeling_detr.py @@ -18,11 +18,18 @@ import unittest from functools import cached_property +from parameterized import parameterized + from transformers import DetrConfig, ResNetConfig, is_torch_available, is_vision_available from transformers.testing_utils import Expectations, require_timm, require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester -from ...test_modeling_common import ModelTesterMixin, floats_tensor +from ...test_modeling_common import ( + TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION, + ModelTesterMixin, + _test_eager_matches_sdpa_inference, + floats_tensor, +) from ...test_pipeline_mixin import PipelineTesterMixin @@ -460,13 +467,13 @@ def test_different_timm_backbone(self): ) self.assertEqual(outputs.logits.shape, expected_shape) # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) elif model_class.__name__ == "DetrForSegmentation": # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.detr.model.backbone.intermediate_channel_sizes), 3) else: # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) self.assertTrue(outputs) @@ -495,13 +502,13 @@ def test_hf_backbone(self): ) self.assertEqual(outputs.logits.shape, expected_shape) # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.model.backbone.intermediate_channel_sizes), 3) elif model_class.__name__ == "DetrForSegmentation": # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.detr.model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.detr.model.backbone.intermediate_channel_sizes), 3) else: # Confirm out_indices was propagated to backbone - self.assertEqual(len(model.backbone.conv_encoder.intermediate_channel_sizes), 3) + self.assertEqual(len(model.backbone.intermediate_channel_sizes), 3) self.assertTrue(outputs) @@ -526,6 +533,18 @@ def test_greyscale_images(self): self.assertTrue(outputs) + # override test_eager_matches_sdpa_inference to set use_attention_mask to False + # as masks used in test are not adapted to the ones used in the model + @parameterized.expand(TEST_EAGER_MATCHES_SDPA_INFERENCE_PARAMETERIZATION) + def test_eager_matches_sdpa_inference( + self, name, dtype, padding_side, use_attention_mask, output_attentions, enable_kernels + ): + if use_attention_mask: + self.skipTest( + "This test uses attention masks which are not compatible with DETR. Skipping when use_attention_mask is True." + ) + _test_eager_matches_sdpa_inference(self, name, dtype, padding_side, False, output_attentions, enable_kernels) + TOLERANCE = 1e-4