Skip to content

Commit

Permalink
vit hybrid
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 3, 2024
1 parent 66f9190 commit 7d4ba86
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion src/transformers/models/vit_hybrid/modeling_vit_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,41 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaSelfAttention with ViT->ViTHybrid
class ViTHybridSdpaSelfAttention(ViTHybridSelfAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self,
hidden_states,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
query_layer = self.transpose_for_scores(mixed_query_layer)

context_layer = torch.nn.functional.scaled_dot_product_attention(
query_layer,
key_layer,
value_layer,
head_mask,
self.attention_probs_dropout_prob if self.training else 0.0,
is_causal=False,
scale=None,
)

context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

return context_layer, None


# Copied from transformers.models.vit.modeling_vit.ViTSelfOutput with ViT->ViTHybrid
class ViTHybridSelfOutput(nn.Module):
"""
Expand Down Expand Up @@ -310,6 +345,13 @@ def forward(
return outputs


# Copied from transformers.models.vit.modeling_vit.ViTSdpaAttention with ViT->ViTHybrid
class ViTHybridSdpaAttention(ViTHybridAttention):
def __init__(self, config: ViTHybridConfig) -> None:
super().__init__(config)
self.attention = ViTHybridSdpaSelfAttention(config)


# Copied from transformers.models.vit.modeling_vit.ViTIntermediate with ViT->ViTHybrid
class ViTHybridIntermediate(nn.Module):
def __init__(self, config: ViTHybridConfig) -> None:
Expand Down Expand Up @@ -343,14 +385,20 @@ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> to
return hidden_states


VIT_HYBRID_ATTENTION_CLASSES = {
"eager": ViTHybridAttention,
"sdpa": ViTHybridSdpaAttention,
}


class ViTHybridLayer(nn.Module):
"""This corresponds to the Block class in the timm implementation."""

def __init__(self, config: ViTHybridConfig) -> None:
super().__init__()
self.chunk_size_feed_forward = config.chunk_size_feed_forward
self.seq_len_dim = 1
self.attention = ViTHybridAttention(config)
self.attention = VIT_HYBRID_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTHybridIntermediate(config)
self.output = ViTHybridOutput(config)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
Expand Down

0 comments on commit 7d4ba86

Please sign in to comment.