Add data_transform argument to nn.scan and preserve PartitionSpec attributes.#1411
Merged
Conversation
fe6a033 to
1f2e26c
Compare
Codecov Report
@@ Coverage Diff @@
## master #1411 +/- ##
==========================================
- Coverage 82.28% 82.27% -0.01%
==========================================
Files 65 65
Lines 5328 5332 +4
==========================================
+ Hits 4384 4387 +3
- Misses 944 945 +1
Continue to review full report at Codecov.
|
…ributes. pjit PartitionSpec Module attributes were being downcast to tuples during freezing. An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD constraints don't propagate across XLA while loops. It allows us to use a workaround to re-apply SPMD constraints inside the scan body function. (Ultimately we hope to find a better upstream fix in JAX/XLA.) PiperOrigin-RevId: 383231061
1f2e26c to
5ffd000
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add data_transform argument to nn.scan and preserve PartitionSpec attributes.
pjit PartitionSpec Module attributes were being downcast to tuples during freezing.
An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD
constraints don't propagate across XLA while loops. It allows us to use a
workaround to re-apply SPMD constraints inside the scan body function.
(Ultimately we hope to find a better upstream fix in JAX/XLA.)