Skip to content

Conversation

@copybara-service
Copy link

revise axes_scan to flatten argument pytrees only once

A user has a custom pytree node with the unusual behavior that it introduces
new arrays when flattening. That is, it's as if we had:

# a custom object with two leaf arrays
custom_tree_object = SomeObject(jax_arrray1, jax_array2)
# convert leaves to ShapedArrays
custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object)
# flatten, should only see ShapedArrays, right?
leaves, treedef = jax.tree.flatten(custom_tree_object2)
print(leaves)
# [ShapedArray(...), ShapedArray(...), np.array(...)]

This change makes the flax.nn.scan function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues.

I don't think we want to support this in general, but this fix seemed like the most
expedient way to roll fowrard jax-ml/jax#29273

A user has a custom pytree node with the unusual behavior that it introduces
new arrays when flattening. That is, it's as if we had:

```python
# a custom object with two leaf arrays
custom_tree_object = SomeObject(jax_arrray1, jax_array2)
# convert leaves to ShapedArrays
custom_tree_object2 = jax.tree.map(core.typeof, custom_tree_object)
# flatten, should only see ShapedArrays, right?
leaves, treedef = jax.tree.flatten(custom_tree_object2)
print(leaves)
# [ShapedArray(...), ShapedArray(...), np.array(...)]
```

This change makes the `flax.nn.scan` function robust to such behavior. Without it, we were passing non-AbstractValues into JAX where JAX required AbstractValues.

I don't think we want to support this in general, but this fix seemed like the most
expedient way to roll fowrard jax-ml/jax#29273

PiperOrigin-RevId: 768175118
@copybara-service copybara-service bot merged commit 893a660 into main Jun 6, 2025
@copybara-service copybara-service bot deleted the test_768154467 branch June 6, 2025 19:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants