-
Notifications
You must be signed in to change notification settings - Fork 773
Closed
Description
Filtering a module with nnx.PathContains('') always returns an empty state.
Consider the following example:
from flax import nnx
class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
...
def __call__(self, x):
...
model = CNN(rngs=nnx.Rngs(0))
graph, state, rest = nnx.split(model, nnx.PathContains(''), ...)
print(state == nnx.State({})) # TrueI would argue that the empty path '' is contained in every path, much like '' in any_string is always true.
Therefore, I'd expect state to contain the full state and rest to be empty.
What do you think?
System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 24.04
- Flax, jax, jaxlib versions (obtain with
pip show flax jax jaxlib: Flax 0.10.2, jax 0.5.0 - Python version: 3.12
Metadata
Metadata
Assignees
Labels
No labels