Skip to content

Conversation

@serach24
Copy link
Member

@serach24 serach24 commented Dec 5, 2025

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.

@serach24 serach24 force-pushed the cj/fix_scatter_issue branch from 5b1666d to b110942 Compare December 5, 2025 04:27
@serach24 serach24 requested a review from akuegel December 5, 2025 04:27
Copy link
Member

@akuegel akuegel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you :)

copybara-service bot pushed a commit that referenced this pull request Dec 5, 2025
Imported from GitHub PR #34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes #34870

FUTURE_COPYBARA_INTEGRATE_REVIEW=#34870 from serach24:cj/fix_scatter_issue b110942
PiperOrigin-RevId: 840559332
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 5, 2025
Imported from GitHub PR openxla/xla#34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes #34870

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#34870 from serach24:cj/fix_scatter_issue b110942e013047c90b19a49bef0aa487061753f6
PiperOrigin-RevId: 840559332
copybara-service bot pushed a commit that referenced this pull request Dec 5, 2025
Imported from GitHub PR #34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes #34870

FUTURE_COPYBARA_INTEGRATE_REVIEW=#34870 from serach24:cj/fix_scatter_issue b110942
PiperOrigin-RevId: 840559332
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 5, 2025
Imported from GitHub PR openxla/xla#34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes #34870

FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#34870 from serach24:cj/fix_scatter_issue b110942e013047c90b19a49bef0aa487061753f6
PiperOrigin-RevId: 840559332
@copybara-service copybara-service bot closed this in d1c2d75 Dec 5, 2025
copybara-service bot pushed a commit to tensorflow/tensorflow that referenced this pull request Dec 5, 2025
Imported from GitHub PR openxla/xla#34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes #34870

PiperOrigin-RevId: 840630461
hariprasadravi pushed a commit to hariprasadravi/tensorflow that referenced this pull request Dec 5, 2025
…ed operations

Imported from GitHub PR openxla/xla#34870

📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).

🚀 Kind of Contribution
🐛 Bug Fix

🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.
Copybara import of the project:

--
b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>:

[XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims

Merging this change closes tensorflow#34870

PiperOrigin-RevId: 840630461
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants