Allow bi-directional attention for all models#43705
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
src/transformers/utils/generic.py
Outdated
| return_dict = getattr(self.config, "return_dict", True) | ||
|
|
||
| # Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence | ||
| is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True) |
There was a problem hiding this comment.
If we default to true, it will break all encoder and encoder-decoder models.
Imo, we should properly add an is_causal flag to all models where it's obvious and default to None, i.e. don't do anything
There was a problem hiding this comment.
We can default to None to be sure, but I was under the impression that we would only add this to a few model. Then as we allow the kwarg, it's true that it becomes usable for all
There was a problem hiding this comment.
Imo, it's just a bit brittle and we don't properly document it --> easy for users to encounter weird behavior
There was a problem hiding this comment.
Made the change! But in general encoder-decoder models will explicitly use the create_bi_directional_mask functions, so behavior will never be causal (though we could do the opposite in the mask funcrions as well, i.e. turn bi-directional -> causal as well)
There was a problem hiding this comment.
If we make one wrong move and pass the is_causal as kwarg we already get that mess 👀 but yes the mask will definitely not be affected in any case
So the edge case of a user using bert as causal model (for whatever reason) and passing a custom mask along the kwarg
There was a problem hiding this comment.
Imo, we should properly add an
is_causalflag to all models where it's obvious
This might help vLLM as well, makes it easier to recognize pure decoders vs encoders.
There was a problem hiding this comment.
Yes, I will try to open a PR about this. For example it could allow users to just change the text model portions causality
src/transformers/utils/generic.py
Outdated
| # Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence | ||
| is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True) | ||
| if not is_causal: | ||
| is_causal_in_config = hasattr(self.config, "is_causal") | ||
| if is_causal_in_config: | ||
| is_causal_original_value = self.config.is_causal | ||
| # Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources) | ||
| self.config.is_causal = False | ||
| kwargs["is_causal"] = False |
There was a problem hiding this comment.
If I understand correctly, then the final is_causal is:
| kwargs["is_causal"] = True | kwargs["is_causal"] = False | kwargs["is_causal"] not defined | |
|---|---|---|---|
| config.is_causal = True | True | False | True |
| config.is_causal = False | False | False | False |
| config.is_causal not defined | True | False | True |
The bold False here is an outlier: if the architecture is bidirectional in nature (i.e. config.is_causal=False), then the user can't override that to causal. Personally, I'm exclusively interested in the decoder -> encoder case, so this is not a problem for my use cases, but perhaps we want to allow the kwargs to always have priority?
Then we'd instead use something like this
| # Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence | |
| is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True) | |
| if not is_causal: | |
| is_causal_in_config = hasattr(self.config, "is_causal") | |
| if is_causal_in_config: | |
| is_causal_original_value = self.config.is_causal | |
| # Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources) | |
| self.config.is_causal = False | |
| kwargs["is_causal"] = False | |
| # Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence | |
| is_causal = kwargs.get("is_causal", getattr(self.config, "is_causal", True)) | |
| is_causal_in_config = hasattr(self.config, "is_causal") | |
| if is_causal_in_config: | |
| is_causal_original_value = self.config.is_causal | |
| # Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources) | |
| self.config.is_causal = is_causal | |
| kwargs["is_causal"] = is_causal |
And then also drop the if not is_causal a little later on:
if is_causal_in_config:
self.config.is_causal = is_causal_original_value
else:
del self.config.is_causalThere was a problem hiding this comment.
This just got changed haha, sorry bad timing here 😬
There was a problem hiding this comment.
Edit: Looks like this is a bit dated now, and I see now that you were planning on only having this act on a few models.
There was a problem hiding this comment.
After another look at the new changes, it looks like you're now doing pretty much what I proposed re. also being able to go from encoder -> decoder.
vasqu
left a comment
There was a problem hiding this comment.
Can we maybe add a fast test to llama to check that we get different logits? Otherwise lgtm, @tomaarsen does it work for you?
tomaarsen
left a comment
There was a problem hiding this comment.
does it work for you?
Yes, I've ran some tests with voyage-4-nano, and apart from some tiny numerical differences caused by having an all-True attention_mask (Voyage's custom code) vs a None attention_mask (this PR), the performance is identical. It should allow this model, as well as many others that are simply e.g. "Qwen3 but bidirectional" to work by setting is_causal in the config.
- Tom Aarsen
What does this PR do?
Allow the
is_causalkwarg and config attribute to make well-behaved decoder-only models act as encoders