-
Notifications
You must be signed in to change notification settings - Fork 719
[XLA:GPU] Enable deterministic scatter for batched operations #34870
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…p with batch dims
5b1666d to
b110942
Compare
akuegel
approved these changes
Dec 5, 2025
Member
akuegel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you :)
copybara-service bot
pushed a commit
that referenced
this pull request
Dec 5, 2025
Imported from GitHub PR #34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes #34870 FUTURE_COPYBARA_INTEGRATE_REVIEW=#34870 from serach24:cj/fix_scatter_issue b110942 PiperOrigin-RevId: 840559332
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Dec 5, 2025
Imported from GitHub PR openxla/xla#34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes #34870 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#34870 from serach24:cj/fix_scatter_issue b110942e013047c90b19a49bef0aa487061753f6 PiperOrigin-RevId: 840559332
copybara-service bot
pushed a commit
that referenced
this pull request
Dec 5, 2025
Imported from GitHub PR #34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes #34870 FUTURE_COPYBARA_INTEGRATE_REVIEW=#34870 from serach24:cj/fix_scatter_issue b110942 PiperOrigin-RevId: 840559332
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Dec 5, 2025
Imported from GitHub PR openxla/xla#34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes #34870 FUTURE_COPYBARA_INTEGRATE_REVIEW=openxla/xla#34870 from serach24:cj/fix_scatter_issue b110942e013047c90b19a49bef0aa487061753f6 PiperOrigin-RevId: 840559332
copybara-service bot
pushed a commit
to tensorflow/tensorflow
that referenced
this pull request
Dec 5, 2025
Imported from GitHub PR openxla/xla#34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes #34870 PiperOrigin-RevId: 840630461
hariprasadravi
pushed a commit
to hariprasadravi/tensorflow
that referenced
this pull request
Dec 5, 2025
…ed operations Imported from GitHub PR openxla/xla#34870 📝 Summary of Changes Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer. The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order. 🎯 Justification Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because: BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2}) FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups). 🚀 Kind of Contribution 🐛 Bug Fix 🧪 Unit Tests Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference. 🧪 Execution Tests ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output. Copybara import of the project: -- b110942e013047c90b19a49bef0aa487061753f6 by Chenhao Jiang <chenhaoj@nvidia.com>: [XLA:GPU] Fix the issue of scatter determinism expander for scatter op with batch dims Merging this change closes tensorflow#34870 PiperOrigin-RevId: 840630461
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📝 Summary of Changes
Fix ScatterDeterminismExpander to correctly handle scatter operations that have been normalized from batched form by BatchedGatherScatterNormalizer.
The key fix is in FlattenIndices: when computing scalar indices for sorting, we now use scatter_dims_to_operand_dims to select the correct stride for each index column, rather than assuming index columns map to operand dimensions in order.
🎯 Justification
Previously, ScatterDeterminismExpander would produce incorrect results for batched scatter operations because:
BatchedGatherScatterNormalizer runs first and transforms batched scatter into a normalized form where batch indices are concatenated into the index tensor
After normalization, scatter_dims_to_operand_dims could be {0, 2} (not {0, 1, 2})
FlattenIndices assumed direct column→dimension mapping, causing indices from different batches to collide and mix updates incorrectly
This enables deterministic scatter for models using batched scatter operations (e.g., batched attention, batched embedding lookups).
🚀 Kind of Contribution
🐛 Bug Fix
🧪 Unit Tests
Existing test ScatterTest.Scatter_Add_F32 in xla/tests/scatter_test.cc now passes - it exercises batched scatter with input_batching_dims={0} and validates correctness against the interpreter reference.
🧪 Execution Tests
ScatterTest.Scatter_Add_F32 runs end-to-end on GPU, triggering the BatchedGatherScatterNormalizer → ScatterDeterminismExpander pipeline and asserting correct output.