diff --git a/src/transformers/models/aimv2/modeling_aimv2.py b/src/transformers/models/aimv2/modeling_aimv2.py index 472eccd0b575..64485ffa958f 100644 --- a/src/transformers/models/aimv2/modeling_aimv2.py +++ b/src/transformers/models/aimv2/modeling_aimv2.py @@ -22,7 +22,7 @@ import math from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -35,6 +35,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import ModelOutput, auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig @@ -289,7 +290,7 @@ def forward( class Aimv2EncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Aimv2VisionConfig): + def __init__(self, config: Union[Aimv2VisionConfig, Aimv2TextConfig]): super().__init__() self.attention = Aimv2Attention(config) self.ffn = Aimv2MLP(config) @@ -300,8 +301,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: norm_hidden_states = self.rms_norm1(hidden_states) attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask) @@ -310,7 +310,7 @@ def forward( mlp_output = self.ffn(norm_hidden_states) hidden_states = hidden_states + mlp_output - return (hidden_states, attn_weights) if output_attentions else (hidden_states, None) + return hidden_states, attn_weights class Aimv2Encoder(nn.Module): @@ -322,19 +322,16 @@ class Aimv2Encoder(nn.Module): config: Aimv2Config """ - def __init__(self, config: Aimv2Config): + def __init__(self, config: Union[Aimv2VisionConfig, Aimv2TextConfig]): super().__init__() self.config = config self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - # Ignore copy - @can_return_tuple def forward( self, - inputs_embeds, + inputs_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" @@ -350,46 +347,21 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_hidden_states = [inputs_embeds] if output_hidden_states else None hidden_states = inputs_embeds for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + hidden_states, _ = encoder_layer(hidden_states, attention_mask) + if all_hidden_states: + all_hidden_states.append(hidden_states) return BaseModelOutput( last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @@ -446,6 +418,9 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True + _can_record_outputs = { + "attentions": Aimv2Attention, + } def _init_weights(self, module): super()._init_weights(module) @@ -482,14 +457,14 @@ def __init__(self, config: Aimv2VisionConfig): def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embed - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, pixel_values, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + **kwargs, ) -> BaseModelOutputWithPooling: r""" Examples: @@ -511,20 +486,16 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states hidden_states = self.embeddings(pixel_values) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, output_hidden_states=output_hidden_states ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.rms_norm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None @@ -533,7 +504,6 @@ def forward( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, ) @@ -543,6 +513,7 @@ def forward( """ ) class Aimv2TextModel(Aimv2PreTrainedModel): + config: Aimv2TextConfig main_input_name = "input_ids" def __init__(self, config: Aimv2TextConfig): @@ -562,19 +533,17 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.embeddings.token_embedding = value - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, input_ids, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + **kwargs, ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states hidden_states = self.embeddings(input_ids) batch_size, seq_len, _ = hidden_states.shape @@ -591,14 +560,13 @@ def forward( past_key_values=None, ) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.rms_norm(last_hidden_state) # Get pooled output @@ -749,8 +717,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + **kwargs, ) -> Aimv2Output: r""" Examples: @@ -775,22 +742,10 @@ def forward( >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values=pixel_values, **kwargs) text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + input_ids=input_ids, attention_mask=attention_mask, **kwargs ) image_embeds = vision_outputs.pooler_output diff --git a/src/transformers/models/aimv2/modular_aimv2.py b/src/transformers/models/aimv2/modular_aimv2.py index 5991b928a2f0..ff1a97c53e1e 100644 --- a/src/transformers/models/aimv2/modular_aimv2.py +++ b/src/transformers/models/aimv2/modular_aimv2.py @@ -24,12 +24,10 @@ from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - can_return_tuple, -) +from ...utils import auto_docstring, can_return_tuple +from ...utils.generic import check_model_inputs from ..clip.modeling_clip import CLIPModel, CLIPTextEmbeddings, _get_vector_norm from ..llama.modeling_llama import LlamaMLP, LlamaRMSNorm from ..siglip.configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig @@ -362,7 +360,7 @@ def __init__(self, config): class Aimv2EncoderLayer(GradientCheckpointingLayer): - def __init__(self, config: Aimv2VisionConfig): + def __init__(self, config: Union[Aimv2VisionConfig, Aimv2TextConfig]): super().__init__() self.attention = Aimv2Attention(config) self.ffn = Aimv2MLP(config) @@ -373,8 +371,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: norm_hidden_states = self.rms_norm1(hidden_states) attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask) @@ -383,7 +380,7 @@ def forward( mlp_output = self.ffn(norm_hidden_states) hidden_states = hidden_states + mlp_output - return (hidden_states, attn_weights) if output_attentions else (hidden_states, None) + return hidden_states, attn_weights class Aimv2Encoder(SiglipEncoder): @@ -443,6 +440,9 @@ class Aimv2PreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True + _can_record_outputs = { + "attentions": Aimv2Attention, + } def _init_weights(self, module): super()._init_weights(module) @@ -479,14 +479,14 @@ def __init__(self, config: Aimv2VisionConfig): def get_input_embeddings(self) -> nn.Module: return self.embeddings.patch_embed - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, pixel_values, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + **kwargs, ) -> BaseModelOutputWithPooling: r""" Examples: @@ -508,20 +508,16 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states hidden_states = self.embeddings(pixel_values) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, output_hidden_states=output_hidden_states ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.rms_norm(last_hidden_state) pooler_output = self.head(last_hidden_state) if self.use_head else None @@ -530,7 +526,6 @@ def forward( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, ) @@ -540,6 +535,7 @@ def forward( """ ) class Aimv2TextModel(Aimv2PreTrainedModel): + config: Aimv2TextConfig main_input_name = "input_ids" def __init__(self, config: Aimv2TextConfig): @@ -559,19 +555,17 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.embeddings.token_embedding = value - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, input_ids, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + **kwargs, ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states hidden_states = self.embeddings(input_ids) batch_size, seq_len, _ = hidden_states.shape @@ -588,14 +582,13 @@ def forward( past_key_values=None, ) - encoder_outputs = self.encoder( + encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.rms_norm(last_hidden_state) # Get pooled output @@ -641,8 +634,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, pixel_values: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + **kwargs, ) -> Aimv2Output: r""" Examples: @@ -667,22 +659,10 @@ def forward( >>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - vision_outputs: BaseModelOutputWithPooling = self.vision_model( - pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - ) + vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values=pixel_values, **kwargs) text_outputs: BaseModelOutputWithPooling = self.text_model( - input_ids=input_ids, - attention_mask=attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + input_ids=input_ids, attention_mask=attention_mask, **kwargs ) image_embeds = vision_outputs.pooler_output diff --git a/src/transformers/models/eomt/modeling_eomt.py b/src/transformers/models/eomt/modeling_eomt.py index ff04a55614e6..d90ccbd296cb 100644 --- a/src/transformers/models/eomt/modeling_eomt.py +++ b/src/transformers/models/eomt/modeling_eomt.py @@ -736,7 +736,7 @@ def eager_attention_forward( class EomtAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: EomtConfig): super().__init__() self.config = config self.embed_dim = config.hidden_size diff --git a/src/transformers/models/eomt/modular_eomt.py b/src/transformers/models/eomt/modular_eomt.py index 7d9a6100a082..d9220ddc23a3 100644 --- a/src/transformers/models/eomt/modular_eomt.py +++ b/src/transformers/models/eomt/modular_eomt.py @@ -286,7 +286,8 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: class EomtAttention(SiglipAttention): - pass + def __init__(self, config: EomtConfig): + super().__init__() class EomtLayerScale(Dinov2LayerScale): diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 00b8b18cd3d1..953a7399d7e7 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -25,13 +25,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.deprecation import deprecate_kwarg +from ...utils.generic import check_model_inputs from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig @@ -313,7 +312,7 @@ def __init__(self, config: Idefics2VisionConfig): output_size=config.hidden_size, ) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) @@ -322,8 +321,9 @@ def forward(self, hidden_state): residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) + pooled_output = hidden_state[:, 0] - return hidden_state[:, 0] + return pooled_output class Idefics2EncoderLayer(GradientCheckpointingLayer): @@ -341,7 +341,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, - ) -> tuple[torch.FloatTensor]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: hidden_states (`torch.FloatTensor`): @@ -367,15 +367,9 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) + return hidden_states, attn_weights - return outputs - -# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics2 class Idefics2Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -385,21 +379,19 @@ class Idefics2Encoder(nn.Module): config: Idefics2Config """ - def __init__(self, config: Idefics2Config): + def __init__(self, config: Idefics2VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Idefics2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - # Ignore copy def forward( self, - inputs_embeds, + inputs_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: + **kwargs, + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -413,46 +405,21 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_hidden_states = [inputs_embeds] if output_hidden_states else None hidden_states = inputs_embeds for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + hidden_states, _ = encoder_layer(hidden_states, attention_mask) + if all_hidden_states: + all_hidden_states.append(hidden_states) - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @@ -503,15 +470,17 @@ class Idefics2VisionTransformer(Idefics2PreTrainedModel): _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True + _can_record_outputs = { + "hidden_states": Idefics2EncoderLayer, + "attentions": Idefics2VisionAttention, + } def __init__(self, config: Idefics2VisionConfig): super().__init__(config) - embed_dim = config.hidden_size - self.config = config self.embeddings = Idefics2VisionEmbeddings(config) self.encoder = Idefics2Encoder(config) - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def get_input_embeddings(self): @@ -520,36 +489,26 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embeddings = value + @check_model_inputs @auto_docstring def forward( self, - pixel_values, + pixel_values: torch.FloatTensor, patch_attention_mask: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: + **kwargs, + ) -> BaseModelOutput: r""" patch_attention_mask (`torch.BoolTensor` of shape `(batch_size, num_patches_height, num_patches_width)`, *optional*): The attention mask for the patches. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - batch_size = pixel_values.size(0) + batch_size, _, height, width = pixel_values.shape if patch_attention_mask is None: - patch_size = self.config.patch_size + num_patches_height = height // self.config.patch_size + num_patches_width = width // self.config.patch_size patch_attention_mask = torch.ones( - ( - batch_size, - pixel_values.size(2) // patch_size, - pixel_values.size(3) // patch_size, - ) + (batch_size, num_patches_height, num_patches_width), dtype=torch.bool, device=pixel_values.device ) - patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) @@ -562,25 +521,14 @@ def forward( elif not self._use_flash_attention_2: patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=patch_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, attention_mask=patch_attention_mask, **kwargs ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) - if not return_dict: - return (last_hidden_state,) + encoder_outputs[1:] - - return BaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + return BaseModelOutput(last_hidden_state=last_hidden_state) # Copied from transformers.models.llama.modeling_llama.repeat_kv @@ -618,7 +566,7 @@ def extra_repr(self): class Idefics2PerceiverAttention(nn.Module): - def __init__(self, config, layer_idx: Optional[int] = None) -> None: + def __init__(self, config: Idefics2PerceiverConfig, layer_idx: Optional[int] = None): """Perceiver Cross-Attention Module --> let long-form inputs be `context`, resampled embeddings be `latents`""" super().__init__() self.config = config @@ -638,17 +586,13 @@ def __init__(self, config, layer_idx: Optional[int] = None) -> None: self.is_causal = False - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: + **kwargs, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Runs Perceiver Self-Attention, with special (context, latents) appended along the `seq` dimension! @@ -674,20 +618,9 @@ def forward( keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_values = getattr(self, "past_key_values", past_key_values) - - if past_key_values is not None: - keys, values = past_key_values.update(keys, values, self.layer_idx) - attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": - if self.config._attn_implementation == "sdpa" and output_attentions: - logger.warning_once( - "`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " - 'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] attn_output, attn_weights = attention_interface( self, @@ -698,19 +631,17 @@ def forward( is_causal=self.is_causal, scaling=self.scaling, dropout=0.0 if not self.training else self.attention_dropout, + **kwargs, ) attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim) attn_output = self.o_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_values + return attn_output, attn_weights class Idefics2PerceiverLayer(nn.Module): - def __init__(self, config, layer_idx: int): + def __init__(self, config: Idefics2PerceiverConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.n_latents = config.resampler_n_latents @@ -728,18 +659,13 @@ def __init__(self, config, layer_idx: int): hidden_act=config.hidden_act, ) - @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = False, **kwargs, - ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: + ) -> tuple[torch.FloatTensor, Optional[torch.Tensor]]: """ Args: latents (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` @@ -759,7 +685,7 @@ def forward( latents = self.input_latents_norm(latents) context = self.input_context_norm(context) - latents, self_attn_weights, present_key_value = self.self_attn( + latents, attn_weights = self.self_attn( latents=latents, context=context, attention_mask=attention_mask, @@ -771,15 +697,7 @@ def forward( latents = self.mlp(latents) latents = residual + latents - outputs = (latents,) - - if output_attentions: - outputs += (self_attn_weights,) - - if use_cache: - outputs += (present_key_value,) - - return outputs + return latents, attn_weights @auto_docstring( @@ -793,7 +711,7 @@ class Idefics2PerceiverResampler(Idefics2PreTrainedModel): _supports_flash_attention_2 = True _supports_flex_attn = True - def __init__(self, config) -> None: + def __init__(self, config: Idefics2PerceiverConfig) -> None: super().__init__(config) self.hidden_size = config.hidden_size self.hidden_act = config.hidden_act @@ -835,25 +753,14 @@ def forward( compressed_context = latents for perceiver_layer in self.layers: - layer_outputs = perceiver_layer( - compressed_context, - context, - attention_mask=attention_mask, - position_ids=None, - past_key_values=None, - output_attentions=False, - use_cache=False, - ) - - compressed_context = layer_outputs[0] - + compressed_context, _ = perceiver_layer(compressed_context, context, attention_mask=attention_mask) compressed_context = self.norm(compressed_context) return compressed_context class Idefics2Connector(nn.Module): - def __init__(self, config): + def __init__(self, config: Idefics2Config): super().__init__() self.modality_projection = Idefics2MLP( hidden_size=config.vision_config.hidden_size, @@ -1025,30 +932,15 @@ def forward( pixel_attention_mask: Optional[torch.BoolTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, Idefics2BaseModelOutputWithPast]: + **kwargs: Unpack[TransformersKwargs], + ) -> Idefics2BaseModelOutputWithPast: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The hidden states of the image encoder after modality projection and perceiver resampling. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False # retrieve input_ids and inputs_embeds if input_ids is not None: @@ -1091,8 +983,6 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, return_dict=True, **kwargs, @@ -1115,7 +1005,7 @@ def forward( class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] - def __init__(self, config): + def __init__(self, config: Idefics2Config): super().__init__(config) self.model = Idefics2Model(config) self.image_token_id = self.config.image_token_id @@ -1167,13 +1057,10 @@ def forward( image_hidden_states: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, Idefics2CausalLMOutputWithPast]: + ) -> Idefics2CausalLMOutputWithPast: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. @@ -1223,14 +1110,8 @@ def forward( ['In this image, we can see the city of New York, and more specifically the Statue of Liberty. In this image, we can see the city of New York, and more specifically the Statue of Liberty.\n\n', 'In which city is that bridge located?\n\nThe bridge is located in the city of Pittsburgh, Pennsylvania.\n\n\nThe bridge is'] ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: Idefics2BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -1240,14 +1121,11 @@ def forward( pixel_attention_mask=pixel_attention_mask, image_hidden_states=image_hidden_states, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=True, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/src/transformers/models/idefics3/modeling_idefics3.py b/src/transformers/models/idefics3/modeling_idefics3.py index 3eafc992540c..0c5de287032a 100644 --- a/src/transformers/models/idefics3/modeling_idefics3.py +++ b/src/transformers/models/idefics3/modeling_idefics3.py @@ -25,12 +25,12 @@ from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask -from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.generic import check_model_inputs from ..auto import AutoModel from .configuration_idefics3 import Idefics3Config, Idefics3VisionConfig @@ -300,7 +300,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, - ) -> tuple[torch.FloatTensor]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: hidden_states (`torch.FloatTensor`): @@ -326,15 +326,10 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights -# Copied from transformers.models.siglip.modeling_siglip.SiglipEncoder with Siglip->Idefics3 +# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2Encoder with Idefics2->Idefics3 class Idefics3Encoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -344,21 +339,19 @@ class Idefics3Encoder(nn.Module): config: Idefics3Config """ - def __init__(self, config: Idefics3Config): + def __init__(self, config: Idefics3VisionConfig): super().__init__() self.config = config self.layers = nn.ModuleList([Idefics3EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - # Ignore copy def forward( self, - inputs_embeds, + inputs_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: + **kwargs, + ) -> BaseModelOutput: r""" Args: inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): @@ -372,46 +365,21 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_hidden_states = [inputs_embeds] if output_hidden_states else None hidden_states = inputs_embeds for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) + hidden_states, _ = encoder_layer(hidden_states, attention_mask) + if all_hidden_states: + all_hidden_states.append(hidden_states) - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) return BaseModelOutput( - last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + last_hidden_state=hidden_states, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @@ -508,55 +476,51 @@ def _init_weights(self, module): The Idefics3 Vision Transformer Model outputting raw image embedding. """ ) +# Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer with Idefics2->Idefics3 class Idefics3VisionTransformer(Idefics3PreTrainedModel): config: Idefics3VisionConfig _supports_sdpa = True _supports_flash_attn = True _supports_flex_attn = True + _can_record_outputs = { + "hidden_states": Idefics3EncoderLayer, + "attentions": Idefics3VisionAttention, + } def __init__(self, config: Idefics3VisionConfig): super().__init__(config) - embed_dim = config.hidden_size - + self.config = config self.embeddings = Idefics3VisionEmbeddings(config) self.encoder = Idefics3Encoder(config) - self.patch_size = config.patch_size - self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.post_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.get_input_embeddings def get_input_embeddings(self): return self.embeddings - # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2VisionTransformer.set_input_embeddings def set_input_embeddings(self, value): self.embeddings = value + @check_model_inputs + @auto_docstring def forward( self, - pixel_values, + pixel_values: torch.FloatTensor, patch_attention_mask: Optional[torch.BoolTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[tuple, BaseModelOutput]: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict + **kwargs, + ) -> BaseModelOutput: + r""" + patch_attention_mask (`torch.BoolTensor` of shape `(batch_size, num_patches_height, num_patches_width)`, *optional*): + The attention mask for the patches. + """ - batch_size = pixel_values.size(0) + batch_size, _, height, width = pixel_values.shape if patch_attention_mask is None: - patch_size = self.patch_size + num_patches_height = height // self.config.patch_size + num_patches_width = width // self.config.patch_size patch_attention_mask = torch.ones( - ( - batch_size, - pixel_values.size(2) // patch_size, - pixel_values.size(3) // patch_size, - ) + (batch_size, num_patches_height, num_patches_width), dtype=torch.bool, device=pixel_values.device ) - patch_attention_mask = patch_attention_mask.to(dtype=torch.bool, device=pixel_values.device) hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) @@ -564,30 +528,19 @@ def forward( # The call to `_upad_input` in `_flash_attention_forward` is expensive # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), # avoiding passing the attention_mask, which is equivalent to attending to the full sequence - if not self._use_flash_attention_2: - patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - elif not torch.any(~patch_attention_mask): + if not torch.any(~patch_attention_mask): patch_attention_mask = None + elif not self._use_flash_attention_2: + patch_attention_mask = _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) - encoder_outputs = self.encoder( - inputs_embeds=hidden_states, - attention_mask=patch_attention_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, + encoder_outputs: BaseModelOutput = self.encoder( + inputs_embeds=hidden_states, attention_mask=patch_attention_mask, **kwargs ) - last_hidden_state = encoder_outputs[0] + last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) - if not return_dict: - return (last_hidden_state,) + encoder_outputs[1:] - - return BaseModelOutput( - last_hidden_state=last_hidden_state, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) + return BaseModelOutput(last_hidden_state=last_hidden_state) @auto_docstring( @@ -749,30 +702,15 @@ def forward( pixel_attention_mask: Optional[torch.BoolTensor] = None, image_hidden_states: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, - **kwargs: Unpack[FlashAttentionKwargs], - ) -> Union[tuple, Idefics3BaseModelOutputWithPast]: + **kwargs: Unpack[TransformersKwargs], + ) -> Idefics3BaseModelOutputWithPast: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. image_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`): The hidden states of the image encoder after modality projection. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - use_cache = use_cache if use_cache is not None else self.config.use_cache - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if self.training and self.text_model.gradient_checkpointing and use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False # retrieve input_ids and inputs_embeds if input_ids is not None: @@ -805,16 +743,13 @@ def forward( image_hidden_states=image_hidden_states, ) - outputs = self.text_model( + outputs: BaseModelOutputWithPast = self.text_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=True, **kwargs, ) @@ -836,7 +771,7 @@ class Idefics3ForConditionalGeneration(Idefics3PreTrainedModel, GenerationMixin) _tied_weights_keys = ["lm_head.weight"] # Copied from transformers.models.idefics2.modeling_idefics2.Idefics2ForConditionalGeneration.__init__ with Idefics2->Idefics3 - def __init__(self, config): + def __init__(self, config: Idefics3Config): super().__init__(config) self.model = Idefics3Model(config) self.image_token_id = self.config.image_token_id @@ -892,13 +827,10 @@ def forward( image_hidden_states: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, - return_dict: Optional[bool] = None, logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: Unpack[TransformersKwargs], - ) -> Union[tuple, Idefics3CausalLMOutputWithPast]: + ) -> Idefics3CausalLMOutputWithPast: r""" pixel_attention_mask (`torch.Tensor` of shape `(batch_size, image_size, image_size)`, *optional*): Mask to avoid performing attention on padding pixel indices. @@ -963,14 +895,8 @@ def forward( >>> print(generated_texts[1]) Assistant: The bridge is in San Francisco. ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.model( + outputs: Idefics3BaseModelOutputWithPast = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -980,14 +906,11 @@ def forward( pixel_attention_mask=pixel_attention_mask, image_hidden_states=image_hidden_states, use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, cache_position=cache_position, - return_dict=True, **kwargs, ) - hidden_states = outputs[0] + hidden_states = outputs.last_hidden_state # Only compute necessary logits, and do not upcast them to float if we are not computing the loss slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep logits = self.lm_head(hidden_states[:, slice_indices, :]) diff --git a/src/transformers/models/siglip/modeling_siglip.py b/src/transformers/models/siglip/modeling_siglip.py index 74e5d7fd5a6f..c65bd5af4819 100644 --- a/src/transformers/models/siglip/modeling_siglip.py +++ b/src/transformers/models/siglip/modeling_siglip.py @@ -22,7 +22,6 @@ import numpy as np import torch from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn.init import _calculate_fan_in_and_fan_out from ...activations import ACT2FN @@ -30,7 +29,8 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import ModelOutput, auto_docstring, can_return_tuple, torch_int +from ...utils import ModelOutput, auto_docstring, torch_int +from ...utils.generic import check_model_inputs from .configuration_siglip import SiglipConfig, SiglipTextConfig, SiglipVisionConfig @@ -344,7 +344,7 @@ def eager_attention_forward( class SiglipAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config): + def __init__(self, config: Union[SiglipTextConfig, SiglipVisionConfig]): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -433,7 +433,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: torch.Tensor, output_attentions: Optional[bool] = False, - ) -> tuple[torch.FloatTensor]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """ Args: hidden_states (`torch.FloatTensor`): @@ -459,12 +459,7 @@ def forward( hidden_states = self.mlp(hidden_states) hidden_states = residual + hidden_states - outputs = (hidden_states,) - - if output_attentions: - outputs += (attn_weights,) - - return outputs + return hidden_states, attn_weights @auto_docstring @@ -483,6 +478,7 @@ class SiglipPreTrainedModel(PreTrainedModel): _supports_sdpa = True _supports_flex_attn = True _supports_attention_backend = True + _can_record_outputs = {"attentions": SiglipAttention} def _init_weights(self, module): """Initialize the weights""" @@ -531,7 +527,6 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) -# Copied from transformers.models.altclip.modeling_altclip.AltCLIPEncoder with AltCLIP->Siglip class SiglipEncoder(nn.Module): """ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a @@ -541,19 +536,16 @@ class SiglipEncoder(nn.Module): config: SiglipConfig """ - def __init__(self, config: SiglipConfig): + def __init__(self, config: Union[SiglipVisionConfig, SiglipTextConfig]): super().__init__() self.config = config self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - # Ignore copy - @can_return_tuple def forward( self, - inputs_embeds, + inputs_embeds: torch.FloatTensor, attention_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutput: r""" @@ -569,46 +561,21 @@ def forward( - 0 for tokens that are **masked**. [What are attention masks?](../glossary#attention-mask) - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. output_hidden_states (`bool`, *optional*): Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for more detail. - return_dict (`bool`, *optional*): - Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - encoder_states = () if output_hidden_states else None - all_attentions = () if output_attentions else None + all_hidden_states = [inputs_embeds] if output_hidden_states else None hidden_states = inputs_embeds for encoder_layer in self.layers: - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) - - layer_outputs = encoder_layer( - hidden_states, - attention_mask, - output_attentions=output_attentions, - ) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_attentions = all_attentions + (layer_outputs[1],) - - if output_hidden_states: - encoder_states = encoder_states + (hidden_states,) + hidden_states, _ = encoder_layer(hidden_states, attention_mask) + if all_hidden_states: + all_hidden_states.append(hidden_states) return BaseModelOutput( last_hidden_state=hidden_states, - hidden_states=encoder_states, - attentions=all_attentions, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, ) @@ -620,27 +587,16 @@ def __init__(self, config: SiglipTextConfig): self.embeddings = SiglipTextEmbeddings(config) self.encoder = SiglipEncoder(config) self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) - self.head = nn.Linear(embed_dim, config.projection_size) - @can_return_tuple @auto_docstring def forward( self, - input_ids: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - if input_ids is None: - raise ValueError("You have to specify input_ids") - input_shape = input_ids.size() input_ids = input_ids.view(-1, input_shape[-1]) @@ -658,7 +614,6 @@ def forward( encoder_outputs: BaseModelOutput = self.encoder( inputs_embeds=hidden_states, attention_mask=attention_mask, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) @@ -673,7 +628,6 @@ def forward( last_hidden_state=last_hidden_state, pooler_output=pooled_output, hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, ) @@ -697,15 +651,15 @@ def get_input_embeddings(self) -> nn.Module: def set_input_embeddings(self, value): self.text_model.embeddings.token_embedding = value - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, - input_ids: Optional[torch.Tensor] = None, + input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, + **kwargs, ) -> BaseModelOutputWithPooling: r""" Examples: @@ -723,12 +677,13 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled (EOS token) states ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states return self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) @@ -746,28 +701,18 @@ def __init__(self, config: SiglipVisionConfig): if self.use_head: self.head = SiglipMultiheadAttentionPoolingHead(config) - @can_return_tuple @auto_docstring def forward( self, - pixel_values, - output_attentions: Optional[bool] = None, + pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: Optional[bool] = False, ) -> BaseModelOutputWithPooling: - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) encoder_outputs: BaseModelOutput = self.encoder( - inputs_embeds=hidden_states, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, + inputs_embeds=hidden_states, output_hidden_states=output_hidden_states ) - last_hidden_state = encoder_outputs.last_hidden_state last_hidden_state = self.post_layernorm(last_hidden_state) @@ -777,7 +722,6 @@ def forward( last_hidden_state=last_hidden_state, pooler_output=pooler_output, hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, ) @@ -792,7 +736,7 @@ def __init__(self, config: SiglipVisionConfig): self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = SiglipMLP(config) - def forward(self, hidden_state): + def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: batch_size = hidden_state.shape[0] probe = self.probe.repeat(batch_size, 1, 1) @@ -801,8 +745,9 @@ def forward(self, hidden_state): residual = hidden_state hidden_state = self.layernorm(hidden_state) hidden_state = residual + self.mlp(hidden_state) + pooled_output = hidden_state[:, 0] - return hidden_state[:, 0] + return pooled_output @auto_docstring( @@ -825,14 +770,14 @@ def __init__(self, config: SiglipVisionConfig): def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, - pixel_values, - output_attentions: Optional[bool] = None, + pixel_values: torch.FloatTensor, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, + **kwargs, ) -> BaseModelOutputWithPooling: r""" Examples: @@ -854,10 +799,11 @@ def forward( >>> last_hidden_state = outputs.last_hidden_state >>> pooled_output = outputs.pooler_output # pooled features ```""" + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states return self.vision_model( pixel_values=pixel_values, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -905,8 +851,7 @@ def get_text_features( input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + **kwargs, ) -> torch.FloatTensor: r""" Returns: @@ -927,31 +872,20 @@ def get_text_features( >>> with torch.no_grad(): ... text_features = model.get_text_features(**inputs) ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - text_outputs: BaseModelOutputWithPooling = self.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, ) - pooled_output = text_outputs.pooler_output - return pooled_output @auto_docstring def get_image_features( self, - pixel_values: Optional[torch.FloatTensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, + pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, + **kwargs, ) -> torch.FloatTensor: r""" Returns: @@ -977,24 +911,14 @@ def get_image_features( >>> with torch.no_grad(): ... image_features = model.get_image_features(**inputs) ```""" - # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) - pooled_output = vision_outputs.pooler_output - return pooled_output - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, @@ -1003,9 +927,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, return_loss: Optional[bool] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, + **kwargs, ) -> SiglipOutput: r""" return_loss (`bool`, *optional*): @@ -1037,15 +961,11 @@ def forward( >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") 31.9% that image 0 is 'a photo of 2 cats' ```""" - # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states vision_outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values=pixel_values, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) @@ -1054,7 +974,6 @@ def forward( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, ) @@ -1113,22 +1032,23 @@ def __init__(self, config: SiglipConfig) -> None: self.vision_model = vision_model.vision_model # Classifier head - self.classifier = ( - nn.Linear(config.vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() - ) + if config.num_labels > 0: + self.classifier = nn.Linear(config.vision_config.hidden_size, config.num_labels) + else: + self.classifier = nn.Identity() # Initialize weights and apply final processing self.post_init() - @can_return_tuple + @check_model_inputs @auto_docstring def forward( self, - pixel_values: Optional[torch.Tensor] = None, + pixel_values: torch.FloatTensor, labels: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, interpolate_pos_encoding: bool = False, + **kwargs, ) -> ImageClassifierOutput: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -1161,18 +1081,14 @@ def forward( >>> print("Predicted class:", model.config.id2label[predicted_class_idx]) Predicted class: LABEL_1 ```""" - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states outputs: BaseModelOutputWithPooling = self.vision_model( pixel_values, - output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, ) - sequence_output = outputs.last_hidden_state # average pool the patch tokens @@ -1182,28 +1098,7 @@ def forward( loss = None if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.num_labels == 1: - self.config.problem_type = "regression" - elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) + loss = self.loss_function(labels, logits, self.config) return ImageClassifierOutput( loss=loss,