diff --git a/flax/struct.py b/flax/struct.py index 7a8283a9d..29dbb9c2f 100644 --- a/flax/struct.py +++ b/flax/struct.py @@ -227,8 +227,8 @@ class PyTreeNode: >>> model_grad = jax.grad(loss_fn)(model) """ - def __init_subclass__(cls): - dataclass(cls) # pytype: disable=wrong-arg-types + def __init_subclass__(cls, **kwargs): + dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types def __init__(self, *args, **kwargs): # stub for pytype diff --git a/tests/struct_test.py b/tests/struct_test.py index da517c739..8ab3119d0 100644 --- a/tests/struct_test.py +++ b/tests/struct_test.py @@ -18,7 +18,7 @@ from typing import Any import jax -from absl.testing import absltest +from absl.testing import absltest, parameterized from jax._src.tree_util import prefix_errors from flax import struct @@ -34,7 +34,7 @@ class Point: meta: Any = struct.field(pytree_node=False) -class StructTest(absltest.TestCase): +class StructTest(parameterized.TestCase): def test_no_extra_fields(self): p = Point(x=1, y=2, meta={}) with self.assertRaises(dataclasses.FrozenInstanceError): @@ -93,24 +93,68 @@ class A(struct.PyTreeNode): a: int # TODO(marcuschiam): Uncomment when Flax upgrades to Python 3.10. - # def test_kw_only(self): - # @struct.dataclass - # class A: - # a: int = 1 - - # with self.assertRaisesRegex(TypeError, "non-default argument 'b' follows default argument"): + # @parameterized.parameters( + # {'mode': 'dataclass'}, + # {'mode': 'pytreenode'}, + # ) + # def test_kw_only(self, mode): + # if mode == 'dataclass': # @struct.dataclass + # class A: + # a: int = 1 + + # @functools.partial(struct.dataclass, kw_only=True) # class B(A): # b: int + # elif mode == 'pytreenode': + # class A(struct.PyTreeNode): + # a: int = 1 - # @functools.partial(struct.dataclass, kw_only=True) - # class B(A): - # b: int + # class B(A, struct.PyTreeNode, kw_only=True): + # b: int # obj = B(b=2) # self.assertEqual(obj.a, 1) # self.assertEqual(obj.b, 2) + # with self.assertRaisesRegex(TypeError, "non-default argument 'b' follows default argument"): + # if mode == 'dataclass': + # @struct.dataclass + # class B(A): + # b: int + # elif mode == 'pytreenode': + # class B(A, struct.PyTreeNode): + # b: int + + # TODO(marcuschiam): Uncomment when Flax upgrades to Python 3.10. + # @parameterized.parameters( + # {'mode': 'dataclass'}, + # {'mode': 'pytreenode'}, + # ) + # def test_mutable(self, mode): + # if mode == 'dataclass': + # @struct.dataclass + # class A: + # a: int = 1 + + # @functools.partial(struct.dataclass, frozen=False) + # class B: + # b: int = 1 + # elif mode == 'pytreenode': + # class A(struct.PyTreeNode): + # a: int = 1 + + # class B(struct.PyTreeNode, frozen=False): + # b: int = 1 + + # obj = A() + # with self.assertRaisesRegex(dataclasses.FrozenInstanceError, "cannot assign to field 'a'"): + # obj.a = 2 + + # obj = B() + # obj.b = 2 + # self.assertEqual(obj.b, 2) + if __name__ == '__main__': absltest.main()