-
Notifications
You must be signed in to change notification settings - Fork 555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal for separate passing of causal flag and other types of bias #640
Comments
Thanks for opening this issue! Now I totally agree that we might want to combine multiple bias types together:
We're not totally clear on how this would work out in terms of API, but we definitively want to support combinations of those, and ideally we would pass it through the Also cc @fmassa |
perhaps we could use a set of biases passed via attn_bias? that would work for the performance purposes mentioned above as well |
at least for causal + arbitrary bias, implementations that don't support separately dealing with both could just sum them in the python layer, while implementations like CUTLASS would retain the necessary information to skip computation |
We were thinking about how we could make that obvious for the user, maybe using python operators (eg |
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ```
See also #640 Adds support for combinations of different sorts of biases: - Causal - Bias (coming with #587) - Block-diagonal (used for different seqlen per batch element) We need to rename "LowerTriangularMask" because when added with a block-diagonal mask it's no longer causal: ``` # A (block-diagonal) 0 0 0 * * 0 0 0 * * * * * 0 0 * * * 0 0 # B (lower triangular) 0 * * * * 0 0 * * * 0 0 0 * * 0 0 0 0 * # A + B 0 * * * * 0 0 * * * * * * * * * * * 0 * # A + causal (what most ppl want) 0 * * * * 0 0 * * * * * * 0 * * * * 0 0 ``` ghstack-source-id: 44740f71132fa76226fd4c559cc3f09732ff139b Pull Request resolved: https://github.com/fairinternal/xformers/pull/435 __original_commit__ = fairinternal/xformers@be55fcd21c5dd621831245c5995e1c6fb49d9b77
馃殌 Feature
there's performance advantages to treating causal masking as a special case since it can be used by the kernel as a signal to skip computation. This is already supported by the CUTLASS implementation, its just not exposed in the python API
Motivation
here's some benchmarks showing additive bias and causal being used simultaneously, generally about 2x faster when causal is enabled
Pitch
here's a commit with a rough draft of what it may look like 7df2df0
for now continuing to allow causal masks to be passed via the
attn_bias
parameter to preserve backward compatibilityone of the primary purposes of this issue is so we can discuss API though
Alternatives
Additional context
also see this PR #587
The text was updated successfully, but these errors were encountered: