Skip to content

Allow bi-directional attention for all models#43705

Merged
Cyrilvallez merged 9 commits intomainfrom
power-mask
Feb 4, 2026
Merged

Allow bi-directional attention for all models#43705
Cyrilvallez merged 9 commits intomainfrom
power-mask

Conversation

@Cyrilvallez
Copy link
Member

@Cyrilvallez Cyrilvallez commented Feb 3, 2026

What does this PR do?

Allow the is_causal kwarg and config attribute to make well-behaved decoder-only models act as encoders

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

return_dict = getattr(self.config, "return_dict", True)

# Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence
is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we default to true, it will break all encoder and encoder-decoder models.

Imo, we should properly add an is_causal flag to all models where it's obvious and default to None, i.e. don't do anything

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can default to None to be sure, but I was under the impression that we would only add this to a few model. Then as we allow the kwarg, it's true that it becomes usable for all

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, it's just a bit brittle and we don't properly document it --> easy for users to encounter weird behavior

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made the change! But in general encoder-decoder models will explicitly use the create_bi_directional_mask functions, so behavior will never be causal (though we could do the opposite in the mask funcrions as well, i.e. turn bi-directional -> causal as well)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we make one wrong move and pass the is_causal as kwarg we already get that mess 👀 but yes the mask will definitely not be affected in any case

So the edge case of a user using bert as causal model (for whatever reason) and passing a custom mask along the kwarg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imo, we should properly add an is_causal flag to all models where it's obvious

This might help vLLM as well, makes it easier to recognize pure decoders vs encoders.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I will try to open a PR about this. For example it could allow users to just change the text model portions causality

Comment on lines 926 to 934
# Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence
is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True)
if not is_causal:
is_causal_in_config = hasattr(self.config, "is_causal")
if is_causal_in_config:
is_causal_original_value = self.config.is_causal
# Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources)
self.config.is_causal = False
kwargs["is_causal"] = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, then the final is_causal is:

kwargs["is_causal"] = True kwargs["is_causal"] = False kwargs["is_causal"] not defined
config.is_causal = True True False True
config.is_causal = False False False False
config.is_causal not defined True False True

The bold False here is an outlier: if the architecture is bidirectional in nature (i.e. config.is_causal=False), then the user can't override that to causal. Personally, I'm exclusively interested in the decoder -> encoder case, so this is not a problem for my use cases, but perhaps we want to allow the kwargs to always have priority?

Then we'd instead use something like this

Suggested change
# Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence
is_causal = kwargs.get("is_causal", True) and getattr(self.config, "is_causal", True)
if not is_causal:
is_causal_in_config = hasattr(self.config, "is_causal")
if is_causal_in_config:
is_causal_original_value = self.config.is_causal
# Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources)
self.config.is_causal = False
kwargs["is_causal"] = False
# Maybe temporarily overwrite config value to create the correct mask - kwarg takes precedence
is_causal = kwargs.get("is_causal", getattr(self.config, "is_causal", True))
is_causal_in_config = hasattr(self.config, "is_causal")
if is_causal_in_config:
is_causal_original_value = self.config.is_causal
# Set it to both config and kwargs (it's needed in both, and can come from only 1 of the sources)
self.config.is_causal = is_causal
kwargs["is_causal"] = is_causal

And then also drop the if not is_causal a little later on:

            if is_causal_in_config:
                self.config.is_causal = is_causal_original_value
            else:
                del self.config.is_causal

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This just got changed haha, sorry bad timing here 😬

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Looks like this is a bit dated now, and I see now that you were planning on only having this act on a few models.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After another look at the new changes, it looks like you're now doing pretty much what I proposed re. also being able to go from encoder -> decoder.

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we maybe add a fast test to llama to check that we get different logits? Otherwise lgtm, @tomaarsen does it work for you?

Copy link
Member

@tomaarsen tomaarsen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it work for you?

Yes, I've ran some tests with voyage-4-nano, and apart from some tiny numerical differences caused by having an all-True attention_mask (Voyage's custom code) vs a None attention_mask (this PR), the performance is identical. It should allow this model, as well as many others that are simply e.g. "Qwen3 but bidirectional" to work by setting is_causal in the config.

  • Tom Aarsen

@Cyrilvallez Cyrilvallez merged commit 83bce8d into main Feb 4, 2026
23 of 26 checks passed
@Cyrilvallez Cyrilvallez deleted the power-mask branch February 4, 2026 17:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants