Skip to content

Conversation

@copybara-service
Copy link

Add Layout support to jax.jit.

jax.jit now accepts Layout instances to the in_shardings and out_shardings argument. Major changes are just plumbing in_layouts and out_layouts everywhere.

Note that public api is Layout(device_local_layout, sharding) which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

@copybara-service copybara-service bot force-pushed the test_621747200 branch 3 times, most recently from 171ae9f to 08d40f5 Compare April 6, 2024 02:40
`jax.jit` now accepts `Layout` instances to the `in_shardings` and `out_shardings` argument. Major changes are just plumbing `in_layouts` and `out_layouts` everywhere.

Note that public api is `Layout(device_local_layout, sharding)` which is how users will pass us the Layout but internally we split them apart into device_local_layout and sharding.

Docs are coming up on how to use the API and what Layouts mean and how to make sense of them (especially on TPU).

PiperOrigin-RevId: 622352537
@copybara-service copybara-service bot merged commit 2a1300d into main Apr 6, 2024
@copybara-service copybara-service bot deleted the test_621747200 branch April 6, 2024 03:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants