Skip to content

Commit

Permalink
Add support for kw_only for chex.dataclass.
Browse files Browse the repository at this point in the history
In practice, chex.dataclasses is keyword only, but isn't marked as such for `dataclasses`, which e.g. prevents doing inheritance across chex.dataclasses with a mix of default and not-defaulted arguments.

The feature is a no-op in Python 3.9 and lower because `kw_only` was added to `dataclasses.dataclass` only in Python 3.10.

PiperOrigin-RevId: 549313777
  • Loading branch information
jblespiau authored and ChexDev committed Jul 19, 2023
1 parent dfdbce6 commit bd01ae1
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 2 deletions.
16 changes: 14 additions & 2 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import collections
import dataclasses
import functools
import sys

from absl import logging
import jax
Expand Down Expand Up @@ -93,6 +94,7 @@ def dataclass(
order=False,
unsafe_hash=False,
frozen=False,
kw_only: bool = False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
):
"""JAX-friendly wrapper for :py:func:`dataclasses.dataclass`.
Expand All @@ -109,6 +111,7 @@ def dataclass(
order: See :py:func:`dataclasses.dataclass`.
unsafe_hash: See :py:func:`dataclasses.dataclass`.
frozen: See :py:func:`dataclasses.dataclass`.
kw_only: See :py:func:`dataclasses.dataclass`.
mappable_dataclass: If True (the default), methods to make the class
implement the :py:class:`collections.abc.Mapping` interface will be
generated and the class will include :py:class:`collections.abc.Mapping`
Expand All @@ -126,7 +129,7 @@ def dataclass(
def dcls(cls):
# Make sure to create a separate _Dataclass instance for each `cls`.
return _Dataclass(
init, repr, eq, order, unsafe_hash, frozen, mappable_dataclass
init, repr, eq, order, unsafe_hash, frozen, kw_only, mappable_dataclass
)(cls)

if cls is None:
Expand All @@ -145,6 +148,7 @@ def __init__(
order=False,
unsafe_hash=False,
frozen=False,
kw_only=False,
mappable_dataclass=True, # pylint: disable=redefined-outer-name
):
self.init = init
Expand All @@ -153,6 +157,7 @@ def __init__(
self.order = order
self.unsafe_hash = unsafe_hash
self.frozen = frozen
self.kw_only = kw_only
self.mappable_dataclass = mappable_dataclass
self.registered = False

Expand All @@ -165,6 +170,11 @@ def __call__(self, cls):
getattr(base, "__dataclass_params__").frozen and not self.frozen):
raise TypeError("cannot inherit non-frozen dataclass from a frozen one")

# `kw_only` is only available starting from 3.10.
version_dependent_args = {}
version = sys.version_info
if version.major == 3 and version.minor >= 10:
version_dependent_args = {"kw_only": self.kw_only}
# pytype: disable=wrong-keyword-args
dcls = dataclasses.dataclass(
cls,
Expand All @@ -173,7 +183,9 @@ def __call__(self, cls):
eq=self.eq,
order=self.order,
unsafe_hash=self.unsafe_hash,
frozen=self.frozen)
frozen=self.frozen,
**version_dependent_args,
)
# pytype: enable=wrong-keyword-args

fields_names = set(f.name for f in dataclasses.fields(dcls))
Expand Down
45 changes: 45 additions & 0 deletions chex/_src/dataclass_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import copy
import dataclasses
import pickle
import sys
from typing import Any, Generic, Mapping, TypeVar

from absl.testing import absltest
Expand Down Expand Up @@ -383,6 +384,50 @@ def test_dataclass_replace(self, frozen):
target_obj = dummy_dataclass(factor=factor, frozen=frozen)
asserts.assert_trees_all_close(obj, target_obj)

def test_dataclass_requires_kwargs_by_default(self):
factor = 1.0
with self.assertRaisesRegex(
ValueError,
"Mappable dataclass constructor doesn't support positional args.",
):
Dataclass(
NestedDataclass(
c=factor * np.ones((3,), dtype=np.float32),
d=factor * np.ones((4,), dtype=np.float32),
),
factor * 2 * np.ones((5,), dtype=np.float32),
)

def test_dataclass_mappable_dataclass_false(self):
factor = 1.0

@chex_dataclass(mappable_dataclass=False)
class NonMappableDataclass:
a: NestedDataclass
b: pytypes.ArrayDevice

NonMappableDataclass(
NestedDataclass(
c=factor * np.ones((3,), dtype=np.float32),
d=factor * np.ones((4,), dtype=np.float32),
),
factor * 2 * np.ones((5,), dtype=np.float32),
)

def test_inheritance_is_possible_thanks_to_kw_only(self):
if sys.version_info.minor < 10: # Feature only available for Python >= 3.10
return

@chex_dataclass(kw_only=True)
class Base:
default: int = 1

@chex_dataclass(kw_only=True)
class Child(Base):
non_default: int

Child(non_default=2)

def test_unfrozen_dataclass_is_mutable(self):
factor = 5.
obj = dummy_dataclass(frozen=False)
Expand Down

0 comments on commit bd01ae1

Please sign in to comment.