-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Fix negative index handling in MultiHeadAttention attention_axes #21721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Fix negative index handling in MultiHeadAttention attention_axes #21721
Conversation
Summary of ChangesHello @utsab345, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request resolves a critical bug in the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request correctly fixes an issue with negative index handling in MultiHeadAttention
's attention_axes
. The core logic change in _build_attention
properly normalizes negative axes with respect to the input tensor's rank, and a new integration test verifies this behavior.
My review includes a suggestion to simplify the normalization logic using a list comprehension for better readability. More importantly, I've identified a redundant and incorrect piece of code in _build_attention_equation
that seems to be a leftover from a different fix attempt. It incorrectly uses the projected rank for normalization, which is the original bug, and should be removed to avoid confusion and future issues.
The PR also includes unrelated but correct changes for image saving utilities. It would be clearer to have these in a separate pull request.
# Normalize negative indices to positive indices | ||
if isinstance(attn_axes, (list, tuple)): | ||
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes) | ||
else: | ||
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block for normalizing negative indices appears to be both redundant and incorrect.
- Redundant: The
_build_attention
method already normalizesself._attention_axes
to positive indices before passing them to this function. This block will therefore have no effect on the already-positive axes. - Incorrect: If this block were to handle negative indices, its logic
ax % rank
is incorrect. It normalizes based on the projected tensor'srank
, which is the exact bug this PR aims to fix. The correct normalization should be relative to the input rank, as correctly implemented in_build_attention
.
To avoid confusion and potential future bugs, this block should be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, please revert.
input_rank = rank - 1 | ||
normalized_axes = [] | ||
for ax in self._attention_axes: | ||
if ax < 0: | ||
# Normalize relative to input rank | ||
normalized_ax = input_rank + ax | ||
else: | ||
normalized_ax = ax | ||
normalized_axes.append(normalized_ax) | ||
self._attention_axes = tuple(normalized_axes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21721 +/- ##
==========================================
- Coverage 82.59% 82.59% -0.01%
==========================================
Files 572 572
Lines 58401 58413 +12
Branches 9146 9150 +4
==========================================
+ Hits 48238 48244 +6
- Misses 7828 7832 +4
- Partials 2335 2337 +2
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
778c3fb
to
e704b28
Compare
e704b28
to
deebbc6
Compare
# Normalize negative indices to positive indices | ||
if isinstance(attn_axes, (list, tuple)): | ||
attn_axes = tuple(ax % rank if ax < 0 else ax for ax in attn_axes) | ||
else: | ||
attn_axes = (attn_axes % rank if attn_axes < 0 else attn_axes,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree, please revert.
""" | ||
data_format = backend.standardize_data_format(data_format) | ||
# Normalize jpg → jpeg | ||
if file_format is not None and file_format.lower() == "jpg": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated changes, please revert.
@@ -0,0 +1,27 @@ | |||
import os |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated changes, please revert.
|
||
|
||
def test_attention_axes_negative_indexing_matches_positive(): | ||
x = np.random.normal(size=(2, 3, 8, 4)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move to multi_head_attention_test.py
and use the unit test style, i.e. self.assertEqual
, self.assertAllClose
, ...
else: | ||
self._attention_axes = tuple(self._attention_axes) | ||
# Normalize negative indices relative to INPUT rank (rank - 1) | ||
input_rank = rank - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why rank - 1
?
I think this would be enough instead of lines 381-391:
self._attention_axes = tuple(axis + rank if axis < 0 else axis for axis in self._attention_axes)
Description
Fixes #21714 - Incorrect results when using
MultiHeadAttention
attention_axes
with negative indexingProblem
When
attention_axes
was specified with negative indices (e.g.,-2
), the normalization was happening relative to the projected tensor rank (which includes thenum_heads
dimension) rather than the input tensor rank. This caused incorrect axis selection and wrong einsum equations.For example, with input shape
(10, 5, 128, 16)
(rank 4):attention_axes=2
correctly produced equationabfde,abcde->abdcf
attention_axes=-2
incorrectly produced equationabcfe,abcde->abcdf
Solution
Modified
_build_attention
method inMultiHeadAttention
to normalize negative indices relative to the input rank (rank - 1
) before the num_heads dimension is added during projection. This ensures:attention_axes=-2
normalizes toinput_rank + (-2) = 4 + (-2) = 2
attention_axes=2
andattention_axes=-2
now produce identical resultsChanges
keras/src/layers/attention/multi_head_attention.py
:_build_attention
methodintegration_tests/test_multi_head_attention_negative_axis.py
:Testing