Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ __pycache__
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
Expand Down
5 changes: 4 additions & 1 deletion keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,10 @@ def _build_attention(self, rank):
if self._attention_axes is None:
self._attention_axes = tuple(range(1, rank - 2))
else:
self._attention_axes = tuple(self._attention_axes)
self._attention_axes = tuple(
axis if axis >= 0 else (rank - 1) + axis
for axis in self._attention_axes
)
(
self._dot_product_equation,
self._combine_equation,
Expand Down
32 changes: 32 additions & 0 deletions keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,38 @@ def test_high_dim_attention(
run_training_check=False,
)

def test_attention_axes_negative_indexing(self):
"""Test that negative attention_axes indexing matches
positive indexing."""
x = np.random.normal(size=(2, 3, 8, 4))

# Create two layers with equivalent positive and negative indices
mha_pos = layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=2
)
mha_neg = layers.MultiHeadAttention(
num_heads=2, key_dim=4, attention_axes=-2
)

# Initialize both layers
_ = mha_pos(x, x)
_ = mha_neg(x, x)

# Set same weights for fair comparison
mha_neg.set_weights(mha_pos.get_weights())

# Get outputs and attention scores
z_pos, a_pos = mha_pos(x, x, return_attention_scores=True)
z_neg, a_neg = mha_neg(x, x, return_attention_scores=True)

# Verify shapes match
self.assertEqual(z_pos.shape, z_neg.shape)
self.assertEqual(a_pos.shape, a_neg.shape)

# Verify outputs are identical
self.assertAllClose(z_pos, z_neg, rtol=1e-5, atol=1e-5)
self.assertAllClose(a_pos, a_neg, rtol=1e-5, atol=1e-5)

@parameterized.named_parameters(
("without_key_same_proj", (4, 8), (2, 8), None, None),
("with_key_same_proj", (4, 8), (2, 8), (2, 3), None),
Expand Down