Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flax/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ class PyTreeNode:
>>> model_grad = jax.grad(loss_fn)(model)
"""

def __init_subclass__(cls):
dataclass(cls) # pytype: disable=wrong-arg-types
def __init_subclass__(cls, **kwargs):
dataclass(cls, **kwargs) # pytype: disable=wrong-arg-types

def __init__(self, *args, **kwargs):
# stub for pytype
Expand Down
66 changes: 55 additions & 11 deletions tests/struct_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Any

import jax
from absl.testing import absltest
from absl.testing import absltest, parameterized
from jax._src.tree_util import prefix_errors

from flax import struct
Expand All @@ -34,7 +34,7 @@ class Point:
meta: Any = struct.field(pytree_node=False)


class StructTest(absltest.TestCase):
class StructTest(parameterized.TestCase):
def test_no_extra_fields(self):
p = Point(x=1, y=2, meta={})
with self.assertRaises(dataclasses.FrozenInstanceError):
Expand Down Expand Up @@ -93,24 +93,68 @@ class A(struct.PyTreeNode):
a: int

# TODO(marcuschiam): Uncomment when Flax upgrades to Python 3.10.
# def test_kw_only(self):
# @struct.dataclass
# class A:
# a: int = 1

# with self.assertRaisesRegex(TypeError, "non-default argument 'b' follows default argument"):
# @parameterized.parameters(
# {'mode': 'dataclass'},
# {'mode': 'pytreenode'},
# )
# def test_kw_only(self, mode):
# if mode == 'dataclass':
# @struct.dataclass
# class A:
# a: int = 1

# @functools.partial(struct.dataclass, kw_only=True)
# class B(A):
# b: int
# elif mode == 'pytreenode':
# class A(struct.PyTreeNode):
# a: int = 1

# @functools.partial(struct.dataclass, kw_only=True)
# class B(A):
# b: int
# class B(A, struct.PyTreeNode, kw_only=True):
# b: int

# obj = B(b=2)
# self.assertEqual(obj.a, 1)
# self.assertEqual(obj.b, 2)

# with self.assertRaisesRegex(TypeError, "non-default argument 'b' follows default argument"):
# if mode == 'dataclass':
# @struct.dataclass
# class B(A):
# b: int
# elif mode == 'pytreenode':
# class B(A, struct.PyTreeNode):
# b: int

# TODO(marcuschiam): Uncomment when Flax upgrades to Python 3.10.
# @parameterized.parameters(
# {'mode': 'dataclass'},
# {'mode': 'pytreenode'},
# )
# def test_mutable(self, mode):
# if mode == 'dataclass':
# @struct.dataclass
# class A:
# a: int = 1

# @functools.partial(struct.dataclass, frozen=False)
# class B:
# b: int = 1
# elif mode == 'pytreenode':
# class A(struct.PyTreeNode):
# a: int = 1

# class B(struct.PyTreeNode, frozen=False):
# b: int = 1

# obj = A()
# with self.assertRaisesRegex(dataclasses.FrozenInstanceError, "cannot assign to field 'a'"):
# obj.a = 2

# obj = B()
# obj.b = 2
# self.assertEqual(obj.b, 2)


if __name__ == '__main__':
absltest.main()