From 9b28a7f182c5bef22186c20275cc5f7ec4817474 Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 13:40:30 -0700 Subject: [PATCH 1/8] Allow fast ObjectType creation based on dataclasses --- graphene/pyutils/dataclasses.py | 1181 +++++++++++++++++++++++ graphene/types/base.py | 4 +- graphene/types/objecttype.py | 23 +- graphene/types/tests/test_objecttype.py | 10 +- 4 files changed, 1212 insertions(+), 6 deletions(-) create mode 100644 graphene/pyutils/dataclasses.py diff --git a/graphene/pyutils/dataclasses.py b/graphene/pyutils/dataclasses.py new file mode 100644 index 000000000..3f2e45c8b --- /dev/null +++ b/graphene/pyutils/dataclasses.py @@ -0,0 +1,1181 @@ +# This is a polyfill for dataclasses +# https://docs.python.org/3/library/dataclasses.html +# Original PEP proposal: PEP 557 +# https://www.python.org/dev/peps/pep-0557/ +import re +import sys +import copy +import types +import inspect +import keyword + +__all__ = ['dataclass', + 'field', + 'Field', + 'FrozenInstanceError', + 'InitVar', + 'MISSING', + + # Helper functions. + 'fields', + 'asdict', + 'astuple', + 'make_dataclass', + 'replace', + 'is_dataclass', + ] + +# Conditions for adding methods. The boxes indicate what action the +# dataclass decorator takes. For all of these tables, when I talk +# about init=, repr=, eq=, order=, unsafe_hash=, or frozen=, I'm +# referring to the arguments to the @dataclass decorator. When +# checking if a dunder method already exists, I mean check for an +# entry in the class's __dict__. I never check to see if an attribute +# is defined in a base class. + +# Key: +# +=========+=========================================+ +# + Value | Meaning | +# +=========+=========================================+ +# | | No action: no method is added. | +# +---------+-----------------------------------------+ +# | add | Generated method is added. | +# +---------+-----------------------------------------+ +# | raise | TypeError is raised. | +# +---------+-----------------------------------------+ +# | None | Attribute is set to None. | +# +=========+=========================================+ + +# __init__ +# +# +--- init= parameter +# | +# v | | | +# | no | yes | <--- class has __init__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __repr__ +# +# +--- repr= parameter +# | +# v | | | +# | no | yes | <--- class has __repr__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + + +# __setattr__ +# __delattr__ +# +# +--- frozen= parameter +# | +# v | | | +# | no | yes | <--- class has __setattr__ or __delattr__ in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because not adding these methods would break the "frozen-ness" +# of the class. + +# __eq__ +# +# +--- eq= parameter +# | +# v | | | +# | no | yes | <--- class has __eq__ in __dict__? +# +=======+=======+=======+ +# | False | | | +# +-------+-------+-------+ +# | True | add | | <- the default +# +=======+=======+=======+ + +# __lt__ +# __le__ +# __gt__ +# __ge__ +# +# +--- order= parameter +# | +# v | | | +# | no | yes | <--- class has any comparison method in __dict__? +# +=======+=======+=======+ +# | False | | | <- the default +# +-------+-------+-------+ +# | True | add | raise | +# +=======+=======+=======+ +# Raise because to allow this case would interfere with using +# functools.total_ordering. + +# __hash__ + +# +------------------- unsafe_hash= parameter +# | +----------- eq= parameter +# | | +--- frozen= parameter +# | | | +# v v v | | | +# | no | yes | <--- class has explicitly defined __hash__ +# +=======+=======+=======+========+========+ +# | False | False | False | | | No __eq__, use the base class __hash__ +# +-------+-------+-------+--------+--------+ +# | False | False | True | | | No __eq__, use the base class __hash__ +# +-------+-------+-------+--------+--------+ +# | False | True | False | None | | <-- the default, not hashable +# +-------+-------+-------+--------+--------+ +# | False | True | True | add | | Frozen, so hashable, allows override +# +-------+-------+-------+--------+--------+ +# | True | False | False | add | raise | Has no __eq__, but hashable +# +-------+-------+-------+--------+--------+ +# | True | False | True | add | raise | Has no __eq__, but hashable +# +-------+-------+-------+--------+--------+ +# | True | True | False | add | raise | Not frozen, but hashable +# +-------+-------+-------+--------+--------+ +# | True | True | True | add | raise | Frozen, so hashable +# +=======+=======+=======+========+========+ +# For boxes that are blank, __hash__ is untouched and therefore +# inherited from the base class. If the base is object, then +# id-based hashing is used. +# +# Note that a class may already have __hash__=None if it specified an +# __eq__ method in the class body (not one that was created by +# @dataclass). +# +# See _hash_action (below) for a coded version of this table. + + +# Raised when an attempt is made to modify a frozen class. +class FrozenInstanceError(AttributeError): pass + +# A sentinel object for default values to signal that a default +# factory will be used. This is given a nice repr() which will appear +# in the function signature of dataclasses' constructors. +class _HAS_DEFAULT_FACTORY_CLASS: + def __repr__(self): + return '' +_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() + +# A sentinel object to detect if a parameter is supplied or not. Use +# a class to give it a better repr. +class _MISSING_TYPE: + pass +MISSING = _MISSING_TYPE() + +# Since most per-field metadata will be unused, create an empty +# read-only proxy that can be shared among all fields. +_EMPTY_METADATA = types.MappingProxyType({}) + +# Markers for the various kinds of fields and pseudo-fields. +class _FIELD_BASE: + def __init__(self, name): + self.name = name + def __repr__(self): + return self.name +_FIELD = _FIELD_BASE('_FIELD') +_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR') +_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR') + +# The name of an attribute on the class where we store the Field +# objects. Also used to check if a class is a Data Class. +_FIELDS = '__dataclass_fields__' + +# The name of an attribute on the class that stores the parameters to +# @dataclass. +_PARAMS = '__dataclass_params__' + +# The name of the function, that if it exists, is called at the end of +# __init__. +_POST_INIT_NAME = '__post_init__' + +# String regex that string annotations for ClassVar or InitVar must match. +# Allows "identifier.identifier[" or "identifier[". +# https://bugs.python.org/issue33453 for details. +_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') + +class _InitVarMeta(type): + def __getitem__(self, params): + return self + +class InitVar(metaclass=_InitVarMeta): + pass + + +# Instances of Field are only ever created from within this module, +# and only from the field() function, although Field instances are +# exposed externally as (conceptually) read-only objects. +# +# name and type are filled in after the fact, not in __init__. +# They're not known at the time this class is instantiated, but it's +# convenient if they're available later. +# +# When cls._FIELDS is filled in with a list of Field objects, the name +# and type fields will have been populated. +class Field: + __slots__ = ('name', + 'type', + 'default', + 'default_factory', + 'repr', + 'hash', + 'init', + 'compare', + 'metadata', + '_field_type', # Private: not to be used by user code. + ) + + def __init__(self, default, default_factory, init, repr, hash, compare, + metadata): + self.name = None + self.type = None + self.default = default + self.default_factory = default_factory + self.init = init + self.repr = repr + self.hash = hash + self.compare = compare + self.metadata = (_EMPTY_METADATA + if metadata is None or len(metadata) == 0 else + types.MappingProxyType(metadata)) + self._field_type = None + + def __repr__(self): + return ('Field(' + f'name={self.name!r},' + f'type={self.type!r},' + f'default={self.default!r},' + f'default_factory={self.default_factory!r},' + f'init={self.init!r},' + f'repr={self.repr!r},' + f'hash={self.hash!r},' + f'compare={self.compare!r},' + f'metadata={self.metadata!r},' + f'_field_type={self._field_type}' + ')') + + # This is used to support the PEP 487 __set_name__ protocol in the + # case where we're using a field that contains a descriptor as a + # defaul value. For details on __set_name__, see + # https://www.python.org/dev/peps/pep-0487/#implementation-details. + # + # Note that in _process_class, this Field object is overwritten + # with the default value, so the end result is a descriptor that + # had __set_name__ called on it at the right time. + def __set_name__(self, owner, name): + func = getattr(type(self.default), '__set_name__', None) + if func: + # There is a __set_name__ method on the descriptor, call + # it. + func(self.default, owner, name) + + +class _DataclassParams: + __slots__ = ('init', + 'repr', + 'eq', + 'order', + 'unsafe_hash', + 'frozen', + ) + + def __init__(self, init, repr, eq, order, unsafe_hash, frozen): + self.init = init + self.repr = repr + self.eq = eq + self.order = order + self.unsafe_hash = unsafe_hash + self.frozen = frozen + + def __repr__(self): + return ('_DataclassParams(' + f'init={self.init!r},' + f'repr={self.repr!r},' + f'eq={self.eq!r},' + f'order={self.order!r},' + f'unsafe_hash={self.unsafe_hash!r},' + f'frozen={self.frozen!r}' + ')') + + +# This function is used instead of exposing Field creation directly, +# so that a type checker can be told (via overloads) that this is a +# function whose type depends on its parameters. +def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, + hash=None, compare=True, metadata=None): + """Return an object to identify dataclass fields. + + default is the default value of the field. default_factory is a + 0-argument function called to initialize a field's value. If init + is True, the field will be a parameter to the class's __init__() + function. If repr is True, the field will be included in the + object's repr(). If hash is True, the field will be included in + the object's hash(). If compare is True, the field will be used + in comparison functions. metadata, if specified, must be a + mapping which is stored but not otherwise examined by dataclass. + + It is an error to specify both default and default_factory. + """ + + if default is not MISSING and default_factory is not MISSING: + raise ValueError('cannot specify both default and default_factory') + return Field(default, default_factory, init, repr, hash, compare, + metadata) + + +def _tuple_str(obj_name, fields): + # Return a string representing each field of obj_name as a tuple + # member. So, if fields is ['x', 'y'] and obj_name is "self", + # return "(self.x,self.y)". + + # Special case for the 0-tuple. + if not fields: + return '()' + # Note the trailing comma, needed if this turns out to be a 1-tuple. + return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' + + +def _create_fn(name, args, body, *, globals=None, locals=None, + return_type=MISSING): + # Note that we mutate locals when exec() is called. Caller + # beware! The only callers are internal to this module, so no + # worries about external callers. + if locals is None: + locals = {} + return_annotation = '' + if return_type is not MISSING: + locals['_return_type'] = return_type + return_annotation = '->_return_type' + args = ','.join(args) + body = '\n'.join(f' {b}' for b in body) + + # Compute the text of the entire function. + txt = f'def {name}({args}){return_annotation}:\n{body}' + + exec(txt, globals, locals) + return locals[name] + + +def _field_assign(frozen, name, value, self_name): + # If we're a frozen class, then assign to our fields in __init__ + # via object.__setattr__. Otherwise, just use a simple + # assignment. + # + # self_name is what "self" is called in this function: don't + # hard-code "self", since that might be a field name. + if frozen: + return f'object.__setattr__({self_name},{name!r},{value})' + return f'{self_name}.{name}={value}' + + +def _field_init(f, frozen, globals, self_name): + # Return the text of the line in the body of __init__ that will + # initialize this field. + + default_name = f'_dflt_{f.name}' + if f.default_factory is not MISSING: + if f.init: + # This field has a default factory. If a parameter is + # given, use it. If not, call the factory. + globals[default_name] = f.default_factory + value = (f'{default_name}() ' + f'if {f.name} is _HAS_DEFAULT_FACTORY ' + f'else {f.name}') + else: + # This is a field that's not in the __init__ params, but + # has a default factory function. It needs to be + # initialized here by calling the factory function, + # because there's no other way to initialize it. + + # For a field initialized with a default=defaultvalue, the + # class dict just has the default value + # (cls.fieldname=defaultvalue). But that won't work for a + # default factory, the factory must be called in __init__ + # and we must assign that to self.fieldname. We can't + # fall back to the class dict's value, both because it's + # not set, and because it might be different per-class + # (which, after all, is why we have a factory function!). + + globals[default_name] = f.default_factory + value = f'{default_name}()' + else: + # No default factory. + if f.init: + if f.default is MISSING: + # There's no default, just do an assignment. + value = f.name + elif f.default is not MISSING: + globals[default_name] = f.default + value = f.name + else: + # This field does not need initialization. Signify that + # to the caller by returning None. + return None + + # Only test this now, so that we can create variables for the + # default. However, return None to signify that we're not going + # to actually do the assignment statement for InitVars. + if f._field_type == _FIELD_INITVAR: + return None + + # Now, actually generate the field assignment. + return _field_assign(frozen, f.name, value, self_name) + + +def _init_param(f): + # Return the __init__ parameter string for this field. For + # example, the equivalent of 'x:int=3' (except instead of 'int', + # reference a variable set to int, and instead of '3', reference a + # variable set to 3). + if f.default is MISSING and f.default_factory is MISSING: + # There's no default, and no default_factory, just output the + # variable name and type. + default = '' + elif f.default is not MISSING: + # There's a default, this will be the name that's used to look + # it up. + default = f'=_dflt_{f.name}' + elif f.default_factory is not MISSING: + # There's a factory function. Set a marker. + default = '=_HAS_DEFAULT_FACTORY' + return f'{f.name}:_type_{f.name}{default}' + + +def _init_fn(fields, frozen, has_post_init, self_name): + # fields contains both real fields and InitVar pseudo-fields. + + # Make sure we don't have fields without defaults following fields + # with defaults. This actually would be caught when exec-ing the + # function source code, but catching it here gives a better error + # message, and future-proofs us in case we build up the function + # using ast. + seen_default = False + for f in fields: + # Only consider fields in the __init__ call. + if f.init: + if not (f.default is MISSING and f.default_factory is MISSING): + seen_default = True + elif seen_default: + raise TypeError(f'non-default argument {f.name!r} ' + 'follows default argument') + + globals = {'MISSING': MISSING, + '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY} + + body_lines = [] + for f in fields: + line = _field_init(f, frozen, globals, self_name) + # line is None means that this field doesn't require + # initialization (it's a pseudo-field). Just skip it. + if line: + body_lines.append(line) + + # Does this class have a post-init function? + if has_post_init: + params_str = ','.join(f.name for f in fields + if f._field_type is _FIELD_INITVAR) + body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') + + # If no body lines, use 'pass'. + if not body_lines: + body_lines = ['pass'] + + locals = {f'_type_{f.name}': f.type for f in fields} + return _create_fn('__init__', + [self_name] + [_init_param(f) for f in fields if f.init], + body_lines, + locals=locals, + globals=globals, + return_type=None) + + +def _repr_fn(fields): + return _create_fn('__repr__', + ('self',), + ['return self.__class__.__qualname__ + f"(' + + ', '.join([f"{f.name}={{self.{f.name}!r}}" + for f in fields]) + + ')"']) + + +def _frozen_get_del_attr(cls, fields): + # XXX: globals is modified on the first call to _create_fn, then + # the modified version is used in the second call. Is this okay? + globals = {'cls': cls, + 'FrozenInstanceError': FrozenInstanceError} + if fields: + fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' + else: + # Special case for the zero-length tuple. + fields_str = '()' + return (_create_fn('__setattr__', + ('self', 'name', 'value'), + (f'if type(self) is cls or name in {fields_str}:', + ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', + f'super(cls, self).__setattr__(name, value)'), + globals=globals), + _create_fn('__delattr__', + ('self', 'name'), + (f'if type(self) is cls or name in {fields_str}:', + ' raise FrozenInstanceError(f"cannot delete field {name!r}")', + f'super(cls, self).__delattr__(name)'), + globals=globals), + ) + + +def _cmp_fn(name, op, self_tuple, other_tuple): + # Create a comparison function. If the fields in the object are + # named 'x' and 'y', then self_tuple is the string + # '(self.x,self.y)' and other_tuple is the string + # '(other.x,other.y)'. + + return _create_fn(name, + ('self', 'other'), + [ 'if other.__class__ is self.__class__:', + f' return {self_tuple}{op}{other_tuple}', + 'return NotImplemented']) + + +def _hash_fn(fields): + self_tuple = _tuple_str('self', fields) + return _create_fn('__hash__', + ('self',), + [f'return hash({self_tuple})']) + + +def _is_classvar(a_type, typing): + # This test uses a typing internal class, but it's the best way to + # test if this is a ClassVar. + return type(a_type) is typing._ClassVar + + +def _is_initvar(a_type, dataclasses): + # The module we're checking against is the module we're + # currently in (dataclasses.py). + return a_type is dataclasses.InitVar + + +def _is_type(annotation, cls, a_module, a_type, is_type_predicate): + # Given a type annotation string, does it refer to a_type in + # a_module? For example, when checking that annotation denotes a + # ClassVar, then a_module is typing, and a_type is + # typing.ClassVar. + + # It's possible to look up a_module given a_type, but it involves + # looking in sys.modules (again!), and seems like a waste since + # the caller already knows a_module. + + # - annotation is a string type annotation + # - cls is the class that this annotation was found in + # - a_module is the module we want to match + # - a_type is the type in that module we want to match + # - is_type_predicate is a function called with (obj, a_module) + # that determines if obj is of the desired type. + + # Since this test does not do a local namespace lookup (and + # instead only a module (global) lookup), there are some things it + # gets wrong. + + # With string annotations, cv0 will be detected as a ClassVar: + # CV = ClassVar + # @dataclass + # class C0: + # cv0: CV + + # But in this example cv1 will not be detected as a ClassVar: + # @dataclass + # class C1: + # CV = ClassVar + # cv1: CV + + # In C1, the code in this function (_is_type) will look up "CV" in + # the module and not find it, so it will not consider cv1 as a + # ClassVar. This is a fairly obscure corner case, and the best + # way to fix it would be to eval() the string "CV" with the + # correct global and local namespaces. However that would involve + # a eval() penalty for every single field of every dataclass + # that's defined. It was judged not worth it. + + match = _MODULE_IDENTIFIER_RE.match(annotation) + if match: + ns = None + module_name = match.group(1) + if not module_name: + # No module name, assume the class's module did + # "from dataclasses import InitVar". + ns = sys.modules.get(cls.__module__).__dict__ + else: + # Look up module_name in the class's module. + module = sys.modules.get(cls.__module__) + if module and module.__dict__.get(module_name) is a_module: + ns = sys.modules.get(a_type.__module__).__dict__ + if ns and is_type_predicate(ns.get(match.group(2)), a_module): + return True + return False + + +def _get_field(cls, a_name, a_type): + # Return a Field object for this field name and type. ClassVars + # and InitVars are also returned, but marked as such (see + # f._field_type). + + # If the default value isn't derived from Field, then it's only a + # normal default value. Convert it to a Field(). + default = getattr(cls, a_name, MISSING) + if isinstance(default, Field): + f = default + else: + if isinstance(default, types.MemberDescriptorType): + # This is a field in __slots__, so it has no default value. + default = MISSING + f = field(default=default) + + # Only at this point do we know the name and the type. Set them. + f.name = a_name + f.type = a_type + + # Assume it's a normal field until proven otherwise. We're next + # going to decide if it's a ClassVar or InitVar, everything else + # is just a normal field. + f._field_type = _FIELD + + # In addition to checking for actual types here, also check for + # string annotations. get_type_hints() won't always work for us + # (see https://github.com/python/typing/issues/508 for example), + # plus it's expensive and would require an eval for every stirng + # annotation. So, make a best effort to see if this is a ClassVar + # or InitVar using regex's and checking that the thing referenced + # is actually of the correct type. + + # For the complete discussion, see https://bugs.python.org/issue33453 + + # If typing has not been imported, then it's impossible for any + # annotation to be a ClassVar. So, only look for ClassVar if + # typing has been imported by any module (not necessarily cls's + # module). + typing = sys.modules.get('typing') + if typing: + if (_is_classvar(a_type, typing) + or (isinstance(f.type, str) + and _is_type(f.type, cls, typing, typing.ClassVar, + _is_classvar))): + f._field_type = _FIELD_CLASSVAR + + # If the type is InitVar, or if it's a matching string annotation, + # then it's an InitVar. + if f._field_type is _FIELD: + # The module we're checking against is the module we're + # currently in (dataclasses.py). + dataclasses = sys.modules[__name__] + if (_is_initvar(a_type, dataclasses) + or (isinstance(f.type, str) + and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, + _is_initvar))): + f._field_type = _FIELD_INITVAR + + # Validations for individual fields. This is delayed until now, + # instead of in the Field() constructor, since only here do we + # know the field name, which allows for better error reporting. + + # Special restrictions for ClassVar and InitVar. + if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): + if f.default_factory is not MISSING: + raise TypeError(f'field {f.name} cannot have a ' + 'default factory') + # Should I check for other field settings? default_factory + # seems the most serious to check for. Maybe add others. For + # example, how about init=False (or really, + # init=)? It makes no sense for + # ClassVar and InitVar to specify init=. + + # For real fields, disallow mutable defaults for known types. + if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): + raise ValueError(f'mutable default {type(f.default)} for field ' + f'{f.name} is not allowed: use default_factory') + + return f + + +def _set_new_attribute(cls, name, value): + # Never overwrites an existing attribute. Returns True if the + # attribute already exists. + if name in cls.__dict__: + return True + setattr(cls, name, value) + return False + + +# Decide if/how we're going to create a hash function. Key is +# (unsafe_hash, eq, frozen, does-hash-exist). Value is the action to +# take. The common case is to do nothing, so instead of providing a +# function that is a no-op, use None to signify that. + +def _hash_set_none(cls, fields): + return None + +def _hash_add(cls, fields): + flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] + return _hash_fn(flds) + +def _hash_exception(cls, fields): + # Raise an exception. + raise TypeError(f'Cannot overwrite attribute __hash__ ' + f'in class {cls.__name__}') + +# +# +-------------------------------------- unsafe_hash? +# | +------------------------------- eq? +# | | +------------------------ frozen? +# | | | +---------------- has-explicit-hash? +# | | | | +# | | | | +------- action +# | | | | | +# v v v v v +_hash_action = {(False, False, False, False): None, + (False, False, False, True ): None, + (False, False, True, False): None, + (False, False, True, True ): None, + (False, True, False, False): _hash_set_none, + (False, True, False, True ): None, + (False, True, True, False): _hash_add, + (False, True, True, True ): None, + (True, False, False, False): _hash_add, + (True, False, False, True ): _hash_exception, + (True, False, True, False): _hash_add, + (True, False, True, True ): _hash_exception, + (True, True, False, False): _hash_add, + (True, True, False, True ): _hash_exception, + (True, True, True, False): _hash_add, + (True, True, True, True ): _hash_exception, + } +# See https://bugs.python.org/issue32929#msg312829 for an if-statement +# version of this table. + + +def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): + # Now that dicts retain insertion order, there's no reason to use + # an ordered dict. I am leveraging that ordering here, because + # derived class fields overwrite base class fields, but the order + # is defined by the base class, which is found first. + fields = {} + + setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, + unsafe_hash, frozen)) + + # Find our base classes in reverse MRO order, and exclude + # ourselves. In reversed order so that more derived classes + # override earlier field definitions in base classes. As long as + # we're iterating over them, see if any are frozen. + any_frozen_base = False + has_dataclass_bases = False + for b in cls.__mro__[-1:0:-1]: + # Only process classes that have been processed by our + # decorator. That is, they have a _FIELDS attribute. + base_fields = getattr(b, _FIELDS, None) + if base_fields: + has_dataclass_bases = True + for f in base_fields.values(): + fields[f.name] = f + if getattr(b, _PARAMS).frozen: + any_frozen_base = True + + # Annotations that are defined in this class (not in base + # classes). If __annotations__ isn't present, then this class + # adds no new annotations. We use this to compute fields that are + # added by this class. + # + # Fields are found from cls_annotations, which is guaranteed to be + # ordered. Default values are from class attributes, if a field + # has a default. If the default value is a Field(), then it + # contains additional info beyond (and possibly including) the + # actual default value. Pseudo-fields ClassVars and InitVars are + # included, despite the fact that they're not real fields. That's + # dealt with later. + cls_annotations = cls.__dict__.get('__annotations__', {}) + + # Now find fields in our class. While doing so, validate some + # things, and set the default values (as class attributes) where + # we can. + cls_fields = [_get_field(cls, name, type) + for name, type in cls_annotations.items()] + for f in cls_fields: + fields[f.name] = f + + # If the class attribute (which is the default value for this + # field) exists and is of type 'Field', replace it with the + # real default. This is so that normal class introspection + # sees a real default value, not a Field. + if isinstance(getattr(cls, f.name, None), Field): + if f.default is MISSING: + # If there's no default, delete the class attribute. + # This happens if we specify field(repr=False), for + # example (that is, we specified a field object, but + # no default value). Also if we're using a default + # factory. The class attribute should not be set at + # all in the post-processed class. + delattr(cls, f.name) + else: + setattr(cls, f.name, f.default) + + # Do we have any Field members that don't also have annotations? + for name, value in cls.__dict__.items(): + if isinstance(value, Field) and not name in cls_annotations: + raise TypeError(f'{name!r} is a field but has no type annotation') + + # Check rules that apply if we are derived from any dataclasses. + if has_dataclass_bases: + # Raise an exception if any of our bases are frozen, but we're not. + if any_frozen_base and not frozen: + raise TypeError('cannot inherit non-frozen dataclass from a ' + 'frozen one') + + # Raise an exception if we're frozen, but none of our bases are. + if not any_frozen_base and frozen: + raise TypeError('cannot inherit frozen dataclass from a ' + 'non-frozen one') + + # Remember all of the fields on our class (including bases). This + # also marks this class as being a dataclass. + setattr(cls, _FIELDS, fields) + + # Was this class defined with an explicit __hash__? Note that if + # __eq__ is defined in this class, then python will automatically + # set __hash__ to None. This is a heuristic, as it's possible + # that such a __hash__ == None was not auto-generated, but it + # close enough. + class_hash = cls.__dict__.get('__hash__', MISSING) + has_explicit_hash = not (class_hash is MISSING or + (class_hash is None and '__eq__' in cls.__dict__)) + + # If we're generating ordering methods, we must be generating the + # eq methods. + if order and not eq: + raise ValueError('eq must be true if order is true') + + if init: + # Does this class have a post-init function? + has_post_init = hasattr(cls, _POST_INIT_NAME) + + # Include InitVars and regular fields (so, not ClassVars). + flds = [f for f in fields.values() + if f._field_type in (_FIELD, _FIELD_INITVAR)] + _set_new_attribute(cls, '__init__', + _init_fn(flds, + frozen, + has_post_init, + # The name to use for the "self" + # param in __init__. Use "self" + # if possible. + '__dataclass_self__' if 'self' in fields + else 'self', + )) + + # Get the fields as a list, and include only real fields. This is + # used in all of the following methods. + field_list = [f for f in fields.values() if f._field_type is _FIELD] + + if repr: + flds = [f for f in field_list if f.repr] + _set_new_attribute(cls, '__repr__', _repr_fn(flds)) + + if eq: + # Create _eq__ method. There's no need for a __ne__ method, + # since python will call __eq__ and negate it. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + _set_new_attribute(cls, '__eq__', + _cmp_fn('__eq__', '==', + self_tuple, other_tuple)) + + if order: + # Create and set the ordering methods. + flds = [f for f in field_list if f.compare] + self_tuple = _tuple_str('self', flds) + other_tuple = _tuple_str('other', flds) + for name, op in [('__lt__', '<'), + ('__le__', '<='), + ('__gt__', '>'), + ('__ge__', '>='), + ]: + if _set_new_attribute(cls, name, + _cmp_fn(name, op, self_tuple, other_tuple)): + raise TypeError(f'Cannot overwrite attribute {name} ' + f'in class {cls.__name__}. Consider using ' + 'functools.total_ordering') + + if frozen: + for fn in _frozen_get_del_attr(cls, field_list): + if _set_new_attribute(cls, fn.__name__, fn): + raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' + f'in class {cls.__name__}') + + # Decide if/how we're going to create a hash function. + hash_action = _hash_action[bool(unsafe_hash), + bool(eq), + bool(frozen), + has_explicit_hash] + if hash_action: + # No need to call _set_new_attribute here, since by the time + # we're here the overwriting is unconditional. + cls.__hash__ = hash_action(cls, field_list) + + if not getattr(cls, '__doc__'): + # Create a class doc-string. + cls.__doc__ = (cls.__name__ + + str(inspect.signature(cls)).replace(' -> None', '')) + + return cls + + +# _cls should never be specified by keyword, so start it with an +# underscore. The presence of _cls is used to detect if this +# decorator is being called with parameters or not. +def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False, + unsafe_hash=False, frozen=False): + """Returns the same class as was passed in, with dunder methods + added based on the fields defined in the class. + + Examines PEP 526 __annotations__ to determine fields. + + If init is true, an __init__() method is added to the class. If + repr is true, a __repr__() method is added. If order is true, rich + comparison dunder methods are added. If unsafe_hash is true, a + __hash__() method function is added. If frozen is true, fields may + not be assigned to after instance creation. + """ + + def wrap(cls): + return _process_class(cls, init, repr, eq, order, unsafe_hash, frozen) + + # See if we're being called as @dataclass or @dataclass(). + if _cls is None: + # We're called with parens. + return wrap + + # We're called as @dataclass without parens. + return wrap(_cls) + + +def fields(class_or_instance): + """Return a tuple describing the fields of this dataclass. + + Accepts a dataclass or an instance of one. Tuple elements are of + type Field. + """ + + # Might it be worth caching this, per class? + try: + fields = getattr(class_or_instance, _FIELDS) + except AttributeError: + raise TypeError('must be called with a dataclass type or instance') + + # Exclude pseudo-fields. Note that fields is sorted by insertion + # order, so the order of the tuple is as the fields were defined. + return tuple(f for f in fields.values() if f._field_type is _FIELD) + + +def _is_dataclass_instance(obj): + """Returns True if obj is an instance of a dataclass.""" + return not isinstance(obj, type) and hasattr(obj, _FIELDS) + + +def is_dataclass(obj): + """Returns True if obj is a dataclass or an instance of a + dataclass.""" + return hasattr(obj, _FIELDS) + + +def asdict(obj, *, dict_factory=dict): + """Return the fields of a dataclass instance as a new dictionary mapping + field names to field values. + + Example usage: + + @dataclass + class C: + x: int + y: int + + c = C(1, 2) + assert asdict(c) == {'x': 1, 'y': 2} + + If given, 'dict_factory' will be used instead of built-in dict. + The function applies recursively to field values that are + dataclass instances. This will also look into built-in containers: + tuples, lists, and dicts. + """ + if not _is_dataclass_instance(obj): + raise TypeError("asdict() should be called on dataclass instances") + return _asdict_inner(obj, dict_factory) + + +def _asdict_inner(obj, dict_factory): + if _is_dataclass_instance(obj): + result = [] + for f in fields(obj): + value = _asdict_inner(getattr(obj, f.name), dict_factory) + result.append((f.name, value)) + return dict_factory(result) + elif isinstance(obj, (list, tuple)): + return type(obj)(_asdict_inner(v, dict_factory) for v in obj) + elif isinstance(obj, dict): + return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) + for k, v in obj.items()) + else: + return copy.deepcopy(obj) + + +def astuple(obj, *, tuple_factory=tuple): + """Return the fields of a dataclass instance as a new tuple of field values. + + Example usage:: + + @dataclass + class C: + x: int + y: int + + c = C(1, 2) + assert astuple(c) == (1, 2) + + If given, 'tuple_factory' will be used instead of built-in tuple. + The function applies recursively to field values that are + dataclass instances. This will also look into built-in containers: + tuples, lists, and dicts. + """ + + if not _is_dataclass_instance(obj): + raise TypeError("astuple() should be called on dataclass instances") + return _astuple_inner(obj, tuple_factory) + + +def _astuple_inner(obj, tuple_factory): + if _is_dataclass_instance(obj): + result = [] + for f in fields(obj): + value = _astuple_inner(getattr(obj, f.name), tuple_factory) + result.append(value) + return tuple_factory(result) + elif isinstance(obj, (list, tuple)): + return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) + elif isinstance(obj, dict): + return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) + for k, v in obj.items()) + else: + return copy.deepcopy(obj) + + +def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, + repr=True, eq=True, order=False, unsafe_hash=False, + frozen=False): + """Return a new dynamically created dataclass. + + The dataclass name will be 'cls_name'. 'fields' is an iterable + of either (name), (name, type) or (name, type, Field) objects. If type is + omitted, use the string 'typing.Any'. Field objects are created by + the equivalent of calling 'field(name, type [, Field-info])'. + + C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,)) + + is equivalent to: + + @dataclass + class C(Base): + x: 'typing.Any' + y: int + z: int = field(init=False) + + For the bases and namespace parameters, see the builtin type() function. + + The parameters init, repr, eq, order, unsafe_hash, and frozen are passed to + dataclass(). + """ + + if namespace is None: + namespace = {} + else: + # Copy namespace since we're going to mutate it. + namespace = namespace.copy() + + # While we're looking through the field names, validate that they + # are identifiers, are not keywords, and not duplicates. + seen = set() + anns = {} + for item in fields: + if isinstance(item, str): + name = item + tp = 'typing.Any' + elif len(item) == 2: + name, tp, = item + elif len(item) == 3: + name, tp, spec = item + namespace[name] = spec + else: + raise TypeError(f'Invalid field: {item!r}') + + if not isinstance(name, str) or not name.isidentifier(): + raise TypeError(f'Field names must be valid identifers: {name!r}') + if keyword.iskeyword(name): + raise TypeError(f'Field names must not be keywords: {name!r}') + if name in seen: + raise TypeError(f'Field name duplicated: {name!r}') + + seen.add(name) + anns[name] = tp + + namespace['__annotations__'] = anns + # We use `types.new_class()` instead of simply `type()` to allow dynamic creation + # of generic dataclassses. + cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) + return dataclass(cls, init=init, repr=repr, eq=eq, order=order, + unsafe_hash=unsafe_hash, frozen=frozen) + + +def replace(obj, **changes): + """Return a new object replacing specified fields with new values. + + This is especially useful for frozen classes. Example usage: + + @dataclass(frozen=True) + class C: + x: int + y: int + + c = C(1, 2) + c1 = replace(c, x=3) + assert c1.x == 3 and c1.y == 2 + """ + + # We're going to mutate 'changes', but that's okay because it's a + # new dict, even if called with 'replace(obj, **my_changes)'. + + if not _is_dataclass_instance(obj): + raise TypeError("replace() should be called on dataclass instances") + + # It's an error to have init=False fields in 'changes'. + # If a field is not in 'changes', read its value from the provided obj. + + for f in getattr(obj, _FIELDS).values(): + if not f.init: + # Error if this field is specified in changes. + if f.name in changes: + raise ValueError(f'field {f.name} is declared with ' + 'init=False, it cannot be specified with ' + 'replace()') + continue + + if f.name not in changes: + changes[f.name] = getattr(obj, f.name) + + # Create the new object, which calls __init__() and + # __post_init__() (if defined), using all of the init fields we've + # added and/or left in 'changes'. If there are values supplied in + # changes that aren't fields, this will correctly raise a + # TypeError. + return obj.__class__(**changes) diff --git a/graphene/types/base.py b/graphene/types/base.py index 79907b4d9..9ab9badb5 100644 --- a/graphene/types/base.py +++ b/graphene/types/base.py @@ -1,6 +1,6 @@ from typing import Type -from ..utils.subclass_with_meta import SubclassWithMeta +from ..utils.subclass_with_meta import SubclassWithMeta, SubclassWithMeta_Meta from ..utils.trim_docstring import trim_docstring @@ -26,6 +26,8 @@ def __repr__(self): return "<{} name={}>".format(self.__class__.__name__, repr(self.name)) +BaseTypeMeta = SubclassWithMeta_Meta + class BaseType(SubclassWithMeta): @classmethod def create_type(cls, class_name, **options): diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index cca23d111..c2f5fbd02 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -1,8 +1,10 @@ -from .base import BaseOptions, BaseType +from .base import BaseOptions, BaseType, BaseTypeMeta from .field import Field from .interface import Interface from .utils import yank_fields_from_attrs +from ..pyutils.dataclasses import make_dataclass, field + # For static type checking with Mypy MYPY = False if MYPY: @@ -14,7 +16,24 @@ class ObjectTypeOptions(BaseOptions): interfaces = () # type: Iterable[Type[Interface]] -class ObjectType(BaseType): +class ObjectTypeMeta(BaseTypeMeta): + def __new__(cls, name, bases, namespace): + # We create this type, to then overload it with the dataclass attrs + class InterObjectType: + pass + + base_cls = super().__new__(cls, name, (InterObjectType, ) + bases, namespace) + if base_cls._meta: + fields = [ + (key, 'typing.Any', field(default=field_value.default_value if isinstance(field_value, Field) else None)) + for key, field_value in base_cls._meta.fields.items() + ] + dataclass = make_dataclass(name, fields, bases=()) + InterObjectType.__init__ = dataclass.__init__ + return base_cls + + +class ObjectType(BaseType, metaclass=ObjectTypeMeta): """ Object Type Definition diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 25025e4d1..f2f3fcdf3 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -83,6 +83,10 @@ class MyObjectType(ObjectType): def test_generate_objecttype_with_private_attributes(): class MyObjectType(ObjectType): + def __init__(self, _private_state=None, **kwargs): + self._private_state = _private_state + super().__init__(**kwargs) + _private_state = None assert "_private_state" not in MyObjectType._meta.fields @@ -173,17 +177,17 @@ def test_objecttype_as_container_all_kwargs(): def test_objecttype_as_container_extra_args(): - with raises(IndexError) as excinfo: + with raises(TypeError) as excinfo: Container("1", "2", "3") - assert "Number of args exceeds number of fields" == str(excinfo.value) + assert "__init__() takes from 1 to 3 positional arguments but 4 were given" == str(excinfo.value) def test_objecttype_as_container_invalid_kwargs(): with raises(TypeError) as excinfo: Container(unexisting_field="3") - assert "'unexisting_field' is an invalid keyword argument for Container" == str( + assert "__init__() got an unexpected keyword argument 'unexisting_field'" == str( excinfo.value ) From 2b3356fd989edcb8e56af08f5a75f72b88ce571a Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 13:48:26 -0700 Subject: [PATCH 2/8] Fixed Python 3.8 integration --- graphene/types/objecttype.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index c2f5fbd02..bf2d0b36f 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -3,7 +3,10 @@ from .interface import Interface from .utils import yank_fields_from_attrs -from ..pyutils.dataclasses import make_dataclass, field +try: + from dataclasses import make_dataclass, field +except ImportError: + from ..pyutils.dataclasses import make_dataclass, field # For static type checking with Mypy MYPY = False From 34560ca5b32fcd611ad971ba58ec47c4f316dabb Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:06:34 -0700 Subject: [PATCH 3/8] Added repr and eq methods to ObjectType containers --- graphene/types/objecttype.py | 2 ++ graphene/types/tests/test_objecttype.py | 11 +++++++++++ 2 files changed, 13 insertions(+) diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index bf2d0b36f..7c2a3f2d9 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -33,6 +33,8 @@ class InterObjectType: ] dataclass = make_dataclass(name, fields, bases=()) InterObjectType.__init__ = dataclass.__init__ + InterObjectType.__eq__ = dataclass.__eq__ + InterObjectType.__repr__ = dataclass.__repr__ return base_cls diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index f2f3fcdf3..7346fd0e6 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -158,6 +158,17 @@ def test_objecttype_as_container_only_args(): assert container.field1 == "1" assert container.field2 == "2" +def test_objecttype_repr(): + container = Container("1", "2") + assert repr(container) == "Container(field1='1', field2='2')" + +def test_objecttype_eq(): + container1 = Container("1", "2") + container2 = Container("1", "2") + container3 = Container("2", "3") + assert container1 == container1 + assert container1 == container2 + assert container2 != container3 def test_objecttype_as_container_args_kwargs(): container = Container("1", field2="2") From 05d08df95d44c969bf7858f5f62f1e9a8e1578d4 Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:07:26 -0700 Subject: [PATCH 4/8] Reformatted code --- graphene/types/base.py | 1 + graphene/types/objecttype.py | 12 ++++++++++-- graphene/types/tests/test_objecttype.py | 7 ++++++- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/graphene/types/base.py b/graphene/types/base.py index 9ab9badb5..64041b7e4 100644 --- a/graphene/types/base.py +++ b/graphene/types/base.py @@ -28,6 +28,7 @@ def __repr__(self): BaseTypeMeta = SubclassWithMeta_Meta + class BaseType(SubclassWithMeta): @classmethod def create_type(cls, class_name, **options): diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 7c2a3f2d9..0d0999e30 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -25,10 +25,18 @@ def __new__(cls, name, bases, namespace): class InterObjectType: pass - base_cls = super().__new__(cls, name, (InterObjectType, ) + bases, namespace) + base_cls = super().__new__(cls, name, (InterObjectType,) + bases, namespace) if base_cls._meta: fields = [ - (key, 'typing.Any', field(default=field_value.default_value if isinstance(field_value, Field) else None)) + ( + key, + "typing.Any", + field( + default=field_value.default_value + if isinstance(field_value, Field) + else None + ), + ) for key, field_value in base_cls._meta.fields.items() ] dataclass = make_dataclass(name, fields, bases=()) diff --git a/graphene/types/tests/test_objecttype.py b/graphene/types/tests/test_objecttype.py index 7346fd0e6..85a46867e 100644 --- a/graphene/types/tests/test_objecttype.py +++ b/graphene/types/tests/test_objecttype.py @@ -158,10 +158,12 @@ def test_objecttype_as_container_only_args(): assert container.field1 == "1" assert container.field2 == "2" + def test_objecttype_repr(): container = Container("1", "2") assert repr(container) == "Container(field1='1', field2='2')" + def test_objecttype_eq(): container1 = Container("1", "2") container2 = Container("1", "2") @@ -170,6 +172,7 @@ def test_objecttype_eq(): assert container1 == container2 assert container2 != container3 + def test_objecttype_as_container_args_kwargs(): container = Container("1", field2="2") assert container.field1 == "1" @@ -191,7 +194,9 @@ def test_objecttype_as_container_extra_args(): with raises(TypeError) as excinfo: Container("1", "2", "3") - assert "__init__() takes from 1 to 3 positional arguments but 4 were given" == str(excinfo.value) + assert "__init__() takes from 1 to 3 positional arguments but 4 were given" == str( + excinfo.value + ) def test_objecttype_as_container_invalid_kwargs(): From cfff16977cb6c3131a8aaa2aea5564141d4744ca Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:09:42 -0700 Subject: [PATCH 5/8] Fixed mypy issue --- graphene/types/objecttype.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index 0d0999e30..de11d77cd 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -6,7 +6,7 @@ try: from dataclasses import make_dataclass, field except ImportError: - from ..pyutils.dataclasses import make_dataclass, field + from ..pyutils.dataclasses import make_dataclass, field # type: ignore # For static type checking with Mypy MYPY = False From 015bfa3452e75c346beb9033a9546db9c44427ef Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:22:47 -0700 Subject: [PATCH 6/8] Removed unused __init__ for ObjectType containers --- graphene/types/objecttype.py | 42 ------------------------------------ 1 file changed, 42 deletions(-) diff --git a/graphene/types/objecttype.py b/graphene/types/objecttype.py index de11d77cd..1e4d5b0ea 100644 --- a/graphene/types/objecttype.py +++ b/graphene/types/objecttype.py @@ -159,45 +159,3 @@ def __init_subclass_with_meta__( super(ObjectType, cls).__init_subclass_with_meta__(_meta=_meta, **options) is_type_of = None - - def __init__(self, *args, **kwargs): - # ObjectType acting as container - args_len = len(args) - fields = self._meta.fields.items() - if args_len > len(fields): - # Daft, but matches old exception sans the err msg. - raise IndexError("Number of args exceeds number of fields") - fields_iter = iter(fields) - - if not kwargs: - for val, (name, field) in zip(args, fields_iter): - setattr(self, name, val) - else: - for val, (name, field) in zip(args, fields_iter): - setattr(self, name, val) - kwargs.pop(name, None) - - for name, field in fields_iter: - try: - val = kwargs.pop( - name, field.default_value if isinstance(field, Field) else None - ) - setattr(self, name, val) - except KeyError: - pass - - if kwargs: - for prop in list(kwargs): - try: - if isinstance( - getattr(self.__class__, prop), property - ) or prop.startswith("_"): - setattr(self, prop, kwargs.pop(prop)) - except AttributeError: - pass - if kwargs: - raise TypeError( - "'{}' is an invalid keyword argument for {}".format( - list(kwargs)[0], self.__class__.__name__ - ) - ) From c17103005c7c3a210a4b738b015d3aba4a57dc41 Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:25:22 -0700 Subject: [PATCH 7/8] Use black in dataclasses --- graphene/pyutils/dataclasses.py | 573 ++++++++++++++++++-------------- 1 file changed, 326 insertions(+), 247 deletions(-) diff --git a/graphene/pyutils/dataclasses.py b/graphene/pyutils/dataclasses.py index 3f2e45c8b..61f0ea388 100644 --- a/graphene/pyutils/dataclasses.py +++ b/graphene/pyutils/dataclasses.py @@ -9,21 +9,21 @@ import inspect import keyword -__all__ = ['dataclass', - 'field', - 'Field', - 'FrozenInstanceError', - 'InitVar', - 'MISSING', - - # Helper functions. - 'fields', - 'asdict', - 'astuple', - 'make_dataclass', - 'replace', - 'is_dataclass', - ] +__all__ = [ + "dataclass", + "field", + "Field", + "FrozenInstanceError", + "InitVar", + "MISSING", + # Helper functions. + "fields", + "asdict", + "astuple", + "make_dataclass", + "replace", + "is_dataclass", +] # Conditions for adding methods. The boxes indicate what action the # dataclass decorator takes. For all of these tables, when I talk @@ -152,20 +152,26 @@ # Raised when an attempt is made to modify a frozen class. -class FrozenInstanceError(AttributeError): pass +class FrozenInstanceError(AttributeError): + pass + # A sentinel object for default values to signal that a default # factory will be used. This is given a nice repr() which will appear # in the function signature of dataclasses' constructors. class _HAS_DEFAULT_FACTORY_CLASS: def __repr__(self): - return '' + return "" + + _HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS() # A sentinel object to detect if a parameter is supplied or not. Use # a class to give it a better repr. class _MISSING_TYPE: pass + + MISSING = _MISSING_TYPE() # Since most per-field metadata will be unused, create an empty @@ -176,33 +182,38 @@ class _MISSING_TYPE: class _FIELD_BASE: def __init__(self, name): self.name = name + def __repr__(self): return self.name -_FIELD = _FIELD_BASE('_FIELD') -_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR') -_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR') + + +_FIELD = _FIELD_BASE("_FIELD") +_FIELD_CLASSVAR = _FIELD_BASE("_FIELD_CLASSVAR") +_FIELD_INITVAR = _FIELD_BASE("_FIELD_INITVAR") # The name of an attribute on the class where we store the Field # objects. Also used to check if a class is a Data Class. -_FIELDS = '__dataclass_fields__' +_FIELDS = "__dataclass_fields__" # The name of an attribute on the class that stores the parameters to # @dataclass. -_PARAMS = '__dataclass_params__' +_PARAMS = "__dataclass_params__" # The name of the function, that if it exists, is called at the end of # __init__. -_POST_INIT_NAME = '__post_init__' +_POST_INIT_NAME = "__post_init__" # String regex that string annotations for ClassVar or InitVar must match. # Allows "identifier.identifier[" or "identifier[". # https://bugs.python.org/issue33453 for details. -_MODULE_IDENTIFIER_RE = re.compile(r'^(?:\s*(\w+)\s*\.)?\s*(\w+)') +_MODULE_IDENTIFIER_RE = re.compile(r"^(?:\s*(\w+)\s*\.)?\s*(\w+)") + class _InitVarMeta(type): def __getitem__(self, params): return self + class InitVar(metaclass=_InitVarMeta): pass @@ -218,20 +229,20 @@ class InitVar(metaclass=_InitVarMeta): # When cls._FIELDS is filled in with a list of Field objects, the name # and type fields will have been populated. class Field: - __slots__ = ('name', - 'type', - 'default', - 'default_factory', - 'repr', - 'hash', - 'init', - 'compare', - 'metadata', - '_field_type', # Private: not to be used by user code. - ) - - def __init__(self, default, default_factory, init, repr, hash, compare, - metadata): + __slots__ = ( + "name", + "type", + "default", + "default_factory", + "repr", + "hash", + "init", + "compare", + "metadata", + "_field_type", # Private: not to be used by user code. + ) + + def __init__(self, default, default_factory, init, repr, hash, compare, metadata): self.name = None self.type = None self.default = default @@ -240,24 +251,28 @@ def __init__(self, default, default_factory, init, repr, hash, compare, self.repr = repr self.hash = hash self.compare = compare - self.metadata = (_EMPTY_METADATA - if metadata is None or len(metadata) == 0 else - types.MappingProxyType(metadata)) + self.metadata = ( + _EMPTY_METADATA + if metadata is None or len(metadata) == 0 + else types.MappingProxyType(metadata) + ) self._field_type = None def __repr__(self): - return ('Field(' - f'name={self.name!r},' - f'type={self.type!r},' - f'default={self.default!r},' - f'default_factory={self.default_factory!r},' - f'init={self.init!r},' - f'repr={self.repr!r},' - f'hash={self.hash!r},' - f'compare={self.compare!r},' - f'metadata={self.metadata!r},' - f'_field_type={self._field_type}' - ')') + return ( + "Field(" + f"name={self.name!r}," + f"type={self.type!r}," + f"default={self.default!r}," + f"default_factory={self.default_factory!r}," + f"init={self.init!r}," + f"repr={self.repr!r}," + f"hash={self.hash!r}," + f"compare={self.compare!r}," + f"metadata={self.metadata!r}," + f"_field_type={self._field_type}" + ")" + ) # This is used to support the PEP 487 __set_name__ protocol in the # case where we're using a field that contains a descriptor as a @@ -268,7 +283,7 @@ def __repr__(self): # with the default value, so the end result is a descriptor that # had __set_name__ called on it at the right time. def __set_name__(self, owner, name): - func = getattr(type(self.default), '__set_name__', None) + func = getattr(type(self.default), "__set_name__", None) if func: # There is a __set_name__ method on the descriptor, call # it. @@ -276,13 +291,14 @@ def __set_name__(self, owner, name): class _DataclassParams: - __slots__ = ('init', - 'repr', - 'eq', - 'order', - 'unsafe_hash', - 'frozen', - ) + __slots__ = ( + "init", + "repr", + "eq", + "order", + "unsafe_hash", + "frozen", + ) def __init__(self, init, repr, eq, order, unsafe_hash, frozen): self.init = init @@ -293,21 +309,31 @@ def __init__(self, init, repr, eq, order, unsafe_hash, frozen): self.frozen = frozen def __repr__(self): - return ('_DataclassParams(' - f'init={self.init!r},' - f'repr={self.repr!r},' - f'eq={self.eq!r},' - f'order={self.order!r},' - f'unsafe_hash={self.unsafe_hash!r},' - f'frozen={self.frozen!r}' - ')') + return ( + "_DataclassParams(" + f"init={self.init!r}," + f"repr={self.repr!r}," + f"eq={self.eq!r}," + f"order={self.order!r}," + f"unsafe_hash={self.unsafe_hash!r}," + f"frozen={self.frozen!r}" + ")" + ) # This function is used instead of exposing Field creation directly, # so that a type checker can be told (via overloads) that this is a # function whose type depends on its parameters. -def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, - hash=None, compare=True, metadata=None): +def field( + *, + default=MISSING, + default_factory=MISSING, + init=True, + repr=True, + hash=None, + compare=True, + metadata=None, +): """Return an object to identify dataclass fields. default is the default value of the field. default_factory is a @@ -323,9 +349,8 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True, """ if default is not MISSING and default_factory is not MISSING: - raise ValueError('cannot specify both default and default_factory') - return Field(default, default_factory, init, repr, hash, compare, - metadata) + raise ValueError("cannot specify both default and default_factory") + return Field(default, default_factory, init, repr, hash, compare, metadata) def _tuple_str(obj_name, fields): @@ -335,27 +360,26 @@ def _tuple_str(obj_name, fields): # Special case for the 0-tuple. if not fields: - return '()' + return "()" # Note the trailing comma, needed if this turns out to be a 1-tuple. return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)' -def _create_fn(name, args, body, *, globals=None, locals=None, - return_type=MISSING): +def _create_fn(name, args, body, *, globals=None, locals=None, return_type=MISSING): # Note that we mutate locals when exec() is called. Caller # beware! The only callers are internal to this module, so no # worries about external callers. if locals is None: locals = {} - return_annotation = '' + return_annotation = "" if return_type is not MISSING: - locals['_return_type'] = return_type - return_annotation = '->_return_type' - args = ','.join(args) - body = '\n'.join(f' {b}' for b in body) + locals["_return_type"] = return_type + return_annotation = "->_return_type" + args = ",".join(args) + body = "\n".join(f" {b}" for b in body) # Compute the text of the entire function. - txt = f'def {name}({args}){return_annotation}:\n{body}' + txt = f"def {name}({args}){return_annotation}:\n{body}" exec(txt, globals, locals) return locals[name] @@ -369,23 +393,25 @@ def _field_assign(frozen, name, value, self_name): # self_name is what "self" is called in this function: don't # hard-code "self", since that might be a field name. if frozen: - return f'object.__setattr__({self_name},{name!r},{value})' - return f'{self_name}.{name}={value}' + return f"object.__setattr__({self_name},{name!r},{value})" + return f"{self_name}.{name}={value}" def _field_init(f, frozen, globals, self_name): # Return the text of the line in the body of __init__ that will # initialize this field. - default_name = f'_dflt_{f.name}' + default_name = f"_dflt_{f.name}" if f.default_factory is not MISSING: if f.init: # This field has a default factory. If a parameter is # given, use it. If not, call the factory. globals[default_name] = f.default_factory - value = (f'{default_name}() ' - f'if {f.name} is _HAS_DEFAULT_FACTORY ' - f'else {f.name}') + value = ( + f"{default_name}() " + f"if {f.name} is _HAS_DEFAULT_FACTORY " + f"else {f.name}" + ) else: # This is a field that's not in the __init__ params, but # has a default factory function. It needs to be @@ -402,7 +428,7 @@ def _field_init(f, frozen, globals, self_name): # (which, after all, is why we have a factory function!). globals[default_name] = f.default_factory - value = f'{default_name}()' + value = f"{default_name}()" else: # No default factory. if f.init: @@ -435,15 +461,15 @@ def _init_param(f): if f.default is MISSING and f.default_factory is MISSING: # There's no default, and no default_factory, just output the # variable name and type. - default = '' + default = "" elif f.default is not MISSING: # There's a default, this will be the name that's used to look # it up. - default = f'=_dflt_{f.name}' + default = f"=_dflt_{f.name}" elif f.default_factory is not MISSING: # There's a factory function. Set a marker. - default = '=_HAS_DEFAULT_FACTORY' - return f'{f.name}:_type_{f.name}{default}' + default = "=_HAS_DEFAULT_FACTORY" + return f"{f.name}:_type_{f.name}{default}" def _init_fn(fields, frozen, has_post_init, self_name): @@ -461,11 +487,11 @@ def _init_fn(fields, frozen, has_post_init, self_name): if not (f.default is MISSING and f.default_factory is MISSING): seen_default = True elif seen_default: - raise TypeError(f'non-default argument {f.name!r} ' - 'follows default argument') + raise TypeError( + f"non-default argument {f.name!r} " "follows default argument" + ) - globals = {'MISSING': MISSING, - '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY} + globals = {"MISSING": MISSING, "_HAS_DEFAULT_FACTORY": _HAS_DEFAULT_FACTORY} body_lines = [] for f in fields: @@ -477,55 +503,67 @@ def _init_fn(fields, frozen, has_post_init, self_name): # Does this class have a post-init function? if has_post_init: - params_str = ','.join(f.name for f in fields - if f._field_type is _FIELD_INITVAR) - body_lines.append(f'{self_name}.{_POST_INIT_NAME}({params_str})') + params_str = ",".join(f.name for f in fields if f._field_type is _FIELD_INITVAR) + body_lines.append(f"{self_name}.{_POST_INIT_NAME}({params_str})") # If no body lines, use 'pass'. if not body_lines: - body_lines = ['pass'] + body_lines = ["pass"] - locals = {f'_type_{f.name}': f.type for f in fields} - return _create_fn('__init__', - [self_name] + [_init_param(f) for f in fields if f.init], - body_lines, - locals=locals, - globals=globals, - return_type=None) + locals = {f"_type_{f.name}": f.type for f in fields} + return _create_fn( + "__init__", + [self_name] + [_init_param(f) for f in fields if f.init], + body_lines, + locals=locals, + globals=globals, + return_type=None, + ) def _repr_fn(fields): - return _create_fn('__repr__', - ('self',), - ['return self.__class__.__qualname__ + f"(' + - ', '.join([f"{f.name}={{self.{f.name}!r}}" - for f in fields]) + - ')"']) + return _create_fn( + "__repr__", + ("self",), + [ + 'return self.__class__.__qualname__ + f"(' + + ", ".join([f"{f.name}={{self.{f.name}!r}}" for f in fields]) + + ')"' + ], + ) def _frozen_get_del_attr(cls, fields): # XXX: globals is modified on the first call to _create_fn, then # the modified version is used in the second call. Is this okay? - globals = {'cls': cls, - 'FrozenInstanceError': FrozenInstanceError} + globals = {"cls": cls, "FrozenInstanceError": FrozenInstanceError} if fields: - fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)' + fields_str = "(" + ",".join(repr(f.name) for f in fields) + ",)" else: # Special case for the zero-length tuple. - fields_str = '()' - return (_create_fn('__setattr__', - ('self', 'name', 'value'), - (f'if type(self) is cls or name in {fields_str}:', - ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', - f'super(cls, self).__setattr__(name, value)'), - globals=globals), - _create_fn('__delattr__', - ('self', 'name'), - (f'if type(self) is cls or name in {fields_str}:', - ' raise FrozenInstanceError(f"cannot delete field {name!r}")', - f'super(cls, self).__delattr__(name)'), - globals=globals), - ) + fields_str = "()" + return ( + _create_fn( + "__setattr__", + ("self", "name", "value"), + ( + f"if type(self) is cls or name in {fields_str}:", + ' raise FrozenInstanceError(f"cannot assign to field {name!r}")', + f"super(cls, self).__setattr__(name, value)", + ), + globals=globals, + ), + _create_fn( + "__delattr__", + ("self", "name"), + ( + f"if type(self) is cls or name in {fields_str}:", + ' raise FrozenInstanceError(f"cannot delete field {name!r}")', + f"super(cls, self).__delattr__(name)", + ), + globals=globals, + ), + ) def _cmp_fn(name, op, self_tuple, other_tuple): @@ -534,18 +572,20 @@ def _cmp_fn(name, op, self_tuple, other_tuple): # '(self.x,self.y)' and other_tuple is the string # '(other.x,other.y)'. - return _create_fn(name, - ('self', 'other'), - [ 'if other.__class__ is self.__class__:', - f' return {self_tuple}{op}{other_tuple}', - 'return NotImplemented']) + return _create_fn( + name, + ("self", "other"), + [ + "if other.__class__ is self.__class__:", + f" return {self_tuple}{op}{other_tuple}", + "return NotImplemented", + ], + ) def _hash_fn(fields): - self_tuple = _tuple_str('self', fields) - return _create_fn('__hash__', - ('self',), - [f'return hash({self_tuple})']) + self_tuple = _tuple_str("self", fields) + return _create_fn("__hash__", ("self",), [f"return hash({self_tuple})"]) def _is_classvar(a_type, typing): @@ -658,12 +698,12 @@ def _get_field(cls, a_name, a_type): # annotation to be a ClassVar. So, only look for ClassVar if # typing has been imported by any module (not necessarily cls's # module). - typing = sys.modules.get('typing') + typing = sys.modules.get("typing") if typing: - if (_is_classvar(a_type, typing) - or (isinstance(f.type, str) - and _is_type(f.type, cls, typing, typing.ClassVar, - _is_classvar))): + if _is_classvar(a_type, typing) or ( + isinstance(f.type, str) + and _is_type(f.type, cls, typing, typing.ClassVar, _is_classvar) + ): f._field_type = _FIELD_CLASSVAR # If the type is InitVar, or if it's a matching string annotation, @@ -672,10 +712,10 @@ def _get_field(cls, a_name, a_type): # The module we're checking against is the module we're # currently in (dataclasses.py). dataclasses = sys.modules[__name__] - if (_is_initvar(a_type, dataclasses) - or (isinstance(f.type, str) - and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, - _is_initvar))): + if _is_initvar(a_type, dataclasses) or ( + isinstance(f.type, str) + and _is_type(f.type, cls, dataclasses, dataclasses.InitVar, _is_initvar) + ): f._field_type = _FIELD_INITVAR # Validations for individual fields. This is delayed until now, @@ -685,8 +725,7 @@ def _get_field(cls, a_name, a_type): # Special restrictions for ClassVar and InitVar. if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR): if f.default_factory is not MISSING: - raise TypeError(f'field {f.name} cannot have a ' - 'default factory') + raise TypeError(f"field {f.name} cannot have a " "default factory") # Should I check for other field settings? default_factory # seems the most serious to check for. Maybe add others. For # example, how about init=False (or really, @@ -695,8 +734,10 @@ def _get_field(cls, a_name, a_type): # For real fields, disallow mutable defaults for known types. if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)): - raise ValueError(f'mutable default {type(f.default)} for field ' - f'{f.name} is not allowed: use default_factory') + raise ValueError( + f"mutable default {type(f.default)} for field " + f"{f.name} is not allowed: use default_factory" + ) return f @@ -715,17 +756,20 @@ def _set_new_attribute(cls, name, value): # take. The common case is to do nothing, so instead of providing a # function that is a no-op, use None to signify that. + def _hash_set_none(cls, fields): return None + def _hash_add(cls, fields): flds = [f for f in fields if (f.compare if f.hash is None else f.hash)] return _hash_fn(flds) + def _hash_exception(cls, fields): # Raise an exception. - raise TypeError(f'Cannot overwrite attribute __hash__ ' - f'in class {cls.__name__}') + raise TypeError(f"Cannot overwrite attribute __hash__ " f"in class {cls.__name__}") + # # +-------------------------------------- unsafe_hash? @@ -736,23 +780,24 @@ def _hash_exception(cls, fields): # | | | | +------- action # | | | | | # v v v v v -_hash_action = {(False, False, False, False): None, - (False, False, False, True ): None, - (False, False, True, False): None, - (False, False, True, True ): None, - (False, True, False, False): _hash_set_none, - (False, True, False, True ): None, - (False, True, True, False): _hash_add, - (False, True, True, True ): None, - (True, False, False, False): _hash_add, - (True, False, False, True ): _hash_exception, - (True, False, True, False): _hash_add, - (True, False, True, True ): _hash_exception, - (True, True, False, False): _hash_add, - (True, True, False, True ): _hash_exception, - (True, True, True, False): _hash_add, - (True, True, True, True ): _hash_exception, - } +_hash_action = { + (False, False, False, False): None, + (False, False, False, True): None, + (False, False, True, False): None, + (False, False, True, True): None, + (False, True, False, False): _hash_set_none, + (False, True, False, True): None, + (False, True, True, False): _hash_add, + (False, True, True, True): None, + (True, False, False, False): _hash_add, + (True, False, False, True): _hash_exception, + (True, False, True, False): _hash_add, + (True, False, True, True): _hash_exception, + (True, True, False, False): _hash_add, + (True, True, False, True): _hash_exception, + (True, True, True, False): _hash_add, + (True, True, True, True): _hash_exception, +} # See https://bugs.python.org/issue32929#msg312829 for an if-statement # version of this table. @@ -764,8 +809,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # is defined by the base class, which is found first. fields = {} - setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, - unsafe_hash, frozen)) + setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order, unsafe_hash, frozen)) # Find our base classes in reverse MRO order, and exclude # ourselves. In reversed order so that more derived classes @@ -796,13 +840,12 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # actual default value. Pseudo-fields ClassVars and InitVars are # included, despite the fact that they're not real fields. That's # dealt with later. - cls_annotations = cls.__dict__.get('__annotations__', {}) + cls_annotations = cls.__dict__.get("__annotations__", {}) # Now find fields in our class. While doing so, validate some # things, and set the default values (as class attributes) where # we can. - cls_fields = [_get_field(cls, name, type) - for name, type in cls_annotations.items()] + cls_fields = [_get_field(cls, name, type) for name, type in cls_annotations.items()] for f in cls_fields: fields[f.name] = f @@ -825,19 +868,17 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # Do we have any Field members that don't also have annotations? for name, value in cls.__dict__.items(): if isinstance(value, Field) and not name in cls_annotations: - raise TypeError(f'{name!r} is a field but has no type annotation') + raise TypeError(f"{name!r} is a field but has no type annotation") # Check rules that apply if we are derived from any dataclasses. if has_dataclass_bases: # Raise an exception if any of our bases are frozen, but we're not. if any_frozen_base and not frozen: - raise TypeError('cannot inherit non-frozen dataclass from a ' - 'frozen one') + raise TypeError("cannot inherit non-frozen dataclass from a " "frozen one") # Raise an exception if we're frozen, but none of our bases are. if not any_frozen_base and frozen: - raise TypeError('cannot inherit frozen dataclass from a ' - 'non-frozen one') + raise TypeError("cannot inherit frozen dataclass from a " "non-frozen one") # Remember all of the fields on our class (including bases). This # also marks this class as being a dataclass. @@ -848,32 +889,35 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # set __hash__ to None. This is a heuristic, as it's possible # that such a __hash__ == None was not auto-generated, but it # close enough. - class_hash = cls.__dict__.get('__hash__', MISSING) - has_explicit_hash = not (class_hash is MISSING or - (class_hash is None and '__eq__' in cls.__dict__)) + class_hash = cls.__dict__.get("__hash__", MISSING) + has_explicit_hash = not ( + class_hash is MISSING or (class_hash is None and "__eq__" in cls.__dict__) + ) # If we're generating ordering methods, we must be generating the # eq methods. if order and not eq: - raise ValueError('eq must be true if order is true') + raise ValueError("eq must be true if order is true") if init: # Does this class have a post-init function? has_post_init = hasattr(cls, _POST_INIT_NAME) # Include InitVars and regular fields (so, not ClassVars). - flds = [f for f in fields.values() - if f._field_type in (_FIELD, _FIELD_INITVAR)] - _set_new_attribute(cls, '__init__', - _init_fn(flds, - frozen, - has_post_init, - # The name to use for the "self" - # param in __init__. Use "self" - # if possible. - '__dataclass_self__' if 'self' in fields - else 'self', - )) + flds = [f for f in fields.values() if f._field_type in (_FIELD, _FIELD_INITVAR)] + _set_new_attribute( + cls, + "__init__", + _init_fn( + flds, + frozen, + has_post_init, + # The name to use for the "self" + # param in __init__. Use "self" + # if possible. + "__dataclass_self__" if "self" in fields else "self", + ), + ) # Get the fields as a list, and include only real fields. This is # used in all of the following methods. @@ -881,54 +925,58 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): if repr: flds = [f for f in field_list if f.repr] - _set_new_attribute(cls, '__repr__', _repr_fn(flds)) + _set_new_attribute(cls, "__repr__", _repr_fn(flds)) if eq: # Create _eq__ method. There's no need for a __ne__ method, # since python will call __eq__ and negate it. flds = [f for f in field_list if f.compare] - self_tuple = _tuple_str('self', flds) - other_tuple = _tuple_str('other', flds) - _set_new_attribute(cls, '__eq__', - _cmp_fn('__eq__', '==', - self_tuple, other_tuple)) + self_tuple = _tuple_str("self", flds) + other_tuple = _tuple_str("other", flds) + _set_new_attribute( + cls, "__eq__", _cmp_fn("__eq__", "==", self_tuple, other_tuple) + ) if order: # Create and set the ordering methods. flds = [f for f in field_list if f.compare] - self_tuple = _tuple_str('self', flds) - other_tuple = _tuple_str('other', flds) - for name, op in [('__lt__', '<'), - ('__le__', '<='), - ('__gt__', '>'), - ('__ge__', '>='), - ]: - if _set_new_attribute(cls, name, - _cmp_fn(name, op, self_tuple, other_tuple)): - raise TypeError(f'Cannot overwrite attribute {name} ' - f'in class {cls.__name__}. Consider using ' - 'functools.total_ordering') + self_tuple = _tuple_str("self", flds) + other_tuple = _tuple_str("other", flds) + for name, op in [ + ("__lt__", "<"), + ("__le__", "<="), + ("__gt__", ">"), + ("__ge__", ">="), + ]: + if _set_new_attribute( + cls, name, _cmp_fn(name, op, self_tuple, other_tuple) + ): + raise TypeError( + f"Cannot overwrite attribute {name} " + f"in class {cls.__name__}. Consider using " + "functools.total_ordering" + ) if frozen: for fn in _frozen_get_del_attr(cls, field_list): if _set_new_attribute(cls, fn.__name__, fn): - raise TypeError(f'Cannot overwrite attribute {fn.__name__} ' - f'in class {cls.__name__}') + raise TypeError( + f"Cannot overwrite attribute {fn.__name__} " + f"in class {cls.__name__}" + ) # Decide if/how we're going to create a hash function. - hash_action = _hash_action[bool(unsafe_hash), - bool(eq), - bool(frozen), - has_explicit_hash] + hash_action = _hash_action[ + bool(unsafe_hash), bool(eq), bool(frozen), has_explicit_hash + ] if hash_action: # No need to call _set_new_attribute here, since by the time # we're here the overwriting is unconditional. cls.__hash__ = hash_action(cls, field_list) - if not getattr(cls, '__doc__'): + if not getattr(cls, "__doc__"): # Create a class doc-string. - cls.__doc__ = (cls.__name__ + - str(inspect.signature(cls)).replace(' -> None', '')) + cls.__doc__ = cls.__name__ + str(inspect.signature(cls)).replace(" -> None", "") return cls @@ -936,8 +984,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen): # _cls should never be specified by keyword, so start it with an # underscore. The presence of _cls is used to detect if this # decorator is being called with parameters or not. -def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False, - unsafe_hash=False, frozen=False): +def dataclass( + _cls=None, + *, + init=True, + repr=True, + eq=True, + order=False, + unsafe_hash=False, + frozen=False, +): """Returns the same class as was passed in, with dunder methods added based on the fields defined in the class. @@ -973,7 +1029,7 @@ def fields(class_or_instance): try: fields = getattr(class_or_instance, _FIELDS) except AttributeError: - raise TypeError('must be called with a dataclass type or instance') + raise TypeError("must be called with a dataclass type or instance") # Exclude pseudo-fields. Note that fields is sorted by insertion # order, so the order of the tuple is as the fields were defined. @@ -1025,8 +1081,10 @@ def _asdict_inner(obj, dict_factory): elif isinstance(obj, (list, tuple)): return type(obj)(_asdict_inner(v, dict_factory) for v in obj) elif isinstance(obj, dict): - return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) - for k, v in obj.items()) + return type(obj)( + (_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory)) + for k, v in obj.items() + ) else: return copy.deepcopy(obj) @@ -1065,15 +1123,27 @@ def _astuple_inner(obj, tuple_factory): elif isinstance(obj, (list, tuple)): return type(obj)(_astuple_inner(v, tuple_factory) for v in obj) elif isinstance(obj, dict): - return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) - for k, v in obj.items()) + return type(obj)( + (_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory)) + for k, v in obj.items() + ) else: return copy.deepcopy(obj) -def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True, - repr=True, eq=True, order=False, unsafe_hash=False, - frozen=False): +def make_dataclass( + cls_name, + fields, + *, + bases=(), + namespace=None, + init=True, + repr=True, + eq=True, + order=False, + unsafe_hash=False, + frozen=False, +): """Return a new dynamically created dataclass. The dataclass name will be 'cls_name'. 'fields' is an iterable @@ -1110,31 +1180,38 @@ class C(Base): for item in fields: if isinstance(item, str): name = item - tp = 'typing.Any' + tp = "typing.Any" elif len(item) == 2: name, tp, = item elif len(item) == 3: name, tp, spec = item namespace[name] = spec else: - raise TypeError(f'Invalid field: {item!r}') + raise TypeError(f"Invalid field: {item!r}") if not isinstance(name, str) or not name.isidentifier(): - raise TypeError(f'Field names must be valid identifers: {name!r}') + raise TypeError(f"Field names must be valid identifers: {name!r}") if keyword.iskeyword(name): - raise TypeError(f'Field names must not be keywords: {name!r}') + raise TypeError(f"Field names must not be keywords: {name!r}") if name in seen: - raise TypeError(f'Field name duplicated: {name!r}') + raise TypeError(f"Field name duplicated: {name!r}") seen.add(name) anns[name] = tp - namespace['__annotations__'] = anns + namespace["__annotations__"] = anns # We use `types.new_class()` instead of simply `type()` to allow dynamic creation # of generic dataclassses. cls = types.new_class(cls_name, bases, {}, lambda ns: ns.update(namespace)) - return dataclass(cls, init=init, repr=repr, eq=eq, order=order, - unsafe_hash=unsafe_hash, frozen=frozen) + return dataclass( + cls, + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen, + ) def replace(obj, **changes): @@ -1165,9 +1242,11 @@ class C: if not f.init: # Error if this field is specified in changes. if f.name in changes: - raise ValueError(f'field {f.name} is declared with ' - 'init=False, it cannot be specified with ' - 'replace()') + raise ValueError( + f"field {f.name} is declared with " + "init=False, it cannot be specified with " + "replace()" + ) continue if f.name not in changes: From 5549927f5207d443a92c92c4acf1030cea2b601e Mon Sep 17 00:00:00 2001 From: Syrus Date: Sat, 14 Mar 2020 14:37:24 -0700 Subject: [PATCH 8/8] Use latest black verison on precommit --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7aa720015..c9ffc21ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,7 +18,7 @@ repos: hooks: - id: pyupgrade - repo: https://github.com/ambv/black - rev: 19.3b0 + rev: 19.10b0 hooks: - id: black language_version: python3