Skip to content

Add tree-mode support for nnx.shard_map#5238

Merged
copybara-service[bot] merged 1 commit intomainfrom
test_868100583
Feb 12, 2026
Merged

Add tree-mode support for nnx.shard_map#5238
copybara-service[bot] merged 1 commit intomainfrom
test_868100583

Conversation

@copybara-service
Copy link

Add tree-mode support for nnx.shard_map

Extends nnx.shard_map with a graph parameter (defaulting to True for
backward compatibility). When graph=False, shard_map operates in tree-mode
using pytree-based state propagation, mirroring the existing tree-mode
for nnx.jit. Tree-mode assumes referential transparency, doesn't propagate graph updates (only Variable updates are propagated), StateSharding is not supported / needed.

Tree-mode offers a simpler, more JAX-native usage pattern that treats Modules as simple stateless pytrees, and only Variables state is automatically handled.

@copybara-service copybara-service bot force-pushed the test_868100583 branch 3 times, most recently from dd29b8e to 7a6fc17 Compare February 12, 2026 06:00
Extends nnx.shard_map with a `graph` parameter (defaulting to True for
backward compatibility). When graph=False, shard_map operates in tree-mode
using pytree-based state propagation, mirroring the existing tree-mode
for nnx.jit. Tree-mode assumes referential transparency, doesn't propagate graph updates (only Variable updates are propagated), StateSharding is not supported / needed.

Tree-mode offers a simpler, more JAX-native usage pattern that treats Modules as simple stateless pytrees, and only Variables state is automatically handled.

PiperOrigin-RevId: 869026486
@copybara-service copybara-service bot merged commit 0bfe50a into main Feb 12, 2026
@copybara-service copybara-service bot deleted the test_868100583 branch February 12, 2026 06:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants