Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support match_args in cdef dataclasses #5381

Merged
merged 5 commits into from Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 24 additions & 2 deletions Cython/Compiler/Dataclass.py
Expand Up @@ -299,7 +299,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
# default argument values from https://docs.python.org/3/library/dataclasses.html
kwargs = dict(init=True, repr=True, eq=True,
order=False, unsafe_hash=False,
frozen=False, kw_only=False)
frozen=False, kw_only=False, match_args=True)
if dataclass_args is not None:
if dataclass_args[0]:
error(node.pos, "cython.dataclasses.dataclass takes no positional arguments")
Expand Down Expand Up @@ -330,7 +330,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
for k, v in kwargs.items() ] +
[ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
ExprNodes.BoolNode(node.pos, value=v))
for k, v in [('kw_only', kw_only), ('match_args', False),
for k, v in [('kw_only', kw_only),
('slots', False), ('weakref_slot', False)]
])
dataclass_params = make_dataclass_call_helper(
Expand All @@ -347,6 +347,7 @@ def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):

code = TemplateCode()
generate_init_code(code, kwargs['init'], node, fields, kw_only)
generate_match_args(code, kwargs['match_args'], node, fields, kw_only)
generate_repr_code(code, kwargs['repr'], node, fields)
generate_eq_code(code, kwargs['eq'], node, fields)
generate_order_code(code, kwargs['order'], node, fields)
Expand Down Expand Up @@ -469,6 +470,27 @@ def generate_init_code(code, init, node, fields, kw_only):
function_start_point.add_code_line(u"def __init__(%s):" % args)


def generate_match_args(code, match_args, node, fields, global_kw_only):
"""
Generates a tuple containing what would be the positional args to __init__

Note that this is generated even if the user overrides init
"""
if not match_args or node.scope.lookup_here("__match_args__"):
return
positional_arg_names = []
for field_name, field in fields.items():
# TODO hasattr and global_kw_only can be removed once full kw_only support is added
field_is_kw_only = global_kw_only or (
hasattr(field, 'kw_only') and field.kw_only.value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This hasattr is only in to anticipate adding a kw_only attribute to field, so that it behaves correctly when that happens. Obviously when that does happen we won't need it.

It's a bit clunky, but it's just future-proofing, and can be removed later

)
if not field_is_kw_only:
positional_arg_names.append(repr(field_name))
positional_arg_names.append("") # finish the tuple with a comma
args = u", ".join(positional_arg_names)
code.add_code_line("__match_args__ = (%s)" % args)


def generate_repr_code(code, repr, node, fields):
"""
The core of the CPython implementation is just:
Expand Down
2 changes: 1 addition & 1 deletion Tools/make_dataclass_tests.py
Expand Up @@ -58,6 +58,7 @@
("TestCase", "test_class_attrs"),
("TestCase", "test_hash_field_rules"),
("TestStringAnnotations",), # almost all the texts here use local variables
("TestMatchArgs", "test_explicit_match_args"),
# Currently unsupported
# =====================
(
Expand All @@ -68,7 +69,6 @@
("TestCase", "test_missing_default"), # MISSING
("TestCase", "test_missing_repr"), # MISSING
("TestSlots",), # __slots__ isn't understood
("TestMatchArgs",),
("TestKeywordArgs", "test_field_marked_as_kwonly"),
("TestKeywordArgs", "test_match_args"),
("TestKeywordArgs", "test_KW_ONLY"),
Expand Down
60 changes: 60 additions & 0 deletions tests/run/test_dataclasses.pyx
Expand Up @@ -607,6 +607,45 @@ class C_TestReplace_test_recursive_repr_misc_attrs:
f: object
g: int

@dataclass
@cclass
class C_TestMatchArgs_test_match_args:
a: int

@dataclass(repr=False, eq=False, init=False)
@cclass
class X_TestMatchArgs_test_bpo_43764:
a: int
b: int
c: int

@dataclass(match_args=False)
@cclass
class X_TestMatchArgs_test_match_args_argument:
a: int

@dataclass(match_args=False)
@cclass
class Y_TestMatchArgs_test_match_args_argument:
a: int
__match_args__ = ('b',)

@dataclass(match_args=False)
@cclass
class Z_TestMatchArgs_test_match_args_argument(Y_TestMatchArgs_test_match_args_argument):
z: int

@dataclass
@cclass
class A_TestMatchArgs_test_match_args_argument:
a: int
z: int

@dataclass(match_args=False)
@cclass
class B_TestMatchArgs_test_match_args_argument(A_TestMatchArgs_test_match_args_argument):
b: int

class CustomError(Exception):
pass

Expand Down Expand Up @@ -1180,6 +1219,27 @@ class TestReplace(unittest.TestCase):
class TestAbstract(unittest.TestCase):
pass

class TestMatchArgs(unittest.TestCase):

def test_match_args(self):
C = C_TestMatchArgs_test_match_args
self.assertEqual(C(42).__match_args__, ('a',))

def test_bpo_43764(self):
X = X_TestMatchArgs_test_bpo_43764
self.assertEqual(X.__match_args__, ('a', 'b', 'c'))

def test_match_args_argument(self):
X = X_TestMatchArgs_test_match_args_argument
self.assertNotIn('__match_args__', X.__dict__)
Y = Y_TestMatchArgs_test_match_args_argument
self.assertEqual(Y.__match_args__, ('b',))
Z = Z_TestMatchArgs_test_match_args_argument
self.assertEqual(Z.__match_args__, ('b',))
A = A_TestMatchArgs_test_match_args_argument
B = B_TestMatchArgs_test_match_args_argument
self.assertEqual(B.__match_args__, ('a', 'z'))

class TestKeywordArgs(unittest.TestCase):
pass
if __name__ == '__main__':
Expand Down