Skip to content

Commit

Permalink
DETR added, problem with position_embedding_type=sine -> not supported
Browse files Browse the repository at this point in the history
  • Loading branch information
mszsorondo authored and awinml committed May 20, 2023
1 parent a6951c1 commit 24209cb
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 3 deletions.
106 changes: 103 additions & 3 deletions optimum/bettertransformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 The HuggingFace and Meta Team. All rights reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,5 +11,105 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .models import BetterTransformerManager
from .transformation import BetterTransformer
import warnings

from .encoder_models import (
AlbertLayerBetterTransformer,
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
CLIPLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
MBartEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
WhisperEncoderLayerBetterTransformer,
DetrEncoderLayerBetterTransformer,
)


class BetterTransformerManager:
MODEL_MAPPING = {
"albert": ("AlbertLayer", AlbertLayerBetterTransformer),
"bart": ("BartEncoderLayer", BartEncoderLayerBetterTransformer),
"bert": ("BertLayer", BertLayerBetterTransformer),
"bert-generation": ("BertGenerationLayer", BertLayerBetterTransformer),
"camembert": ("CamembertLayer", BertLayerBetterTransformer),
"clip": ("CLIPEncoderLayer", CLIPLayerBetterTransformer),
"data2vec-text": ("Data2VecTextLayer", BertLayerBetterTransformer),
"deit": ("DeiTLayer", ViTLayerBetterTransformer),
"detr": ("DetrEncoderLayer", DetrEncoderLayerBetterTransformer),
"distilbert": ("TransformerBlock", DistilBertLayerBetterTransformer),
"electra": ("ElectraLayer", BertLayerBetterTransformer),
"ernie": ("ErnieLayer", BertLayerBetterTransformer),
"fsmt": ("EncoderLayer", FSMTEncoderLayerBetterTransformer),
"hubert": ("HubertEncoderLayer", Wav2Vec2EncoderLayerBetterTransformer),
"layoutlm": ("LayoutLMLayer", BertLayerBetterTransformer),
"m2m_100": ("M2M100EncoderLayer", MBartEncoderLayerBetterTransformer),
"markuplm": ("MarkupLMLayer", BertLayerBetterTransformer),
"mbart": ("MBartEncoderLayer", MBartEncoderLayerBetterTransformer),
"rembert": ("RemBertLayer", BertLayerBetterTransformer),
"roberta": ("RobertaLayer", BertLayerBetterTransformer),
"splinter": ("SplinterLayer", BertLayerBetterTransformer),
"tapas": ("TapasLayer", BertLayerBetterTransformer),
"vilt": ("ViltLayer", ViltLayerBetterTransformer),
"vit": ("ViTLayer", ViTLayerBetterTransformer),
"vit_mae": ("ViTMAELayer", ViTLayerBetterTransformer),
"vit_msn": ("ViTMSNLayer", ViTLayerBetterTransformer),
"wav2vec2": ("Wav2Vec2EncoderLayer", Wav2Vec2EncoderLayerBetterTransformer),
"whisper": ("WhisperEncoderLayer", WhisperEncoderLayerBetterTransformer),
"xlm-roberta": ("XLMRobertaLayer", BertLayerBetterTransformer),
"yolos": ("YolosLayer", ViTLayerBetterTransformer),
}

EXCLUDE_FROM_TRANSFORM = {
# clip's text model uses causal attention, that is most likely not supported in BetterTransformer
"clip": ["text_model"],
}

CAN_NOT_BE_SUPPORTED = {
"deberta-v2": "DeBERTa v2 does not use a regular attention mechanism, which is not suppored in PyTorch's BetterTransformer.",
"glpn": "GLPN has a convolutional layer present in the FFN network, which is not suppored in PyTorch's BetterTransformer.",
"t5": "T5 uses attention bias, which is not suppored in PyTorch's BetterTransformer.",
}

@staticmethod
def cannot_support(model_type: str) -> bool:
"""
Returns True if a given model type can not be supported by PyTorch's Better Transformer.
Args:
model_type (`str`):
The model type to check.
"""
return model_type in BetterTransformerManager.CAN_NOT_BE_SUPPORTED

@staticmethod
def supports(model_type: str) -> bool:
"""
Returns True if a given model type is supported by PyTorch's Better Transformer, and integrated in Optimum.
Args:
model_type (`str`):
The model type to check.
"""
return model_type in BetterTransformerManager.MODEL_MAPPING


class warn_uncompatible_save(object):
def __init__(self, callback):
self.callback = callback

def __enter__(self):
return self

def __exit__(self, ex_typ, ex_val, traceback):
return True

def __call__(self, *args, **kwargs):
warnings.warn(
"You are calling `save_pretrained` to a `BetterTransformer` converted model you may likely encounter unexepected behaviors. ",
UserWarning,
)
return self.callback(*args, **kwargs)
124 changes: 124 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,130 @@ def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
hidden_states = hidden_states.to_padded_tensor(0.0, original_shape)
return (hidden_states,)

class DetrEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, detr_encoder_layer, config):
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
detr_encoder_layer.self_attn.q_proj.weight,
detr_encoder_layer.self_attn.k_proj.weight,
detr_encoder_layer.self_attn.v_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
detr_encoder_layer.self_attn.q_proj.bias,
detr_encoder_layer.self_attn.k_proj.bias,
detr_encoder_layer.self_attn.v_proj.bias,
]
)
)
self.out_proj_weight = detr_encoder_layer.self_attn.out_proj.weight
self.out_proj_bias = detr_encoder_layer.self_attn.out_proj.bias
self.linear1_weight = detr_encoder_layer.fc1.weight
self.linear1_bias = detr_encoder_layer.fc1.bias
self.linear2_weight = detr_encoder_layer.fc2.weight
self.linear2_bias = detr_encoder_layer.fc2.bias
# Layer norm 1
self.norm1_eps = detr_encoder_layer.self_attn_layer_norm.eps
self.norm1_weight = detr_encoder_layer.self_attn_layer_norm.weight
self.norm1_bias = detr_encoder_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = detr_encoder_layer.final_layer_norm.eps
self.norm2_weight = detr_encoder_layer.final_layer_norm.weight
self.norm2_bias = detr_encoder_layer.final_layer_norm.bias

self.num_heads = detr_encoder_layer.self_attn.num_heads
self.embed_dim = detr_encoder_layer.self_attn.embed_dim

self.is_last_layer = False
self.norm_first = True
self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

# we expect attention_mask to be None in the vision model
if attention_mask is not None:
raise ValueError(
"Please do not use attention masks when using `BetterTransformer` converted vision models"
)

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)

def forward(self, hidden_states, attention_mask, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

# we expect attention_mask to be None in the vision model
if attention_mask is not None:
raise ValueError(
"Please do not use attention masks when using `BetterTransformer` converted vision models"
)

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)

return (hidden_states,)

def _get_activation_function(self, config: "PretrainedConfig"):
if hasattr(config, "vision_config") and hasattr(config, "text_config"):
assert config.vision_config.hidden_act == config.text_config.hidden_act
return config.vision_config.hidden_act
else:
return config.hidden_act


class DistilBertLayerBetterTransformer(BetterTransformerBaseLayer, nn.Module):
def __init__(self, bert_layer, config):
Expand Down

0 comments on commit 24209cb

Please sign in to comment.