-
Notifications
You must be signed in to change notification settings - Fork 795
Description
System information
- OS Platform and Distribution: Linux Ubuntu 24.04
- Flax 0.12.5, jax 0.9.1, jaxlib 0.9.1
- Python 3.12.3
- CPU only (not GPU-specific)
Problem you have encountered:
nnx.Rngs.__getattr__ falls back to returning the default RngStream for any attribute name that isn't an explicitly defined stream — including dunder names like __pydantic_serializer__, __json__, __reduce_ex__, etc. This causes confusing breakage when Rngs objects are embedded in third-party containers that probe for protocol support via getattr.
Concrete example: placing an nnx.Rngs instance inside a Pydantic BaseModel with arbitrary_types_allowed=True crashes on model_dump(). Pydantic's Rust serializer checks getattr(obj, '__pydantic_serializer__') to see if the object provides its own serialization. Instead of getting AttributeError, it receives the default RngStream, then fails with:
TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'
The same class of problem will appear with any library or protocol that uses hasattr/getattr on dunder names to detect capabilities (pickle customisation, copy protocol, dataclass introspection, etc.).
What you expected to happen:
nnx.Rngs.__getattr__ should raise AttributeError for dunder names (__*__) rather than silently returning the default stream. Only user-chosen stream names (like params, dropout) should trigger the fallback.
Logs, error messages, etc:
TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'
Steps to reproduce:
from flax import nnx
# 1. Any dunder getattr returns the default RngStream instead of raising AttributeError
r = nnx.Rngs(default=42)
print(type(r.__pydantic_serializer__)) # <class 'flax.nnx.rnglib.RngStream'>
print(r.__pydantic_serializer__.tag) # 'default'
print(r.__totally_fake__.tag) # 'default'
# 2. Breaks Pydantic serialization (and likely other protocol probes)
from pydantic import BaseModel, ConfigDict
class State(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
rngs: nnx.Rngs
s = State(rngs=nnx.Rngs(default=42))
s.model_dump() # TypeError: 'RngStream' object cannot be converted to 'SchemaSerializer'Suggested fix — in Rngs.__getattr__, add an early guard:
if name.startswith('__') and name.endswith('__'):
raise AttributeError(name)