Skip to content

Conditinally passing and_mask_function arg to create_causal_mask #44641

Draft
kmbhattt-aws wants to merge 1 commit intohuggingface:mainfrom
kmbhattt-aws:fix/correctly-create-causal-mask-falcon
Draft

Conditinally passing and_mask_function arg to create_causal_mask #44641
kmbhattt-aws wants to merge 1 commit intohuggingface:mainfrom
kmbhattt-aws:fix/correctly-create-causal-mask-falcon

Conversation

@kmbhattt-aws
Copy link
Copy Markdown

What does this PR do?

Issue: A full 4D attention mask of shape [1, 1, seq_len, seq_len] is being created during attention, even when not using alibi for positional embeddings.

  • This occupied extra memory during training.

Root Cause:
The create_causal_mask call in the Falcon model was passing an and_mask_function argument:

causal_mask = create_causal_mask(
    ...
    and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool),  # Forces mask creation
)

This was added with the comment "Force mask creation for alibi" - as ALiBi positional encoding requires the materialized mask to apply linear biases to attention scores.

However, in Falcon model even when alibi is not used, the and_mask_function argument unconditionally sets allow_is_causal_skip = False, which materializes a full seq_len x seq_len shaped mask.

Fix:
Only pass and_mask_function to create_causal_mask function when alibi is used.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: falcon

past_key_values=past_key_values,
# Force mask creation for alibi
and_mask_function=lambda *args: torch.tensor(True, dtype=torch.bool),
and_mask_function=(lambda *args: torch.tensor(True, dtype=torch.bool)) if self.use_alibi else None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, valid point. However we should probably check whether the alibi tensor is none - not sure how reliable that flag is tbh 👀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants