Add Layout support to jax.jit.
#72
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Add
Layoutsupport tojax.jit.jax.jitnow acceptsLayoutinstances to thein_shardingsandout_shardingsargument. Major changes are just plumbingin_layoutsandout_layoutseverywhere.Note that public api is
Layout(device_local_layout, sharding)which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).