Skip to content

Add data_transform argument to nn.scan and preserve PartitionSpec attributes.#1411

Merged
copybara-service[bot] merged 1 commit into
masterfrom
test_383161650
Jul 6, 2021
Merged

Add data_transform argument to nn.scan and preserve PartitionSpec attributes.#1411
copybara-service[bot] merged 1 commit into
masterfrom
test_383161650

Conversation

@copybara-service
Copy link
Copy Markdown

Add data_transform argument to nn.scan and preserve PartitionSpec attributes.

pjit PartitionSpec Module attributes were being downcast to tuples during freezing.

An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD
constraints don't propagate across XLA while loops. It allows us to use a
workaround to re-apply SPMD constraints inside the scan body function.
(Ultimately we hope to find a better upstream fix in JAX/XLA.)

@google-cla google-cla Bot added the cla: yes label Jul 6, 2021
@copybara-service copybara-service Bot force-pushed the test_383161650 branch 2 times, most recently from fe6a033 to 1f2e26c Compare July 6, 2021 08:59
@codecov-commenter
Copy link
Copy Markdown

Codecov Report

Merging #1411 (1f2e26c) into master (03f8e7b) will decrease coverage by 0.00%.
The diff coverage is 85.71%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1411      +/-   ##
==========================================
- Coverage   82.28%   82.27%   -0.01%     
==========================================
  Files          65       65              
  Lines        5328     5332       +4     
==========================================
+ Hits         4384     4387       +3     
- Misses        944      945       +1     
Impacted Files Coverage Δ
flax/linen/transforms.py 91.45% <ø> (ø)
flax/core/lift.py 96.15% <50.00%> (-0.26%) ⬇️
flax/linen/module.py 94.21% <100.00%> (+0.02%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 03f8e7b...1f2e26c. Read the comment docs.

…ributes.

pjit PartitionSpec Module attributes were being downcast to tuples during freezing.

An extra data_transform kwarg to nn.scan is used to help fix an issue where XLA SPMD
constraints don't propagate across XLA while loops.  It allows us to use a
workaround to re-apply SPMD constraints inside the scan body function.
(Ultimately we hope to find a better upstream fix in JAX/XLA.)

PiperOrigin-RevId: 383231061
@copybara-service copybara-service Bot merged commit 5ffd000 into master Jul 6, 2021
@copybara-service copybara-service Bot deleted the test_383161650 branch July 6, 2021 09:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants