Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 3, 2024
1 parent cfde6eb commit 1fcc0a0
Showing 1 changed file with 4 additions and 16 deletions.
20 changes: 4 additions & 16 deletions src/transformers/models/vit_msn/modeling_vit_msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,18 +194,12 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (
self.num_attention_heads,
self.attention_head_size,
)
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
self,
hidden_states,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

Expand Down Expand Up @@ -247,10 +241,7 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.attention_probs_dropout_prob = config.attention_probs_dropout_prob

def forward(
self,
hidden_states,
head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
mixed_query_layer = self.query(hidden_states)

Expand Down Expand Up @@ -306,10 +297,7 @@ def prune_heads(self, heads: Set[int]) -> None:
if len(heads) == 0:
return
heads, index = find_pruneable_heads_and_indices(
heads,
self.attention.num_attention_heads,
self.attention.attention_head_size,
self.pruned_heads,
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
)

# Prune linear layers
Expand Down

0 comments on commit 1fcc0a0

Please sign in to comment.