Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastien Ehrhardt committed May 3, 2024
1 parent e77df1e commit cfde6eb
Showing 1 changed file with 37 additions and 120 deletions.
157 changes: 37 additions & 120 deletions src/transformers/models/vit_msn/modeling_vit_msn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from .configuration_vit_msn import ViTMSNConfig


logger = logging.get_logger(__name__)


Expand All @@ -56,22 +57,14 @@ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None:
super().__init__()

self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
self.mask_token = (
nn.Parameter(torch.zeros(1, 1, config.hidden_size))
if use_mask_token
else None
)
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None
self.patch_embeddings = ViTMSNPatchEmbeddings(config)
num_patches = self.patch_embeddings.num_patches
self.position_embeddings = nn.Parameter(
torch.zeros(1, num_patches + 1, config.hidden_size)
)
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.config = config

def interpolate_pos_encoding(
self, embeddings: torch.Tensor, height: int, width: int
) -> torch.Tensor:
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Expand All @@ -95,9 +88,7 @@ def interpolate_pos_encoding(
patch_window_height + 0.1,
patch_window_width + 0.1,
)
patch_pos_embed = patch_pos_embed.reshape(
1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim
)
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
Expand All @@ -118,9 +109,7 @@ def forward(
interpolate_pos_encoding: bool = False,
) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
embeddings = self.patch_embeddings(
pixel_values, interpolate_pos_encoding=interpolate_pos_encoding
)
embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)

if bool_masked_pos is not None:
seq_length = embeddings.shape[1]
Expand All @@ -135,9 +124,7 @@ def forward(

# add positional encoding to each token
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(
embeddings, height, width
)
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embeddings

Expand All @@ -159,31 +146,17 @@ def __init__(self, config):
image_size, patch_size = config.image_size, config.patch_size
num_channels, hidden_size = config.num_channels, config.hidden_size

image_size = (
image_size
if isinstance(image_size, collections.abc.Iterable)
else (image_size, image_size)
)
patch_size = (
patch_size
if isinstance(patch_size, collections.abc.Iterable)
else (patch_size, patch_size)
)
num_patches = (image_size[1] // patch_size[1]) * (
image_size[0] // patch_size[0]
)
image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.num_patches = num_patches

self.projection = nn.Conv2d(
num_channels, hidden_size, kernel_size=patch_size, stride=patch_size
)
self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)

def forward(
self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False
) -> torch.Tensor:
def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
if num_channels != self.num_channels:
raise ValueError(
Expand All @@ -204,9 +177,7 @@ def forward(
class ViTMSNSelfAttention(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
config, "embedding_size"
):
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
raise ValueError(
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
f"heads {config.num_attention_heads}."
Expand All @@ -216,15 +187,9 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size

self.query = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.key = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.value = nn.Linear(
config.hidden_size, self.all_head_size, bias=config.qkv_bias
)
self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)

self.dropout = nn.Dropout(config.attention_probs_dropout_prob)

Expand Down Expand Up @@ -270,9 +235,7 @@ def forward(
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(new_context_layer_shape)

outputs = (
(context_layer, attention_probs) if output_attentions else (context_layer,)
)
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)

return outputs

Expand Down Expand Up @@ -324,9 +287,7 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)

Expand Down Expand Up @@ -358,12 +319,8 @@ def prune_heads(self, heads: Set[int]) -> None:
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)

# Update hyper params and store pruned heads
self.attention.num_attention_heads = self.attention.num_attention_heads - len(
heads
)
self.attention.all_head_size = (
self.attention.attention_head_size * self.attention.num_attention_heads
)
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
self.pruned_heads = self.pruned_heads.union(heads)

def forward(
Expand All @@ -376,9 +333,7 @@ def forward(

attention_output = self.output(self_outputs[0], hidden_states)

outputs = (attention_output,) + self_outputs[
1:
] # add attentions if we output them
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
return outputs


Expand Down Expand Up @@ -413,9 +368,7 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)

def forward(
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)

Expand All @@ -438,12 +391,8 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config)
self.intermediate = ViTMSNIntermediate(config)
self.output = ViTMSNOutput(config)
self.layernorm_before = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.layernorm_after = nn.LayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(
self,
Expand All @@ -452,16 +401,12 @@ def forward(
output_attentions: bool = False,
) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]:
self_attention_outputs = self.attention(
self.layernorm_before(
hidden_states
), # in ViTMSN, layernorm is applied before self-attention
self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention
head_mask,
output_attentions=output_attentions,
)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[
1:
] # add self attentions if we output attention weights
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

# first residual connection
hidden_states = attention_output + hidden_states
Expand All @@ -483,9 +428,7 @@ class ViTMSNEncoder(nn.Module):
def __init__(self, config: ViTMSNConfig) -> None:
super().__init__()
self.config = config
self.layer = nn.ModuleList(
[ViTMSNLayer(config) for _ in range(config.num_hidden_layers)]
)
self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False

def forward(
Expand Down Expand Up @@ -513,9 +456,7 @@ def forward(
output_attentions,
)
else:
layer_outputs = layer_module(
hidden_states, layer_head_mask, output_attentions
)
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions)

hidden_states = layer_outputs[0]

Expand All @@ -526,11 +467,7 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(
v
for v in [hidden_states, all_hidden_states, all_self_attentions]
if v is not None
)
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
return BaseModelOutput(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
Expand Down Expand Up @@ -631,9 +568,7 @@ class PreTrainedModel
self.encoder.layer[layer].attention.prune_heads(heads)

@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC
)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -668,19 +603,11 @@ def forward(
... outputs = model(**inputs)
>>> last_hidden_states = outputs.last_hidden_state
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if pixel_values is None:
raise ValueError("You have to specify pixel_values")
Expand Down Expand Up @@ -735,19 +662,13 @@ def __init__(self, config: ViTMSNConfig) -> None:
self.vit = ViTMSNModel(config)

# Classifier head
self.classifier = (
nn.Linear(config.hidden_size, config.num_labels)
if config.num_labels > 0
else nn.Identity()
)
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()

# Initialize weights and apply final processing
self.post_init()

@add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING)
@replace_return_docstrings(
output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC
)
@replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -785,9 +706,7 @@ def forward(
>>> print(model.config.id2label[predicted_label])
tusker
```"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.vit(
pixel_values,
Expand All @@ -807,9 +726,7 @@ def forward(
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (
labels.dtype == torch.long or labels.dtype == torch.int
):
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
Expand Down

0 comments on commit cfde6eb

Please sign in to comment.