Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding BetterTransformer support for ProphetNet #648

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

adit299
Copy link
Contributor

@adit299 adit299 commented Dec 27, 2022

What does this PR do?

Opening up draft PR to start discussion on how to add Better Transformer support for ProphetNet

Fixes #488

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

@adit299
Copy link
Contributor Author

adit299 commented Dec 29, 2022

Hello, I had a couple of questions about how to add BetterTransformer support for ProphetNet. FYI, this is my first open source contribution so maybe a little
hand holding could be required:

(1) Just so that I am clear on what is going on within this ticket, we are attempting to replace any transformer encoder layers being used in models to integrate
with "Better Transformer" which is an API released by pytorch that helps speedup the models through CPI & GPU based means such as "sparsity" and "fused kernels" <--
still a little unsure about what this terms mean. Any clarification that can be provided if my understanding of this task is incorrect, would be appreciated.

(2) I am having a hard time understanding how to identify which layers within ProphetNet need to be replaced. For example, within the documentation, I see that the lines:

model = AutoModel.from_pretrained("bert-base-uncased")

print(model)

Are used to identify the layers that need to be replaced. And if I understood correctly, layers which contain (Attention), are the ones that need to be replaced? My question is how do you know what to pass to the .from_pretrained method? What would I pass to get the layers of ProphetNet? I had a look at: https://huggingface.co/transformers/v3.0.2/model_doc/auto.html, but couldn't find the answer.

Thanks again for your patience!

@michaelbenayoun @fxmarty @younesbelkada

@fxmarty
Copy link
Collaborator

fxmarty commented Dec 30, 2022

Hi, that's great! To answer your questions:

  1. Yes, exactly! You can have a look at https://pytorch.org/blog/a-better-transformer-for-fast-transformer-encoder-inference/ and https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2 for more details. To understand the "fused kernel" thing, this blog post is an excellent reference: https://horace.io/brrr_intro.html
  2. It's maybe a good idea to have a look at https://github.com/huggingface/transformers/blob/main/src/transformers/models/prophetnet/modeling_prophetnet.py as well. And maybe compare to Bert, or other models already supported to get an understanding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py . Now if you do:
from transformers import ProphetNetModel, ProphetNetConfig

cfg = ProphetNetConfig()
model = ProphetNetModel(cfg)

print(model)

you can find some submodules like ProphetNetEncoderLayer and ProphetNetDecoderLayer. It's likely that the module to replace will be the ProphetNetEncoderLayer because they include the attention and feed forward blocks. It's probably a bit unclear in the doc https://huggingface.co/docs/optimum/main/en/bettertransformer/tutorials/contribute , but we actually aim to match with nn.TransformerEncoderLayer. In this last link, you can click on the "[SOURCE]" and see which operations it includes (spoiler: the attention and feed foward)

@fxmarty
Copy link
Collaborator

fxmarty commented Jan 4, 2023

Hi @adit299 , your branch is out of sync with main, especially optimum/bettertransformer/models/__init__.py. Are you familiar with git?

@adit299
Copy link
Contributor Author

adit299 commented Jan 4, 2023

Hi @fxmarty, I assume this is referring to the merge conflicts? Yes, I am familiar with Git, will resolve. Will let you know if I am unable to.

@adit299
Copy link
Contributor Author

adit299 commented Jan 5, 2023

Hello @fxmarty ,

Thanks for your detailed response to my initial question! Those links were super helpful! Also, please do let me know if there are any issues with the way I resolved the merge conflicts. I had some additional questions:

(1) Just a quick question about how these attributes relate to the Transformer Encoder module mentioned in the paper:

in_proj_weight
in_proj_bias
out_proj_weight
out_proj_bias

For example, does in_proj_weight refer to the weights being applied to the input embedding? (And similar idea for out_proj_weight?) Perhaps, this question is out of scope for the task, but any clarification is appreciated!

(2) In Step 3 of https://huggingface.co/docs/optimum/bettertransformer/tutorials/contribute, it says:

After the first forward pass, the hidden states needs to be nested using the attention mask.
Once they are nested, the attention mask is not needed anymore, therefore can be set to None.
This is how the forward pass is built for Bert, these lines should remain pretty much similar across models,
but sometimes the shapes of the attention masks are different across models.

I am confused about what is meant by nesting? And how do we check the shape of the attention mask?

@younesbelkada @michaelbenayoun

@fxmarty
Copy link
Collaborator

fxmarty commented Jan 5, 2023

Hi @adit299 , the merge is good, thanks! For your questions:

  1. The four
in_proj_weight
in_proj_bias
out_proj_weight
out_proj_bias

correspond to

self.self_attn.in_proj_weight
self.self_attn.in_proj_bias
self.self_attn.out_proj.weight
self.self_attn.out_proj.bias

in PyTorch's nn.TransformerEncoderLayer, self_attn being himself a MultiHeadAttention in PyTorch. They should correspond to the linear layers to obtain Q, K, V before the multi-head attention, and the one after the attention.

  1. Yes, the doc is not super clear, we'll improve it! What is meant by nested is something like:
    hidden_states = torch._nested_tensor_from_mask(hidden_states, ~attention_mask)
    This line basically creates a nested tensor from a mask: https://pytorch.org/docs/stable/nested.html . It allows to get rid of the values at padding indexes for all item in a batch.

Yon can get a feel of it:

import torch

hidden_states = torch.rand(4, 6, 8)  # batch size, sequence length, hidden size

attention_mask = torch.ones(4, 6, dtype=torch.bool)

attention_mask[2][4:] = False
attention_mask[3][3:] = False
attention_mask[1][2:] = False

attention_mask = torch.reshape(attention_mask, (attention_mask.shape[0], attention_mask.shape[-1]))

res = torch._nested_tensor_from_mask(hidden_states, attention_mask)

I'm not sure about the shape of the attention mask.

Just a note: could you add the model in tests/bettertransformer/test_bettertransformer_encoder.py? You can search for a tiny one on the Hub. I'll trigger the tests afterwards!

@adit299
Copy link
Contributor Author

adit299 commented Jan 7, 2023

Hello @fxmarty,

Thanks again for the detailed response, made things clear. I just added the model to the tests. Couple of things:

(1) I'm noticing that I get this error when I run the tests locally:

test_failure

Currently, I am debugging this. Any advice/insight you could give would be appreciated! I have also attached the full stacktrace as well:

full_stacktrace.txt

@younesbelkada @michaelbenayoun

Comment on lines +1274 to +1275
if hidden_states.shape[0] != attention_mask.shape[0]:
hidden_states = hidden_states.transpose(1, 0)
Copy link
Collaborator

@fxmarty fxmarty Jan 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe the error you get is due to those you could try to remove, not sure.

Copy link
Contributor Author

@adit299 adit299 Jan 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried that.. no luck, those tests are still failing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

By using print statements, I can see that hidden_states is of shape: torch.Size([2, 4, 16]) and attention_mask is of shape: torch.Size([8, 4]). So there appears to be a discrepancy in the batch_size value?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @adit299 I'll look asap, probably on Friday. In the meanwhile you can compare with the shapes for e.g. Bert to see how it differs.

@adit299 adit299 changed the title Added mapping for prophetnet ProphetNet BetterTransformer Support Jan 8, 2023
@adit299 adit299 changed the title ProphetNet BetterTransformer Support Adding BetterTransformer support for ProphetNet Jan 8, 2023
@adit299
Copy link
Contributor Author

adit299 commented Jan 9, 2023

Hello,

From the ProphetNet documentation (https://huggingface.co/docs/transformers/model_doc/prophetnet) it states that "ProphetNet is an encoder-decoder model..", which is why I added "hf-internal-testing/tiny-random-ProphetNetModel" under the ALL_ENCODER_DECODER_MODELS_TO_TEST list within test_bettertransformer_encoder.py. I tried running the tests when "hf-internal-testing/tiny-random-ProphetNetModel" is set under the ALL_ENCODER_MODELS_TO_TEST list and 1 test fails, with this failure (I have attached the full stacktrace):

encoder_model_error_trace.txt

This error seems similar to the one referenced here: #494 (comment)

So, I guess my questions are:

(1) Since the layer being replaced within ProphetNet is an encoder layer, does this mean that we just need to add it under the ALL_ENCODER_MODELS_TO_TEST list? Or would it be under the ALL_ENCODER_DECODER_MODELS_TO_TEST list? (both lists are within test_bettertransformer_encoder.py) What differentiates the models within both these lists?

(2) I had a look at the test_bettertransformer_encoder.py file, and was having a hard time understanding how the tests differ between models in the ALL_ENCODER_MODELS_TO_TEST list and ALL_ENCODER_DECODER_MODELS_TO_TEST list. Any clarification that could be given on how the tests differ would be appreciated.

Thanks again for the speedy responses!

@fxmarty @younesbelkada @michaelbenayoun

@michaelbenayoun
Copy link
Member

Hi @adit299

You are right that the naming for the test file might not be the best, basically this files tests the BetterTransformer feature for text models.

About your questions, ALL_ENCODER_MODELS_TO_TEST is for "regular" models while ALL_ENCODER_DECODER_MODELS_TO_TEST is for seq2seq models.

The tests are different because things can vary a bit for seq2seq models, for instance inputs. In you case your test fails because no inputs are provided for the decoder. You should add it to the ALL_ENCODER_DECODER_MODELS_TO_TEST list instead.

@adit299
Copy link
Contributor Author

adit299 commented Jan 15, 2023

Hello @michaelbenayoun,

Thank you for the clarification! I guess this confirms that the attention mask in ProphetNet is being constructed differently to BERT. I believe I have found the lines of code that are causing this difference:

https://github.com/huggingface/transformers/blob/5db9abde439bc02c3791da2a4fefee80d94d5b96/src/transformers/models/prophetnet/modeling_prophetnet.py#L1334-L1340

Prior to these lines of code executing, the attention_mask is of shape [2, 4] which lines up with the dimensions of the hidden_states. However, after these lines execute, the dimensions are [8, 1, 4]. So some sort of padding/re-shaping is happenning here. If I can figure out what the purpose of this padding is and how to deal with it,
I think the issues in the tests are resolved. This is what I am looking into now (mainly using Git Blame). Any clarification you can provide on this matter would be appreciated.

@fxmarty @younesbelkada

@adit299
Copy link
Contributor Author

adit299 commented Jan 19, 2023

Hello,

I have been having issues with the model not producing an attention mask of the correct dimension and not producing the same logits as the original model. I have attached the full stacktrace. I am not sure what the issue is, I suspect it could be that one of the parameters within ProphetNetEncoderLayerBetterTransformer is not set correctly but I'm not sure.. bit lost on how to proceed.

full_stacktrace-2.txt

@fxmarty @younesbelkada @michaelbenayoun

@fxmarty
Copy link
Collaborator

fxmarty commented Feb 14, 2023

Hi @adit299 , I'll try to have a look soon!

@jayant-yadav
Copy link

@adit299 whats the status of this task?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Community contribution - BetterTransformer integration for more models!
4 participants