Skip to content

Filtering with nnx.PathContains('') returns an empty state #4660

@NKlug

Description

@NKlug

Filtering a module with nnx.PathContains('') always returns an empty state.
Consider the following example:

from flax import nnx

class CNN(nnx.Module):

  def __init__(self, *, rngs: nnx.Rngs):
    self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
    self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
    ...

  def __call__(self, x):
    ...

model = CNN(rngs=nnx.Rngs(0))

graph, state, rest = nnx.split(model, nnx.PathContains(''), ...)
print(state == nnx.State({})) # True

I would argue that the empty path '' is contained in every path, much like '' in any_string is always true.
Therefore, I'd expect state to contain the full state and rest to be empty.
What do you think?

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Ubuntu 24.04
  • Flax, jax, jaxlib versions (obtain with pip show flax jax jaxlib: Flax 0.10.2, jax 0.5.0
  • Python version: 3.12

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions