Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 32 additions & 77 deletions src/transformers/models/aimv2/modeling_aimv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import math
from dataclasses import dataclass
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, Union

import torch
import torch.nn.functional as F
Expand All @@ -35,6 +35,7 @@
from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import ModelOutput, auto_docstring, can_return_tuple
from ...utils.generic import check_model_inputs
from .configuration_aimv2 import Aimv2Config, Aimv2TextConfig, Aimv2VisionConfig


Expand Down Expand Up @@ -289,7 +290,7 @@ def forward(


class Aimv2EncoderLayer(GradientCheckpointingLayer):
def __init__(self, config: Aimv2VisionConfig):
def __init__(self, config: Union[Aimv2VisionConfig, Aimv2TextConfig]):
super().__init__()
self.attention = Aimv2Attention(config)
self.ffn = Aimv2MLP(config)
Expand All @@ -300,8 +301,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
norm_hidden_states = self.rms_norm1(hidden_states)
attn_output, attn_weights = self.attention(hidden_states=norm_hidden_states, attention_mask=attention_mask)

Expand All @@ -310,7 +310,7 @@ def forward(
mlp_output = self.ffn(norm_hidden_states)

hidden_states = hidden_states + mlp_output
return (hidden_states, attn_weights) if output_attentions else (hidden_states, None)
return hidden_states, attn_weights


class Aimv2Encoder(nn.Module):
Expand All @@ -322,19 +322,16 @@ class Aimv2Encoder(nn.Module):
config: Aimv2Config
"""

def __init__(self, config: Aimv2Config):
def __init__(self, config: Union[Aimv2VisionConfig, Aimv2TextConfig]):
super().__init__()
self.config = config
self.layers = nn.ModuleList([Aimv2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

# Ignore copy
@can_return_tuple
def forward(
self,
inputs_embeds,
inputs_embeds: torch.FloatTensor,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
) -> BaseModelOutput:
r"""
Expand All @@ -350,46 +347,21 @@ def forward(
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_hidden_states = [inputs_embeds] if output_hidden_states else None

hidden_states = inputs_embeds
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

layer_outputs = encoder_layer(
hidden_states,
attention_mask,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
hidden_states, _ = encoder_layer(hidden_states, attention_mask)
if all_hidden_states:
all_hidden_states.append(hidden_states)

return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
hidden_states=tuple(all_hidden_states) if all_hidden_states else None,
)


Expand Down Expand Up @@ -446,6 +418,9 @@ class Aimv2PreTrainedModel(PreTrainedModel):
_supports_sdpa = True
_supports_flash_attn = True
_supports_flex_attn = True
_can_record_outputs = {
"attentions": Aimv2Attention,
}

def _init_weights(self, module):
super()._init_weights(module)
Expand Down Expand Up @@ -482,14 +457,14 @@ def __init__(self, config: Aimv2VisionConfig):
def get_input_embeddings(self) -> nn.Module:
return self.embeddings.patch_embed

@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(
self,
pixel_values,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseModelOutputWithPooling:
r"""
Examples:
Expand All @@ -511,20 +486,16 @@ def forward(
>>> last_hidden_state = outputs.last_hidden_state
>>> pooled_output = outputs.pooler_output # pooled features
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states

hidden_states = self.embeddings(pixel_values)

encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states, output_hidden_states=output_hidden_states
)

last_hidden_state = encoder_outputs[0]
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.rms_norm(last_hidden_state)

pooler_output = self.head(last_hidden_state) if self.use_head else None
Expand All @@ -533,7 +504,6 @@ def forward(
last_hidden_state=last_hidden_state,
pooler_output=pooler_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)


Expand All @@ -543,6 +513,7 @@ def forward(
"""
)
class Aimv2TextModel(Aimv2PreTrainedModel):
config: Aimv2TextConfig
main_input_name = "input_ids"

def __init__(self, config: Aimv2TextConfig):
Expand All @@ -562,19 +533,17 @@ def get_input_embeddings(self) -> nn.Module:
def set_input_embeddings(self, value):
self.embeddings.token_embedding = value

@can_return_tuple
@check_model_inputs
@auto_docstring
def forward(
self,
input_ids,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> BaseModelOutputWithPooling:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if output_hidden_states is None:
output_hidden_states = self.config.output_hidden_states

hidden_states = self.embeddings(input_ids)
batch_size, seq_len, _ = hidden_states.shape
Expand All @@ -591,14 +560,13 @@ def forward(
past_key_values=None,
)

encoder_outputs = self.encoder(
encoder_outputs: BaseModelOutput = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)

last_hidden_state = encoder_outputs[0]
last_hidden_state = encoder_outputs.last_hidden_state
last_hidden_state = self.rms_norm(last_hidden_state)

# Get pooled output
Expand Down Expand Up @@ -749,8 +717,7 @@ def forward(
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
**kwargs,
) -> Aimv2Output:
r"""
Examples:
Expand All @@ -775,22 +742,10 @@ def forward(
>>> probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities
```"""

output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

vision_outputs: BaseModelOutputWithPooling = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
vision_outputs: BaseModelOutputWithPooling = self.vision_model(pixel_values=pixel_values, **kwargs)

text_outputs: BaseModelOutputWithPooling = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
input_ids=input_ids, attention_mask=attention_mask, **kwargs
)

image_embeds = vision_outputs.pooler_output
Expand Down
Loading