From 4dd61948d4403d5a58d365f05ecaf01689de9943 Mon Sep 17 00:00:00 2001 From: mohsinm-dev Date: Wed, 11 Feb 2026 00:18:04 +0500 Subject: [PATCH] Fix PyTreeNode + Generic losing __parameters__ when Generic is last in bases (#5233) --- flax/struct.py | 1 + tests/struct_test.py | 35 ++++++++++++++++++++++++++++++++++- 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/flax/struct.py b/flax/struct.py index 53827f3db..903b10dba 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -228,6 +228,7 @@ class PyTreeNode: """ def __init_subclass__(cls, **kwargs): + super().__init_subclass__() dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types def __init__(self, *args, **kwargs): diff --git a/tests/struct_test.py b/tests/struct_test.py index 9bb4986c9..304723039 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -15,7 +15,7 @@ """Tests for flax.struct.""" import dataclasses -from typing import Any +from typing import Any, Generic, TypeVar import jax from absl.testing import absltest, parameterized @@ -157,5 +157,38 @@ class B(struct.PyTreeNode, frozen=False): self.assertEqual(obj.b, 2) + def test_generic_pytreenode_base_order(self): + # PyTreeNode + Generic should work regardless of base order (#5233). + T = TypeVar('T') + U = TypeVar('U') + + # Generic after PyTreeNode. + class A(struct.PyTreeNode, Generic[T, U]): + x: int = 0 + + self.assertEqual(A.__parameters__, (T, U)) + A[int, int] # should not raise + + # Generic before PyTreeNode. + class B(Generic[T, U], struct.PyTreeNode): + x: int = 0 + + self.assertEqual(B.__parameters__, (T, U)) + B[int, int] # should not raise + + # Subclassing a parameterized generic PyTreeNode. + class Base(struct.PyTreeNode, Generic[T, U]): + x: int = 0 + + class Sub(Base[int, str]): + y: int = 1 + + obj = Sub(x=1, y=2) + self.assertEqual(obj.x, 1) + self.assertEqual(obj.y, 2) + leaves = jax.tree_util.tree_leaves(obj) + self.assertEqual(leaves, [1, 2]) + + if __name__ == '__main__': absltest.main()