Skip to content

nnx.Rngs.__*__ returns the default stream, which causes confusing breakage in combination with other packages #5375

@manulari

Description

@manulari

System information

  • OS Platform and Distribution: Linux Ubuntu 24.04
  • Flax 0.12.5, jax 0.9.1, jaxlib 0.9.1
  • Python 3.12.3
  • CPU only (not GPU-specific)

Problem you have encountered:

nnx.Rngs.__getattr__ falls back to returning the default RngStream for any attribute name that isn't an explicitly defined stream — including dunder names like __pydantic_serializer__, __json__, __reduce_ex__, etc. This causes confusing breakage when Rngs objects are embedded in third-party containers that probe for protocol support via getattr.

Concrete example: placing an nnx.Rngs instance inside a Pydantic BaseModel with arbitrary_types_allowed=True crashes on model_dump(). Pydantic's Rust serializer checks getattr(obj, '__pydantic_serializer__') to see if the object provides its own serialization. Instead of getting AttributeError, it receives the default RngStream, then fails with:

TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'

The same class of problem will appear with any library or protocol that uses hasattr/getattr on dunder names to detect capabilities (pickle customisation, copy protocol, dataclass introspection, etc.).

What you expected to happen:

nnx.Rngs.__getattr__ should raise AttributeError for dunder names (__*__) rather than silently returning the default stream. Only user-chosen stream names (like params, dropout) should trigger the fallback.

Logs, error messages, etc:

TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'

Steps to reproduce:

from flax import nnx

# 1. Any dunder getattr returns the default RngStream instead of raising AttributeError
r = nnx.Rngs(default=42)
print(type(r.__pydantic_serializer__))  # <class 'flax.nnx.rnglib.RngStream'>
print(r.__pydantic_serializer__.tag)     # 'default'
print(r.__totally_fake__.tag)            # 'default'

# 2. Breaks Pydantic serialization (and likely other protocol probes)
from pydantic import BaseModel, ConfigDict

class State(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    rngs: nnx.Rngs

s = State(rngs=nnx.Rngs(default=42))
s.model_dump()  # TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'

Suggested fix — in Rngs.__getattr__, add an early guard:

if name.startswith('__') and name.endswith('__'):
    raise AttributeError(name)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions