Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BT] Add Bettertransformer support for FSMT #494

Merged
merged 13 commits into from
Nov 24, 2022
3 changes: 3 additions & 0 deletions optimum/bettertransformer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
BartEncoderLayerBetterTransformer,
BertLayerBetterTransformer,
DistilBertLayerBetterTransformer,
FSMTEncoderLayerBetterTransformer,
ViltLayerBetterTransformer,
ViTLayerBetterTransformer,
Wav2Vec2EncoderLayerBetterTransformer,
Expand Down Expand Up @@ -66,6 +67,8 @@
"ViTMAELayer": ViTLayerBetterTransformer,
"ViTMSNLayer": ViTLayerBetterTransformer,
"YolosLayer": ViTLayerBetterTransformer,
# FSMTModel:
"EncoderLayer": FSMTEncoderLayerBetterTransformer,
"ViltLayer": ViltLayerBetterTransformer,
}

Expand Down
112 changes: 112 additions & 0 deletions optimum/bettertransformer/models/encoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,3 +847,115 @@ def forward(self, hidden_states, attention_mask, **__):
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states,)


class FSMTEncoderLayerBetterTransformer(BetterTransformerBaseLayer):
def __init__(self, fsmt_layer, config):
r"""
A simple conversion of the FSMT Encoder layer to its `BetterTransformer` implementation.

Args:
fsmt_layer (`torch.nn.Module`):
The original FSMT Layer where the weights needs to be retrieved.
"""
super().__init__(config)
# In_proj layer
self.in_proj_weight = nn.Parameter(
torch.cat(
[
fsmt_layer.self_attn.q_proj.weight,
fsmt_layer.self_attn.k_proj.weight,
fsmt_layer.self_attn.v_proj.weight,
]
)
)
self.in_proj_bias = nn.Parameter(
torch.cat(
[
fsmt_layer.self_attn.q_proj.bias,
fsmt_layer.self_attn.k_proj.bias,
fsmt_layer.self_attn.v_proj.bias,
]
)
)

# Out proj layer
self.out_proj_weight = fsmt_layer.self_attn.out_proj.weight
self.out_proj_bias = fsmt_layer.self_attn.out_proj.bias

# Linear layer 1
self.linear1_weight = fsmt_layer.fc1.weight
self.linear1_bias = fsmt_layer.fc1.bias

# Linear layer 2
self.linear2_weight = fsmt_layer.fc2.weight
self.linear2_bias = fsmt_layer.fc2.bias

# Layer norm 1
self.norm1_eps = fsmt_layer.self_attn_layer_norm.eps
self.norm1_weight = fsmt_layer.self_attn_layer_norm.weight
self.norm1_bias = fsmt_layer.self_attn_layer_norm.bias

# Layer norm 2
self.norm2_eps = fsmt_layer.final_layer_norm.eps
self.norm2_weight = fsmt_layer.final_layer_norm.weight
self.norm2_bias = fsmt_layer.final_layer_norm.bias

# Model hyper parameters
self.num_heads = fsmt_layer.self_attn.num_heads
self.embed_dim = fsmt_layer.self_attn.embed_dim

# Last step: set the last layer to `False` -> this will be set to `True` when converting the model
self.is_last_layer = False

self.validate_bettertransformer()

def forward(self, hidden_states, attention_mask, position_bias=None, *_, **__):
r"""
This is just a wrapper around the forward function proposed in:
https://github.com/huggingface/transformers/pull/19553
"""
super().forward_checker()

if hidden_states.is_nested:
attention_mask = None

if attention_mask is not None:
# attention mask comes in with values 0 and -inf. we convert to torch.nn.TransformerEncoder style bool mask
# 0->false->keep this token -inf->true->mask this token
attention_mask = attention_mask.bool()
attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))
seqlen = attention_mask.shape[1]
lengths = torch.sum(~attention_mask, 1)

if hidden_states.shape[0] != attention_mask.shape[0]:
hidden_states = hidden_states.transpose(1, 0)

if not all([l == seqlen for l in lengths]):
hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
attention_mask = None

hidden_states = torch._transformer_encoder_layer_fwd(
hidden_states,
self.embed_dim,
self.num_heads,
self.in_proj_weight,
self.in_proj_bias,
self.out_proj_weight,
self.out_proj_bias,
self.use_gelu,
self.norm_first,
self.norm1_eps,
self.norm1_weight,
self.norm1_bias,
self.norm2_weight,
self.norm2_bias,
self.linear1_weight,
self.linear1_bias,
self.linear2_weight,
self.linear2_bias,
attention_mask,
)
if hidden_states.is_nested and self.is_last_layer:
hidden_states = hidden_states.to_padded_tensor(0.0)
return (hidden_states, attention_mask)
35 changes: 33 additions & 2 deletions tests/bettertransformer/test_bettertransformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

ALL_ENCODER_MODELS_TO_TEST = [
"hf-internal-testing/tiny-random-DistilBertModel",
"hf-internal-testing/tiny-random-BartModel",
"hf-internal-testing/tiny-random-AlbertModel",
"hf-internal-testing/tiny-random-RobertaModel",
"hf-internal-testing/tiny-xlm-roberta",
Expand All @@ -47,6 +46,11 @@
"ybelkada/random-tiny-BertGenerationModel",
younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
]

ALL_ENCODER_DECODER_MODELS_TO_TEST = [
"hf-internal-testing/tiny-random-FSMTModel",
"hf-internal-testing/tiny-random-BartModel",
]


class BetterTransformersEncoderTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Expand Down Expand Up @@ -77,7 +81,10 @@ def _loop_all_classes(self):
model to test.
"""
for layer_class in BETTER_TRANFORMER_LAYERS_MAPPING_DICT.keys():
if layer_class == "TransformerBlock":
if layer_class == "EncoderLayer":
# Hardcode it for FSMT - see https://github.com/huggingface/optimum/pull/494
class_name = "FSMT"
elif layer_class == "TransformerBlock":
# Hardcode it for distilbert - see https://github.com/huggingface/transformers/pull/19966
class_name = "DistilBert"
elif "EncoderLayer" in layer_class:
Expand Down Expand Up @@ -252,6 +259,30 @@ def test_accelerate_compatibility_single_gpu_without_keeping(self):
self.check_accelerate_compatibility_cpu_gpu(keep_original_model=False, max_memory=max_memory)


class BetterTransformersEncoderDecoderTest(BetterTransformersTestMixin, unittest.TestCase):
r"""
Full testing suite of the `BetterTransformers` integration into Hugging Face
`transformers` ecosystem. Check the docstring of each test to understand the
purpose of each test. Basically we test:
- if the conversion dictionnary is consistent, ie if the converted model exists
in HuggingFace `transformers` library.
- if the converted model produces the same logits as the original model.
- if the converted model is faster than the original model.
"""
all_models_to_test = ALL_ENCODER_DECODER_MODELS_TO_TEST

def tearDown(self):
gc.collect()

def prepare_inputs_for_class(self, model_id=None):
input_dict = {
"input_ids": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1]]),
"attention_mask": torch.LongTensor([[1, 1, 1, 1, 1, 1], [1, 1, 1, 0, 0, 0]]),
"decoder_input_ids": torch.LongTensor([[0], [0]]),
}
return input_dict


def get_batch(batch_size, avg_seqlen, max_sequence_length, seqlen_stdev, vocab_size, pad_idx=0):
r"""
Utility function to generate a batch of random sequences, together with their
Expand Down
2 changes: 1 addition & 1 deletion tests/bettertransformer/testing_bettertransformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def test_logits(self):
torch.manual_seed(0)
bt_hidden_states = converted_model(**inputs)[0]

if "gelu_new" in vars(random_config).values():
if "gelu_new" in list(random_config.to_dict().values()):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great change, but no need to cast to list here!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! but I think the to_dict() is necessary (otherwise I get a VERY weird error from transformers)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes exactly, but accessing the value is enough here

younesbelkada marked this conversation as resolved.
Show resolved Hide resolved
# Since `gelu_new` is a slightly modified version of `GeLU` we expect a small
# discrepency.
tol = 4e-2
Expand Down