diff --git a/CHANGELOG b/CHANGELOG index ddf962f05..f9788e981 100644 --- a/CHANGELOG +++ b/CHANGELOG @@ -1,3 +1,13 @@ +Version 2021.10.04 + +Bug fixes: +* Call init_class instead of instantiate when adding assertIsInstance bindings. +* Use the namedtuple 'defaults' argument when generating __new__ signature. +* Do not raise a parser error for unrecognised decorators. +* Merge BaseValue.cls and BaseValue.get_class(). +* Check Instance.maybe_missing_members earlier during attribute access. +* Fix a bug in matching callables with TypeVar parameters. + Version 2021.09.27 New features and updates: diff --git a/docs/developers/annotations.md b/docs/developers/annotations.md index 8640b18c6..9b5c4d75f 100644 --- a/docs/developers/annotations.md +++ b/docs/developers/annotations.md @@ -150,7 +150,7 @@ the `param`s are types (possibly parametrised themselves) or lists of types. ### Conversion to abstract types The main annotation processing code lives in the -`annotations_util.AnnotationsUtil` class (instantiated as a member of the VM). +`annotation_utils.AnnotationUtils` class (instantiated as a member of the VM). This code has several entry points, for various annotation contexts, but the bulk of the conversion work is done in the internal method `_process_one_annotation()`. @@ -170,9 +170,9 @@ various kinds of annotations, and calling itself recursively to deal with nested annotations. The return value of `_process_one_annotation` is an `abstract.*` object that can be applied as the python type of a variable. -The various public methods in `AnnotationsUtil` cover different contexts in +The various public methods in `AnnotationUtils` cover different contexts in which we can encounter variable annotations while processing bytecode; search -for `self.annotations_util` in `vm.py` to see where each one is used. +for `self.annotation_utils` in `vm.py` to see where each one is used. ## Tracking local operations diff --git a/docs/developers/special_builtins.md b/docs/developers/special_builtins.md index 63f39ea2a..075850648 100644 --- a/docs/developers/special_builtins.md +++ b/docs/developers/special_builtins.md @@ -15,7 +15,7 @@ freshness: { owner: 'mdemello' reviewed: '2020-09-18' } * [Instances](#instances) * [Variables and data](#variables-and-data) - + @@ -181,7 +181,7 @@ Pytype replicates this behaviour by providing a `StaticMethod` class, whose `call` method takes in a function (specifically a variable whose binding is an `abstract.InterpreterFunction` object), and returns a `StaticMethodInstance` that wraps the original variable. `StaticMethodInstance` in turn wraps the -underlying function and provides an object whose `get_class()` method returns +underlying function and provides an object whose `cls` attribute is `special_builtins.StaticMethod` and whose `__get__` slot returns the original function. (The details of `StaticMethodInstance` don't matter too much for now, but note the two-stage process by which we have achieved the desired method diff --git a/pytype/CMakeLists.txt b/pytype/CMakeLists.txt index 89c06e194..f9bba8e33 100644 --- a/pytype/CMakeLists.txt +++ b/pytype/CMakeLists.txt @@ -28,7 +28,7 @@ py_library( .abstract .abstract_utils .analyze - .annotations_util + .annotation_utils .attribute .blocks .class_mixin @@ -122,9 +122,9 @@ py_library( py_library( NAME - annotations_util + annotation_utils SRCS - annotations_util.py + annotation_utils.py DEPS ._utils .abstract @@ -510,7 +510,7 @@ py_library( ._utils .abstract .abstract_utils - .annotations_util + .annotation_utils .attribute .blocks .class_mixin diff --git a/pytype/__version__.py b/pytype/__version__.py index f11bf175b..64870bd11 100644 --- a/pytype/__version__.py +++ b/pytype/__version__.py @@ -1,2 +1,2 @@ # pylint: skip-file -__version__ = '2021.09.27' +__version__ = '2021.10.04' diff --git a/pytype/abstract.py b/pytype/abstract.py index 69246cf07..421d65130 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -55,7 +55,10 @@ def __init__(self, name, vm): """Basic initializer for all BaseValues.""" super().__init__(vm) assert hasattr(vm, "program"), type(self) - self.cls = None + # This default cls value is used by things like Unsolvable that inherit + # directly from BaseValue. Classes and instances overwrite the default with + # a more sensible value. + self.cls = self self.name = name self.mro = self.compute_mro() self.module = None @@ -199,7 +202,7 @@ def get_special_attribute(self, unused_node, name, unused_valself): abc.load_lazy_attribute("ABCMeta") return abc.members["ABCMeta"] else: - return self.get_class().to_variable(self.vm.root_node) + return self.cls.to_variable(self.vm.root_node) return None def get_own_new(self, node, value): @@ -244,10 +247,6 @@ def register_instance(self, instance): # pylint: disable=unused-arg instance: An instance of this class (as a BaseValue) """ - def get_class(self): - """Return the class of this object. Equivalent of x.__class__ in Python.""" - raise NotImplementedError(self.__class__.__name__) - def get_instance_type(self, node=None, instance=None, seen=None, view=None): """Get the type an instance of us would have.""" return self.vm.convert.pytd_convert.value_instance_to_pytd_type( @@ -516,9 +515,6 @@ def call(self, node, func, args, alias_map=None): del func, args return node, self.to_variable(node) - def get_class(self): - return self - def instantiate(self, node, container=None): return self.to_variable(node) @@ -633,9 +629,6 @@ def update_official_name(self, name): self.name, self.name, name) self.vm.errorlog.invalid_typevar(self.vm.frames, message) - def get_class(self): - return self - def call(self, node, func, args, alias_map=None): return node, self.instantiate(node) @@ -645,13 +638,10 @@ class TypeParameterInstance(BaseValue): def __init__(self, param, instance, vm): super().__init__(param.name, vm) - self.param = param + self.cls = self.param = param self.instance = instance self.module = param.module - def get_class(self): - return self.param - def call(self, node, func, args, alias_map=None): var = self.instance.get_instance_type_parameter(self.name) if var.bindings: @@ -692,6 +682,7 @@ def __init__(self, name, vm): vm: The TypegraphVirtualMachine to use. """ super().__init__(name, vm) + self._cls = None # lazily loaded 'cls' attribute self.members = datatypes.MonitorDict() # Lazily loaded to handle recursive types. # See Instance._load_instance_type_parameters(). @@ -777,18 +768,29 @@ def argcount(self, node): return 0 def __repr__(self): - cls = " [%r]" % self.cls if self.cls else "" - return "<%s%s>" % (self.name, cls) - - def get_class(self): - # See Py_TYPE() in Include/object.h - if self.cls: - return self.cls - elif isinstance(self, InterpreterClass): + return "<%s [%r]>" % (self.name, self.cls) + + def _get_class(self): + if isinstance(self, InterpreterClass): return ParameterizedClass( self.vm.convert.type_type, {abstract_utils.T: self}, self.vm) elif isinstance(self, (AnnotationClass, class_mixin.Class)): return self.vm.convert.type_type + else: + return self.vm.convert.unsolvable + + @property + def cls(self): + if not self.vm.convert.minimally_initialized: + return self.vm.convert.unsolvable + if not self._cls: + self._cls = self.vm.convert.unsolvable # prevent infinite recursion + self._cls = self._get_class() + return self._cls + + @cls.setter + def cls(self, cls): + self._cls = cls def set_class(self, node, var): """Set the __class__ of an instance, for code that does "x.__class__ = y.""" @@ -800,11 +802,8 @@ def set_class(self, node, var): except abstract_utils.ConversionError: self.cls = self.vm.convert.unsolvable else: - if self.cls and self.cls != new_cls: + if self.cls != new_cls: self.cls = self.vm.convert.unsolvable - else: - self.cls = new_cls - new_cls.register_instance(self) return node def get_type_key(self, seen=None): @@ -816,9 +815,7 @@ def get_type_key(self, seen=None): if not seen: seen = set() seen.add(self) - key = set() - if self.cls: - key.add(self.cls) + key = {self.cls} for name, var in self.instance_type_parameters.items(): subkey = frozenset( value.data.get_default_type_key() # pylint: disable=g-long-ternary @@ -872,7 +869,7 @@ def _load_instance_type_parameters(self): @property def full_name(self): - return self.get_class().full_name + return self.cls.full_name @property def instance_type_parameters(self): @@ -1287,7 +1284,7 @@ def resolve(self, node, f_globals, f_locals): self.vm, node, f_globals, f_locals, self.expr) if errorlog: self.vm.errorlog.copy_from(errorlog.errors, self.stack) - self._type = self.vm.annotations_util.extract_annotation( + self._type = self.vm.annotation_utils.extract_annotation( node, var, None, self.stack) if self._type != self.vm.convert.unsolvable: # We may have tried to call __init__ on instances of this annotation. @@ -1386,7 +1383,7 @@ def __init__(self, name, vm, base_cls): def _sub_annotation( self, annot: BaseValue, subst: Mapping[str, BaseValue]) -> BaseValue: """Apply type parameter substitutions to an annotation.""" - # This is very similar to annotations_util.sub_one_annotation, but a couple + # This is very similar to annotation_utils.sub_one_annotation, but a couple # differences make it more convenient to maintain two separate methods: # - subst here is a str->BaseValue mapping rather than str->Variable, and it # would be wasteful to create variables just to match sub_one_annotation's @@ -1449,7 +1446,7 @@ def _get_value_info(self, inner, ellipses, allowed_ellipses=frozenset()): for k in template: v = self.base_cls.formal_type_parameters[k] if v.formal: - params = self.vm.annotations_util.get_type_parameters(v) + params = self.vm.annotation_utils.get_type_parameters(v) for param in params: # If there are too few parameters, we ignore the problem for now; # it'll be reported when _build_value checks that the lengths of @@ -1480,7 +1477,7 @@ def _validate_inner(self, template, inner, raw_inner): # For a generic type alias, we check that the number of typevars in the # alias matches the number of raw parameters provided. template_length = raw_template_length = len(set( - self.vm.annotations_util.get_type_parameters(self.base_cls))) + self.vm.annotation_utils.get_type_parameters(self.base_cls))) inner_length = raw_inner_length = len(raw_inner) base_cls = self.base_cls.base_cls else: @@ -1569,7 +1566,7 @@ def _build_value(self, node, inner, ellipses): actual = params[formal.name].instantiate(root_node) bad = self.vm.matcher(root_node).bad_matches(actual, formal) if bad: - formal = self.vm.annotations_util.sub_one_annotation( + formal = self.vm.annotation_utils.sub_one_annotation( root_node, formal, [{}]) self.vm.errorlog.bad_concrete_type( self.vm.frames, root_node, formal, actual, bad) @@ -1618,6 +1615,7 @@ def __init__(self, options, vm): super().__init__("Union", vm) assert options self.options = list(options) + self.cls = self._get_class() # TODO(rechen): Don't allow a mix of formal and non-formal types self.formal = any(t.formal for t in self.options) mixin.NestedAnnotation.init_mixin(self) @@ -1642,10 +1640,17 @@ def _unique_parameters(self): return [o.to_variable(self.vm.root_node) for o in self.options] def _get_type_params(self): - params = self.vm.annotations_util.get_type_parameters(self) + params = self.vm.annotation_utils.get_type_parameters(self) params = [x.full_name for x in params] return utils.unique_list(params) + def _get_class(self): + classes = {o.cls for o in self.options} + if len(classes) > 1: + return self.vm.convert.unsolvable + else: + return classes.pop() + def getitem_slot(self, node, slice_var): """Custom __getitem__ implementation.""" slice_content = abstract_utils.maybe_extract_tuple(slice_var) @@ -1665,7 +1670,7 @@ def getitem_slot(self, node, slice_var): else: concrete.append(value.instantiate(node)) substs = [dict(zip(params, concrete))] - new = self.vm.annotations_util.sub_one_annotation(node, self, substs) + new = self.vm.annotation_utils.sub_one_annotation(node, self, substs) return node, new.to_variable(node) def instantiate(self, node, container=None): @@ -1674,13 +1679,6 @@ def instantiate(self, node, container=None): var.PasteVariable(option.instantiate(node, container), node) return var - def get_class(self): - classes = {o.get_class() for o in self.options} - if len(classes) > 1: - return self.vm.convert.unsolvable - else: - return classes.pop() - def call(self, node, func, args, alias_map=None): var = self.vm.program.NewVariable(self.options, [], node) return self.vm.call_function(node, var, args) @@ -1820,6 +1818,7 @@ class ClassMethod(BaseValue): def __init__(self, name, method, callself, vm): super().__init__(name, vm) + self.cls = self.vm.convert.function_type self.method = method self.method.is_attribute_of_class = True # Rename to callcls to make clear that callself is the cls parameter. @@ -1830,9 +1829,6 @@ def call(self, node, func, args, alias_map=None): return self.method.call( node, func, args.replace(posargs=(self._callcls,) + args.posargs)) - def get_class(self): - return self.vm.convert.function_type - def to_bound_function(self): return BoundPyTDFunction(self._callcls, self.method) @@ -1842,15 +1838,13 @@ class StaticMethod(BaseValue): def __init__(self, name, method, _, vm): super().__init__(name, vm) + self.cls = self.vm.convert.function_type self.method = method self.signatures = self.method.signatures def call(self, *args, **kwargs): return self.method.call(*args, **kwargs) - def get_class(self): - return self.vm.convert.function_type - class Property(BaseValue): """Implements @property methods in pyi. @@ -1861,6 +1855,7 @@ class Property(BaseValue): def __init__(self, name, method, callself, vm): super().__init__(name, vm) + self.cls = self.vm.convert.function_type self.method = method self._callself = callself self.signatures = self.method.signatures @@ -1870,9 +1865,6 @@ def call(self, node, func, args, alias_map=None): args = args or function.Args(posargs=(self._callself,)) return self.method.call(node, func, args.replace(posargs=(self._callself,))) - def get_class(self): - return self.vm.convert.function_type - class PyTDFunction(Function): """A PyTD function (name + list of signatures). @@ -1947,7 +1939,7 @@ def property_get(self, callself, is_class=False): callself.instance.get_instance_type_parameter(callself.name), default=self.vm.convert.unsolvable) # callself is the instance, and we want to bind to its class. - callself = callself.get_class().to_variable(self.vm.root_node) + callself = callself.cls.to_variable(self.vm.root_node) return ClassMethod(self.name, self, callself, self.vm) elif self.kind == pytd.MethodTypes.PROPERTY and not is_class: return Property(self.name, self, callself, self.vm) @@ -2021,7 +2013,7 @@ def call(self, node, func, args, alias_map=None): # - An annotation has a type param that is not ambigious or empty # - The mutation adds a type that is not ambiguous or empty def should_check(value): - return not value.isinstance_AMBIGUOUS_OR_EMPTY() and value.cls + return not value.isinstance_AMBIGUOUS_OR_EMPTY() def compatible_with(new, existing, view): """Check whether a new type can be added to a container.""" @@ -2306,6 +2298,7 @@ def __init__(self, base_cls, formal_type_parameters, vm, template=None): assert isinstance(base_cls, (PyTDClass, InterpreterClass)) self.base_cls = base_cls super().__init__(base_cls.name, vm) + self._cls = None # lazily loaded 'cls' attribute self.module = base_cls.module # Lazily loaded to handle recursive types. # See the formal_type_parameters() property. @@ -2415,7 +2408,7 @@ def _load_formal_type_parameters(self): # that imports, etc., are visible. The last created node is usually the # active one. self._formal_type_parameters = ( - self.vm.annotations_util.convert_class_annotations( + self.vm.annotation_utils.convert_class_annotations( self.vm.program.cfg_nodes[-1], self._formal_type_parameters)) self._formal_type_parameters_loaded = True @@ -2425,7 +2418,7 @@ def compute_mro(self): def instantiate(self, node, container=None): if self.full_name == "builtins.type": # deformalize removes TypeVars. - instance = self.vm.annotations_util.deformalize( + instance = self.vm.annotation_utils.deformalize( self.formal_type_parameters[abstract_utils.T]) return instance.to_variable(node) elif self.full_name == "typing.ClassVar": @@ -2436,8 +2429,17 @@ def instantiate(self, node, container=None): else: return super().instantiate(node, container) - def get_class(self): - return self.base_cls.get_class() + @property + def cls(self): + if not self.vm.convert.minimally_initialized: + return self.vm.convert.unsolvable + if not self._cls: + self._cls = self.base_cls.cls + return self._cls + + @cls.setter + def cls(self, cls): + self._cls = cls def set_class(self, node, var): self.base_cls.set_class(node, var) @@ -2648,7 +2650,7 @@ def call_slot(self, node, *args, **kwargs): raise function.WrongArgTypes( function.Signature.from_callable(self), function.Args(posargs=args), self.vm, bad_param=bad_param) - ret = self.vm.annotations_util.sub_one_annotation( + ret = self.vm.annotation_utils.sub_one_annotation( node, self.formal_type_parameters[abstract_utils.RET], substs) node, retvar = self.vm.init_class(node, ret) return node, retvar @@ -2665,7 +2667,7 @@ class LiteralClass(ParameterizedClass): def __init__(self, instance, vm, template=None): base_cls = vm.convert.name_to_value("typing.Literal") - formal_type_parameters = {abstract_utils.T: instance.get_class()} + formal_type_parameters = {abstract_utils.T: instance.cls} super().__init__(base_cls, formal_type_parameters, vm, template) self._instance = instance @@ -3014,7 +3016,10 @@ def get_own_attributes(self): def get_own_abstract_methods(self): def _can_be_abstract(var): - return any((isinstance(v, Function) or v.isinstance_PropertyInstance()) + return any((isinstance(v, Function) or # pylint: disable=g-complex-comprehension + v.isinstance_PropertyInstance() or + v.isinstance_ClassMethodInstance() or + v.isinstance_StaticMethodInstance()) and v.is_abstract for v in var.data) return {name for name, var in self.members.items() if _can_be_abstract(var)} @@ -3073,7 +3078,8 @@ def bases(self): return self._bases def metaclass(self, node): - if self.cls and self.cls is not self._get_inherited_metaclass(): + if (self.cls.full_name != "builtins.type" and + self.cls is not self._get_inherited_metaclass()): return self.vm.convert.merge_classes([self]) else: return None @@ -3526,7 +3532,7 @@ def _inner_cls_check(self, last_frame): # get all type parameters from function annotations all_type_parameters = [] for annot in self.signature.annotations.values(): - params = self.vm.annotations_util.get_type_parameters(annot) + params = self.vm.annotation_utils.get_type_parameters(annot) all_type_parameters.extend(itm.with_module(None) for itm in params) if all_type_parameters: @@ -3610,7 +3616,7 @@ def call(self, node, func, args, new_locals=False, alias_map=None, frame.substs, annotation_substs) # Keep type parameters without substitutions, as they may be needed for # type-checking down the road. - annotations = self.vm.annotations_util.sub_annotations( + annotations = self.vm.annotation_utils.sub_annotations( node, sig.annotations, annotation_substs, instantiate_unbound=False) if sig.has_param_annotations: if first_arg and sig.param_names[0] == "self": @@ -3633,7 +3639,7 @@ def call(self, node, func, args, new_locals=False, alias_map=None, self.argcount(node) == 0 or name != sig.param_names[0])): extra_key = (self.get_first_opcode(), name) - node, callargs[name] = self.vm.annotations_util.init_annotation( + node, callargs[name] = self.vm.annotation_utils.init_annotation( node, name, annotations[name], container=container, extra_key=extra_key) mutations = self._mutations_generator(node, first_arg, substs) @@ -3884,7 +3890,7 @@ def call(self, node, _, args, alias_map=None): callargs = self._map_args(node, args.simplify(node, self.vm)) substs = self.match_args(node, args, alias_map) # Substitute type parameters in the signature's annotations. - annotations = self.vm.annotations_util.sub_annotations( + annotations = self.vm.annotation_utils.sub_annotations( node, self.signature.annotations, substs, instantiate_unbound=False) if self.signature.has_return_annotation: ret_type = annotations["return"] @@ -3905,6 +3911,7 @@ class BoundFunction(BaseValue): def __init__(self, callself, underlying): super().__init__(underlying.name, underlying.vm) + self.cls = underlying.cls self._callself = callself self.underlying = underlying self.is_attribute_of_class = False @@ -3916,7 +3923,8 @@ def __init__(self, callself, underlying): inst = abstract_utils.get_atomic_value( self._callself, default=self.vm.convert.unsolvable) if self._should_replace_self_annot(): - if isinstance(inst.cls, class_mixin.Class): + if (isinstance(inst.cls, class_mixin.Class) and + inst.cls.full_name != "builtins.type"): for cls in inst.cls.mro: if isinstance(cls, ParameterizedClass): base_cls = cls.base_cls @@ -3999,9 +4007,6 @@ def has_varargs(self): def has_kwargs(self): return self.underlying.has_kwargs() - def get_class(self): - return self.underlying.get_class() - @property def is_abstract(self): return self.underlying.is_abstract @@ -4200,17 +4205,15 @@ class Splat(BaseValue): def __init__(self, vm, iterable): super().__init__("splat", vm) - self.iterable = iterable - - def get_class(self): # When building a tuple for a function call, we preserve splats as elements # in a concrete tuple (e.g. f(x, *ys, z) gets called with the concrete tuple # (x, *ys, z) in starargs) and let the arg matcher in function.py unpack - # them. Constructing the tuple invokes get_class() as a side effect; ideally + # them. Constructing the tuple accesses its class as a side effect; ideally # we would specialise abstract.Tuple for function calls and not bother # constructing an associated TupleClass for a function call tuple, but for # now we just set the class to Any here. - return self.vm.convert.unsolvable + self.cls = vm.convert.unsolvable + self.iterable = iterable def __repr__(self): return "splat(%r)" % self.iterable.data @@ -4532,10 +4535,6 @@ def _make_sig(args, ret): slots=None, template=()) - def get_class(self): - # We treat instances of an Unknown as the same as the class. - return self - def instantiate(self, node, container=None): return self.to_variable(node) diff --git a/pytype/abstract_test.py b/pytype/abstract_test.py index 8f80e120b..565a59440 100644 --- a/pytype/abstract_test.py +++ b/pytype/abstract_test.py @@ -1075,7 +1075,7 @@ def test_instantiate_type_parameter_type(self): def test_super_type(self): supercls = special_builtins.Super(self._vm) - self.assertEqual(supercls.get_class(), self._vm.convert.type_type) + self.assertEqual(supercls.cls, self._vm.convert.type_type) def test_instantiate_interpreter_class(self): cls = abstract.InterpreterClass("X", [], {}, None, self._vm) @@ -1154,7 +1154,7 @@ def test_instantiate_tuple_class_for_sub(self): subst_value = cls.instantiate(self._vm.root_node, abstract_utils.DUMMY_CONTAINER) # Recover the class from the instance. - subbed_cls = self._vm.annotations_util.sub_one_annotation( + subbed_cls = self._vm.annotation_utils.sub_one_annotation( self._vm.root_node, type_param, [{ abstract_utils.K: subst_value }]) diff --git a/pytype/abstract_utils.py b/pytype/abstract_utils.py index e55bc965d..335592bf7 100644 --- a/pytype/abstract_utils.py +++ b/pytype/abstract_utils.py @@ -241,7 +241,7 @@ def get_template(val): base = get_atomic_value(base, default=val.vm.convert.unsolvable) res.update(get_template(base)) return res - elif val.cls: + elif val.cls != val: return get_template(val.cls) else: return set() @@ -329,7 +329,7 @@ def merge(t0, t1, name): formal_type_parameters.merge_from( base.base_cls.all_formal_type_parameters, merge) params = base.get_formal_type_parameters() - if getattr(container, "cls", None): + if hasattr(container, "cls"): container_template = container.cls.template else: container_template = () @@ -589,8 +589,16 @@ def check_classes(var, check): Returns: Whether the check passes. """ - return var and all( - v.cls.isinstance_Class() and check(v.cls) for v in var.data if v.cls) + if not var: + return False + for v in var.data: + if v.isinstance_Class(): + if not check(v): + return False + elif v.cls.isinstance_Class() and v.cls != v: + if not check(v.cls): + return False + return True def match_type_container(typ, container_type_name: Union[str, Tuple[str, ...]]): @@ -704,7 +712,7 @@ def is_indefinite_iterable(val: _BaseValue): """True if val is a non-concrete instance of typing.Iterable.""" instance = val.isinstance_Instance() concrete = is_concrete(val) - cls_instance = val.cls and val.cls.isinstance_Class() + cls_instance = val.cls.isinstance_Class() if not (instance and cls_instance and not concrete): return False for cls in val.cls.mro: @@ -755,7 +763,7 @@ def is_callable(value: _BaseValue): value.isinstance_StaticMethod() or value.isinstance_StaticMethodInstance()): return True - if not value.cls or not value.cls.isinstance_Class(): + if not value.cls.isinstance_Class(): return False _, attr = value.vm.attribute_handler.get_attribute( value.vm.root_node, value.cls, "__call__") diff --git a/pytype/analyze.py b/pytype/analyze.py index 5e72bcb5d..3153c9691 100644 --- a/pytype/analyze.py +++ b/pytype/analyze.py @@ -332,7 +332,7 @@ def init_class(self, node, cls, container=None, extra_key=None): def _call_method(self, node, binding, method_name): node, method = self.attribute_handler.get_attribute( - node, binding.data.get_class(), method_name, binding) + node, binding.data.cls, method_name, binding) if method: bound_method = self.bind_method( node, method, binding.AssignToNewVariable()) @@ -344,7 +344,7 @@ def _call_init_on_binding(self, node, b): for param in b.data.instance_type_parameters.values(): node = self.call_init(node, param) node = self._call_method(node, b, "__init__") - cls = b.data.get_class() + cls = b.data.cls if isinstance(cls, abstract.InterpreterClass): # Call any additional initalizers the class has registered. for method in cls.additional_init_methods: @@ -587,7 +587,7 @@ def pytd_classes_for_call_traces(self): # We don't need to record call signatures that don't involve # unknowns - there's nothing to solve for. continue - cls = args[0].data.get_class() + cls = args[0].data.cls if isinstance(cls, abstract.PyTDClass): class_to_records[cls].append(call_record) classes = [] diff --git a/pytype/annotations_util.py b/pytype/annotation_utils.py similarity index 98% rename from pytype/annotations_util.py rename to pytype/annotation_utils.py index 97f303df0..b338f5340 100644 --- a/pytype/annotations_util.py +++ b/pytype/annotation_utils.py @@ -12,7 +12,7 @@ from pytype.overlays import typing_overlay -class AnnotationsUtil(utils.VirtualMachineWeakrefMixin): +class AnnotationUtils(utils.VirtualMachineWeakrefMixin): """Utility class for inline type annotations.""" def sub_annotations(self, node, annotations, substs, instantiate_unbound): @@ -220,9 +220,8 @@ def extract_and_init_annotation(self, node, name, var): if self_var: type_params = [] for v in self_var.data: - if v.cls: - # Normalize type parameter names by dropping the scope. - type_params.extend(p.with_module(None) for p in v.cls.template) + # Normalize type parameter names by dropping the scope. + type_params.extend(p.with_module(None) for p in v.cls.template) self_substs = tuple( abstract_utils.get_type_parameter_substitutions(v, type_params) for v in self_var.data) diff --git a/pytype/attribute.py b/pytype/attribute.py index d7d36476a..4a8bb6eba 100644 --- a/pytype/attribute.py +++ b/pytype/attribute.py @@ -165,7 +165,7 @@ def set_attribute(self, node, obj, name, value): def _check_writable(self, obj, name): """Verify that a given attribute is writable. Log an error if not.""" - if obj.cls is None: + if not obj.cls.mro: # "Any" etc. return True for baseclass in obj.cls.mro: @@ -228,13 +228,14 @@ def _get_class_attribute(self, node, cls, name, valself=None): # instance, if we're analyzing int.mro(), we want to retrieve the mro # method on the type class, but for (3).mro(), we want to report that the # method does not exist.) - meta = cls.get_class() + meta = cls.cls return self._get_attribute(node, cls, meta, name, valself) def _get_instance_attribute(self, node, obj, name, valself=None): """Get an attribute from an instance.""" assert isinstance(obj, abstract.SimpleValue) - return self._get_attribute(node, obj, obj.cls, name, valself) + cls = None if obj.cls.full_name == "builtins.type" else obj.cls + return self._get_attribute(node, obj, cls, name, valself) def _get_attribute(self, node, obj, cls, name, valself): """Get an attribute from an object or its class. @@ -268,6 +269,12 @@ def _get_attribute(self, node, obj, cls, name, valself): node, obj, name, valself, skip=()) else: node, attr = self._get_member(node, obj, name, valself) + if attr is None and obj.maybe_missing_members: + # The VM hit maximum depth while initializing this instance, so it may + # have attributes that we don't know about. These attributes take + # precedence over class attributes and __getattr__, so we set `attr` to + # Any immediately. + attr = self.vm.new_unsolvable(node) if attr is None and cls: # Check for the attribute on the class. node, attr = self.get_attribute(node, cls, name, valself) @@ -288,22 +295,18 @@ def _get_attribute(self, node, obj, cls, name, valself): # reinitialize it with the current instance's parameter values. subst = abstract_utils.get_type_parameter_substitutions( valself.data, - self.vm.annotations_util.get_type_parameters(typ)) - typ = self.vm.annotations_util.sub_one_annotation( + self.vm.annotation_utils.get_type_parameters(typ)) + typ = self.vm.annotation_utils.sub_one_annotation( node, typ, [subst], instantiate_unbound=False) - _, attr = self.vm.annotations_util.init_annotation(node, name, typ) + _, attr = self.vm.annotation_utils.init_annotation(node, name, typ) elif attr is None: # An attribute has been declared but not defined, e.g., # class Foo: # bar: int - _, attr = self.vm.annotations_util.init_annotation(node, name, typ) + _, attr = self.vm.annotation_utils.init_annotation(node, name, typ) break if attr is not None: attr = self._filter_var(node, attr) - if attr is None and obj.maybe_missing_members: - # The VM hit maximum depth while initializing this instance, so it may - # have attributes that we don't know about. - attr = self.vm.new_unsolvable(node) return node, attr def _get_attribute_from_super_instance( @@ -326,8 +329,8 @@ def _get_attribute_from_super_instance( # super().__init__() # line 6 # if we're looking up super.__init__ in line 6 as part of analyzing the # super call in line 3, then starting_cls=Foo, current_cls=Bar. - if (isinstance(obj.super_obj.cls, - (type(None), abstract.AMBIGUOUS_OR_EMPTY)) or + if (obj.super_obj.cls.full_name == "builtins.type" or + isinstance(obj.super_obj.cls, abstract.AMBIGUOUS_OR_EMPTY) or isinstance(obj.super_cls, abstract.AMBIGUOUS_OR_EMPTY)): # Setting starting_cls to the current class when either of them is # ambiguous is technically incorrect but behaves correctly in the common diff --git a/pytype/class_mixin.py b/pytype/class_mixin.py index e983bc062..0939e08fd 100644 --- a/pytype/class_mixin.py +++ b/pytype/class_mixin.py @@ -106,8 +106,8 @@ def __new__(cls, *unused_args, **unused_kwds): def init_mixin(self, metaclass): """Mix-in equivalent of __init__.""" if metaclass is None: - self.cls = self._get_inherited_metaclass() - else: + metaclass = self._get_inherited_metaclass() + if metaclass: # TODO(rechen): Check that the metaclass is a (non-strict) subclass of the # metaclasses of the base classes. self.cls = metaclass @@ -239,8 +239,7 @@ def _init_abstract_methods(self): self.abstract_methods = abstract_methods def _has_explicit_abcmeta(self): - return self.cls and any( - parent.full_name == "abc.ABCMeta" for parent in self.cls.mro) + return any(parent.full_name == "abc.ABCMeta" for parent in self.cls.mro) def _has_implicit_abcmeta(self): """Whether the class should be considered implicitly abstract.""" @@ -268,11 +267,7 @@ def is_test_class(self): @property def is_enum(self): - if self.cls: - return any(cls.full_name == "enum.EnumMeta" for cls in self.cls.mro) - else: - return any(base.cls and base.cls.full_name == "enum.Enum" - for base in self.mro) + return any(cls.full_name == "enum.EnumMeta" for cls in self.cls.mro) @property def is_protocol(self): @@ -280,13 +275,14 @@ def is_protocol(self): def _get_inherited_metaclass(self): for base in self.mro[1:]: - if isinstance(base, Class) and base.cls is not None: + if (isinstance(base, Class) and base.cls != self.vm.convert.unsolvable and + base.cls.full_name != "builtins.type"): return base.cls return None def call_metaclass_init(self, node): """Call the metaclass's __init__ method if it does anything interesting.""" - if not self.cls: + if self.cls.full_name == "builtins.type": return node node, init = self.vm.attribute_handler.get_attribute( node, self.cls, "__init__") @@ -394,7 +390,7 @@ def get_special_attribute(self, node, name, valself): if name == "__getitem__" and valself is None: # See vm._call_binop_on_bindings: valself == None is a special value that # indicates an annotation. - if self.cls: + if self.cls.full_name != "builtins.type": # This class has a custom metaclass; check if it defines __getitem__. _, att = self.vm.attribute_handler.get_attribute( node, self, name, self.to_binding(node)) diff --git a/pytype/convert.py b/pytype/convert.py index 6031fafa2..af55f3a5a 100644 --- a/pytype/convert.py +++ b/pytype/convert.py @@ -66,6 +66,7 @@ def __init__(self, type_param_name): def __init__(self, vm): super().__init__(vm) + self.minimally_initialized = False self.vm.convert = self # to make constant_to_value calls below work self.pytd_convert = output.Converter(vm) @@ -83,6 +84,9 @@ def __init__(self, vm): self.object_type = self.constant_to_value(object) self.unsolvable = abstract.Unsolvable(self.vm) + self.type_type = self.constant_to_value(type) + self.minimally_initialized = True + self.empty = abstract.Empty(self.vm) self.no_return = typing_overlay.NoReturn(self.vm) @@ -133,7 +137,6 @@ def __init__(self, vm): self.set_type = self.constant_to_value(set) self.frozenset_type = self.constant_to_value(frozenset) self.dict_type = self.constant_to_value(dict) - self.type_type = self.constant_to_value(type) self.module_type = self.constant_to_value(types.ModuleType) self.function_type = self.constant_to_value(types.FunctionType) self.tuple_type = self.constant_to_value(tuple) @@ -218,6 +221,10 @@ def build_string(self, node, s): del node return self.constant_to_var(s) + def build_nonatomic_string(self, node): + s = self.primitive_class_instances[str] + return s.to_variable(node) + def build_content(self, elements): if len(elements) == 1: return next(iter(elements)) @@ -367,11 +374,7 @@ def merge_classes(self, instances): Returns: An abstract.BaseValue created by merging the instances' classes. """ - classes = set() - for v in instances: - cls = v.get_class() - if cls and cls != self.empty: - classes.add(cls) + classes = {v.cls for v in instances if v.cls != self.empty} return self.vm.merge_values(classes) def constant_to_var(self, pyval, subst=None, node=None, source_sets=None, @@ -725,7 +728,7 @@ def _constant_to_value(self, pyval, subst, get_node): raise self.TypeParameterError(c.full_name) # deformalize gets rid of any unexpected TypeVars, which can appear # if something is annotated as Type[T]. - return self.vm.annotations_util.deformalize( + return self.vm.annotation_utils.deformalize( self.merge_classes(subst[c.full_name].data)) else: return self.constant_to_value(c, subst, self.vm.root_node) diff --git a/pytype/convert_test.py b/pytype/convert_test.py index 7d94e636b..646ea1d6d 100644 --- a/pytype/convert_test.py +++ b/pytype/convert_test.py @@ -44,12 +44,12 @@ class C(B): ... self.assertEqual(meta, cls_meta) self.assertEqual(meta, subcls_meta) - def test_convert_no_metaclass(self): + def test_convert_default_metaclass(self): ast = self._load_ast("a", """ class A: ... """) cls = self._convert_class("a.A", ast) - self.assertIsNone(cls.cls) + self.assertEqual(cls.cls, self._vm.convert.type_type) def test_convert_metaclass_with_generic(self): ast = self._load_ast("a", """ diff --git a/pytype/debug.py b/pytype/debug.py index 967cb66f9..bf6c659ca 100644 --- a/pytype/debug.py +++ b/pytype/debug.py @@ -1,7 +1,10 @@ """Debugging helper functions.""" import collections +import contextlib +import inspect import io +import logging import re import traceback @@ -321,9 +324,9 @@ def stack_trace(indent_level=0, limit=100): indent = " " * indent_level stack = [frame for frame in traceback.extract_stack() if "/errors.py" not in frame[0] and "/debug.py" not in frame[0]] - trace = traceback.format_list(stack[-limit:]) - trace = [indent + re.sub(r"/usr/.*/pytype/", "", x) for x in trace] - return "\n ".join(trace) + tb = traceback.format_list(stack[-limit:]) + tb = [indent + re.sub(r"/usr/.*/pytype/", "", x) for x in tb] + return "\n ".join(tb) def _setup_tabulate(): @@ -384,3 +387,75 @@ def show_ordered_code(code, extra_col=None): tab.append([blk]) tab.append(["\n".join(block_table[start:end])]) print(tabulate.tabulate(tab, tablefmt="fancy_grid")) + + +# Tracing logger +def tracer(name=None): + name = f"trace.{name}" if name else "trace" + return logging.getLogger(name) + + +def set_trace_level(level): + logging.getLogger("trace").setLevel(level) + + +@contextlib.contextmanager +def tracing(level=logging.DEBUG): + log = logging.getLogger("trace") + current_level = log.getEffectiveLevel() + log.setLevel(level) + try: + yield + finally: + log.setLevel(current_level) + + +def trace(name, *trace_args): + """Record args and return value for a function call. + + The trace is of the form + function name: { + function name: arg = value + function name: arg = value + ... + function name: -> return + function name: } + + This will let us write tools to pretty print the traces with indentation etc. + + Args: + name: module name, usually `__name__` + *trace_args: function arguments to log + Returns: + a decorator + """ + def decorator(f): + def wrapper(*args, **kwargs): + t = tracer(name) + if t.getEffectiveLevel() < logging.DEBUG: + return f(*args, **kwargs) + argspec = inspect.getfullargspec(f) + t.debug("%s: {", f.__name__) + for arg in trace_args: + if isinstance(arg, int): + argname = argspec.args[arg] + val = args[arg] + else: + argname = arg + val = kwargs[arg] + t.debug("%s: %s = %s", f.__name__, argname, show(val)) + ret = f(*args, **kwargs) + t.debug("%s: -> %s", f.__name__, show(ret)) + t.debug("%s: }", f.__name__) + return ret + return wrapper + return decorator + + +def show(x): + """Pretty print values for debugging.""" + typename = x.__class__.__name__ + if typename == "Variable": + return f"{x!r} {x.data}" + else: + return f"{x!r} <{typename}>" diff --git a/pytype/errors.py b/pytype/errors.py index e02e0202c..2915d018e 100644 --- a/pytype/errors.py +++ b/pytype/errors.py @@ -503,7 +503,7 @@ def _print_as_expected_type(self, t: abstract.BaseValue, instance=None): elif abstract_utils.is_concrete(t): return re.sub(r"(\\n|\s)+", " ", t.str_of_constant(self._print_as_expected_type)) - elif isinstance(t, abstract.AnnotationClass) or not t.cls: + elif isinstance(t, abstract.AnnotationClass) or t.cls == t: return t.name else: return "" % self._print_as_expected_type(t.cls, t) @@ -1123,7 +1123,7 @@ def assert_type(self, stack, node, var, typ=None): # NOTE: Converting types to strings is provided as a fallback, but is not # really supported, since there are issues around name resolution. vm = typ.data[0].vm - typ = vm.annotations_util.extract_annotation( + typ = vm.annotation_utils.extract_annotation( node, typ, "assert_type", vm.simple_stack()) node, typ = vm.init_class(node, typ) wanted = [ diff --git a/pytype/function.py b/pytype/function.py index 644ddb652..30f9602fa 100644 --- a/pytype/function.py +++ b/pytype/function.py @@ -31,7 +31,7 @@ def get_signatures(func): return get_signatures(func.method) elif func.isinstance_SimpleFunction(): return [func.signature] - elif func.cls and func.cls.isinstance_CallableClass(): + elif func.cls.isinstance_CallableClass(): return [Signature.from_callable(func.cls)] else: if func.isinstance_Instance(): @@ -84,7 +84,7 @@ def __init__(self, name, param_names, varargs_name, kwonly_params, self.type_params = set() for annot in self.annotations.values(): self.type_params.update( - p.name for p in annot.vm.annotations_util.get_type_parameters(annot)) + p.name for p in annot.vm.annotation_utils.get_type_parameters(annot)) @property def has_return_annotation(self): @@ -98,7 +98,7 @@ def add_scope(self, module): """Add scope for type parameters in annotations.""" annotations = {} for key, val in self.annotations.items(): - annotations[key] = val.vm.annotations_util.add_scope( + annotations[key] = val.vm.annotation_utils.add_scope( val, self.excluded_types, module) self.annotations = annotations @@ -120,7 +120,7 @@ def check_type_parameter_count(self, stack): """Check the count of type parameters in function.""" c = collections.Counter() for annot in self.annotations.values(): - c.update(annot.vm.annotations_util.get_type_parameters(annot)) + c.update(annot.vm.annotation_utils.get_type_parameters(annot)) for param, count in c.items(): if param.name in self.excluded_types: # skip all the type parameters in `excluded_types` @@ -828,8 +828,7 @@ def _collect_mutated_parameters(cls, typ, mutated_type): if (not isinstance(typ, pytd.GenericType) or not isinstance(mutated_type, pytd.GenericType) or typ.base_type != mutated_type.base_type or - not isinstance(typ.base_type, pytd.ClassType) or - not typ.base_type.cls): + not isinstance(typ.base_type, pytd.ClassType)): raise ValueError("Unsupported mutation:\n%r ->\n%r" % (typ, mutated_type)) return [zip(mutated_type.base_type.cls.template, mutated_type.parameters)] @@ -890,7 +889,7 @@ def append_float(x: list[int]): # This is a constructor, so check whether the constructed instance needs # to be mutated. for ret in retvar.data: - if ret.cls: + if ret.cls.full_name != "builtins.type": for t in ret.cls.template: if t.full_name in subst: mutations.append(Mutation(ret, t.full_name, subst[t.full_name])) diff --git a/pytype/matcher.py b/pytype/matcher.py index 52693eba7..6bbda05b0 100644 --- a/pytype/matcher.py +++ b/pytype/matcher.py @@ -123,7 +123,7 @@ def compute_subst(self, formal_args, arg_dict, view, alias_map=None): actual = arg_dict[name] subst = self._match_value_against_type(actual, formal, subst, view) if subst is None: - formal = self.vm.annotations_util.sub_one_annotation( + formal = self.vm.annotation_utils.sub_one_annotation( self._node, formal, [self._error_subst or {}]) return None, function.BadParam( @@ -283,10 +283,10 @@ def _match_value_against_type(self, value, other_type, subst, view): # some sort of runtime processing of type annotations. We replace all type # parameters with 'object' so that they don't match concrete types like # 'int' but still match things like 'Any'. - type_params = self.vm.annotations_util.get_type_parameters(left) + type_params = self.vm.annotation_utils.get_type_parameters(left) obj_var = self.vm.convert.primitive_class_instances[object].to_variable( self._node) - left = self.vm.annotations_util.sub_one_annotation( + left = self.vm.annotation_utils.sub_one_annotation( self._node, left, [{p.full_name: obj_var for p in type_params}]) assert not left.formal, left @@ -424,7 +424,7 @@ def _match_value_against_type(self, value, other_type, subst, view): # Since options without type parameters do not modify subst, we can # break after the first match rather than finding all matches. We still # need to fill in subst with *something* so that - # annotations_util.sub_one_annotation can tell that all annotations have + # annotation_utils.sub_one_annotation can tell that all annotations have # been fully matched. subst = self._subst_with_type_parameters_from(subst, other_type) break @@ -456,7 +456,7 @@ def _match_type_against_type(self, left, other_type, subst, view): isinstance(other_type, abstract.Empty)): return subst elif isinstance(left, abstract.AMBIGUOUS_OR_EMPTY): - params = self.vm.annotations_util.get_type_parameters(other_type) + params = self.vm.annotation_utils.get_type_parameters(other_type) if isinstance(left, abstract.Empty): value = self.vm.convert.empty else: @@ -479,7 +479,7 @@ def _match_type_against_type(self, left, other_type, subst, view): elif _is_callback_protocol(other_type): return self._match_type_against_callback_protocol( left, other_type, subst, view) - elif left.cls: + else: return self._match_instance_against_type(left, other_type, subst, view) elif isinstance(left, abstract.Module): if other_type.full_name in [ @@ -517,11 +517,9 @@ def _match_type_against_type(self, left, other_type, subst, view): elif _is_callback_protocol(other_type): return self._match_type_against_callback_protocol( left, other_type, subst, view) - elif left.cls: + else: return self._match_type_against_type( abstract.Instance(left.cls, self.vm), other_type, subst, view) - else: - return None elif isinstance(left, dataclass_overlay.FieldInstance) and left.default: return self._match_all_bindings(left.default, other_type, subst, view) elif isinstance(left, abstract.SimpleValue): @@ -590,7 +588,21 @@ def _mutate_type_parameters(self, params, value, subst): return self._merge_substs(subst, [new_subst]) def _get_param_matcher(self, callable_type): - """Helper for _match_signature_against_callable.""" + """Helper for matching the parameters of a callable. + + Args: + callable_type: The callable being matched against. + + Returns: + A special param matcher: (left, right, subst) -> Optional[subst]. + left: An argument to be matched against a parameter of callable_type. + right: A parameter of callable_type. + subst: The current substitution dictionary. + If the matcher returns a non-None subst dict, then the match has succeeded + via special matching rules for single TypeVars. Otherwise, the caller + should next attempt normal matching on the inputs. (See + _match_signature_against_callable for a usage example.) + """ # Any type parameter should match an unconstrained, unbounded type parameter # that appears exactly once in a callable, in order for matching to succeed # in cases like: @@ -601,12 +613,12 @@ def _get_param_matcher(self, callable_type): # the callable must accept any argument, but here, it means that the # argument must be the same type as `x`. callable_param_count = collections.Counter( - self.vm.annotations_util.get_type_parameters(callable_type)) + self.vm.annotation_utils.get_type_parameters(callable_type)) if isinstance(callable_type, abstract.CallableClass): # In CallableClass, type parameters in arguments are double-counted # because ARGS contains the union of the individual arguments. callable_param_count.subtract( - self.vm.annotations_util.get_type_parameters( + self.vm.annotation_utils.get_type_parameters( callable_type.get_formal_type_parameter(abstract_utils.ARGS))) def match(left, right, subst): if (not isinstance(left, abstract.TypeParameter) or @@ -720,11 +732,10 @@ def _match_instance_against_type(self, left, other_type, subst, view): left, other_type.formal_type_parameters[abstract_utils.T], subst, view) elif isinstance(other_type, class_mixin.Class): - if not self._satisfies_noniterable_str(left.get_class(), other_type): - self._noniterable_str_error = NonIterableStrError(left.get_class(), - other_type) + if not self._satisfies_noniterable_str(left.cls, other_type): + self._noniterable_str_error = NonIterableStrError(left.cls, other_type) return None - base = self.match_from_mro(left.get_class(), other_type) + base = self.match_from_mro(left.cls, other_type) if base is None: if other_type.is_protocol: with self._track_partially_matched_protocols(): @@ -889,11 +900,17 @@ def _match_callable_instance(self, left, instance, other_type, subst, view): return subst if left.num_args != other_type.num_args: return None + param_match = self._get_param_matcher(other_type) for i in range(left.num_args): - # Flip actual and expected to enforce contravariance of argument types. - subst = self._instantiate_and_match( - other_type.formal_type_parameters[i], left.formal_type_parameters[i], - subst, view, container=other_type) + left_arg = left.formal_type_parameters[i] + right_arg = other_type.formal_type_parameters[i] + new_subst = param_match(left_arg, right_arg, subst) + if new_subst is None: + # Flip actual and expected to enforce contravariance of argument types. + subst = self._instantiate_and_match( + right_arg, left_arg, subst, view, container=other_type) + else: + subst = new_subst if subst is None: return None return subst @@ -905,10 +922,8 @@ def _get_attribute_names(self, left): _ = left.items() # loads all attributes into members if isinstance(left, abstract.SimpleValue): left_attributes.update(left.members) - left_cls = left.get_class() - if left_cls: - left_attributes.update(*(cls.get_own_attributes() for cls in left_cls.mro - if isinstance(cls, class_mixin.Class))) + left_attributes.update(*(cls.get_own_attributes() for cls in left.cls.mro + if isinstance(cls, class_mixin.Class))) if "__getitem__" in left_attributes and "__iter__" not in left_attributes: # If a class has a __getitem__ method, it also (implicitly) has a # __iter__: Python will emulate __iter__ by calling __getitem__ with @@ -927,18 +942,17 @@ def _match_against_protocol(self, left, other_type, subst, view): Returns: A new type parameter assignment if the matching succeeded, None otherwise. """ - left_cls = left.get_class() - if isinstance(left_cls, abstract.AMBIGUOUS_OR_EMPTY): + if isinstance(left.cls, abstract.AMBIGUOUS_OR_EMPTY): return subst - elif left_cls.is_dynamic: + elif left.cls.is_dynamic: return self._subst_with_type_parameters_from(subst, other_type) left_attributes = self._get_attribute_names(left) missing = other_type.protocol_attributes - left_attributes if missing: # not all protocol attributes are implemented by 'left' self._protocol_error = ProtocolMissingAttributesError( - left_cls, other_type, missing) + left.cls, other_type, missing) return None - key = (left_cls, other_type) + key = (left.cls, other_type) if key in self._protocol_cache: return subst self._protocol_cache.add(key) @@ -980,9 +994,7 @@ def _resolve_property_attribute(self, cls, attribute, instance): return resolved_attribute def _get_type(self, value): - cls = value.get_class() - if not cls: - return None + cls = value.cls if (not isinstance(cls, (abstract.PyTDClass, abstract.InterpreterClass)) or not cls.template): return cls @@ -1021,7 +1033,7 @@ def _get_attribute_types(self, other_type, attribute): for (param, value) in other_type.get_formal_type_parameters().items(): annotation_subst[param] = value.instantiate( self._node, abstract_utils.DUMMY_CONTAINER) - callable_signature = self.vm.annotations_util.sub_one_annotation( + callable_signature = self.vm.annotation_utils.sub_one_annotation( self._node, callable_signature, [annotation_subst]) yield callable_signature @@ -1037,9 +1049,8 @@ def _match_protocol_attribute(self, left, other_type, attribute, subst, view): Returns: A new type parameter assignment if the matching succeeded, None otherwise. """ - left_cls = left.get_class() left_attribute = self._get_attribute_for_protocol_matching( - left_cls, attribute, left) + left.cls, attribute, left) if left_attribute is None: if attribute == "__iter__": # See _get_attribute_names: left has an implicit __iter__ method @@ -1088,7 +1099,7 @@ def _match_protocol_attribute(self, left, other_type, attribute, subst, view): # protocol_attribute_var. bad_left, bad_right = zip(*bad_matches) self._protocol_error = ProtocolTypeError( - left_cls, other_type, attribute, self.vm.merge_values(bad_left), + left.cls, other_type, attribute, self.vm.merge_values(bad_left), self.vm.merge_values(bad_right)) return None return self._merge_substs(subst, new_substs) @@ -1099,16 +1110,17 @@ def _discard_ambiguous_values(self, values): # value altogether. concrete_values = [] for v in values: - if not isinstance(v, (abstract.AMBIGUOUS_OR_EMPTY, + # TODO(b/200220895): This is probably wrong; we should expand unions + # instead of ignoring them. + if not isinstance(v, (abstract.AMBIGUOUS_OR_EMPTY, abstract.Union, abstract.TypeParameterInstance)): - cls = v.get_class() - if not isinstance(cls, abstract.AMBIGUOUS_OR_EMPTY): + if not isinstance(v.cls, abstract.AMBIGUOUS_OR_EMPTY): concrete_values.append(v) return concrete_values def _satisfies_single_type(self, values): """Enforce that the variable contains only one concrete type.""" - class_names = {v.get_class().full_name for v in values} + class_names = {v.cls.full_name for v in values} for compat_name, name in _COMPATIBLE_BUILTINS: if {compat_name, name} <= class_names: class_names.remove(compat_name) @@ -1120,9 +1132,8 @@ def _satisfies_common_superclass(self, values): common_classes = None object_in_values = False for v in values: - cls = v.get_class() - object_in_values |= cls == self.vm.convert.object_type - superclasses = {c.full_name for c in cls.mro} + object_in_values |= v.cls == self.vm.convert.object_type + superclasses = {c.full_name for c in v.cls.mro} for compat_name, name in _COMPATIBLE_BUILTINS: if compat_name in superclasses: superclasses.add(name) @@ -1159,7 +1170,7 @@ def _satisfies_noniterable_str(self, left, other_type): def _subst_with_type_parameters_from(self, subst, typ): subst = subst.copy() - for param in self.vm.annotations_util.get_type_parameters(typ): + for param in self.vm.annotation_utils.get_type_parameters(typ): if param.name not in subst: subst[param.name] = self.vm.convert.empty.to_variable(self._node) return subst diff --git a/pytype/metaclass.py b/pytype/metaclass.py index 8ffa2b02f..19d13d151 100644 --- a/pytype/metaclass.py +++ b/pytype/metaclass.py @@ -66,9 +66,6 @@ def __init__(self, vm, cls, bases): class_mixin.Class.init_mixin(self, cls) self.bases = bases - def get_class(self): - return self.cls - def get_own_attributes(self): if isinstance(self.cls, class_mixin.Class): return self.cls.get_own_attributes() diff --git a/pytype/output.py b/pytype/output.py index 17e9370cd..2b230a477 100644 --- a/pytype/output.py +++ b/pytype/output.py @@ -277,19 +277,11 @@ def value_to_pytd_type(self, node, v, seen, view): # inner value rather than properly converting it. return pytd.Literal(repr(v.pyval)) elif isinstance(v, abstract.SimpleValue): - if v.cls: - ret = self.value_instance_to_pytd_type( - node, v.cls, v, seen=seen, view=view) - ret.Visit(visitors.FillInLocalPointers( - {"builtins": self.vm.loader.builtins})) - return ret - else: - # We don't know this type's __class__, so return AnythingType to - # indicate that we don't know anything about what this is. - # This happens e.g. for locals / globals, which are returned from the - # code in class declarations. - log.info("Using Any for %s", v.name) - return pytd.AnythingType() + ret = self.value_instance_to_pytd_type( + node, v.cls, v, seen=seen, view=view) + ret.Visit(visitors.FillInLocalPointers( + {"builtins": self.vm.loader.builtins})) + return ret elif isinstance(v, abstract.Union): return pytd_utils.JoinTypes(self.value_to_pytd_type(node, o, seen, view) for o in v.options) diff --git a/pytype/overlay_utils.py b/pytype/overlay_utils.py index 24248a6d3..fd72f52d5 100644 --- a/pytype/overlay_utils.py +++ b/pytype/overlay_utils.py @@ -66,13 +66,12 @@ def _process_annotation(param): if not param.typ: return elif isinstance(param.typ, cfg.Variable): - if all(t.cls for t in param.typ.data): - types = param.typ.data - if len(types) == 1: - annotations[param.name] = types[0].cls - else: - t = abstract.Union([t.cls for t in types], vm) - annotations[param.name] = t + types = param.typ.data + if len(types) == 1: + annotations[param.name] = types[0].cls + else: + t = abstract.Union([t.cls for t in types], vm) + annotations[param.name] = t else: annotations[param.name] = param.typ diff --git a/pytype/overlays/attr_overlay.py b/pytype/overlays/attr_overlay.py index 9b5a34d4e..04639cabb 100644 --- a/pytype/overlays/attr_overlay.py +++ b/pytype/overlays/attr_overlay.py @@ -229,8 +229,8 @@ def call(self, node, unused_func, args): type_source = TypeSource.TYPE allowed_type_params = ( self.vm.frame.type_params | - self.vm.annotations_util.get_callable_type_parameter_names(type_var)) - typ = self.vm.annotations_util.extract_annotation( + self.vm.annotation_utils.get_callable_type_parameter_names(type_var)) + typ = self.vm.annotation_utils.extract_annotation( node, type_var, "attr.ib", self.vm.simple_stack(), allowed_type_params=allowed_type_params) elif default_var: diff --git a/pytype/overlays/enum_overlay.py b/pytype/overlays/enum_overlay.py index 1f354f716..70edca5b6 100644 --- a/pytype/overlays/enum_overlay.py +++ b/pytype/overlays/enum_overlay.py @@ -196,7 +196,7 @@ def instantiate(self, node, container=None): # TODO(tsudol): Use the types of other members to set `value`. del container instance = abstract.Instance(self, self.vm) - instance.members["name"] = self.vm.convert.build_string(node, "") + instance.members["name"] = self.vm.convert.build_nonatomic_string(node) if self.member_type: value = self.member_type.instantiate(node) else: @@ -398,7 +398,7 @@ def _mark_dynamic_enum(self, cls): # The most typical use of custom subclasses of EnumMeta is to add more # members to the enum, or to (for example) make attribute access # case-insensitive. Treat such enums as having dynamic attributes. - if cls.cls and cls.cls.full_name != "enum.EnumMeta": + if cls.cls.full_name != "enum.EnumMeta": cls.maybe_missing_members = True return for base_var in cls.bases(): @@ -408,7 +408,7 @@ def _mark_dynamic_enum(self, cls): # Interpreter classes don't have "maybe_missing_members" set even if # they have _HAS_DYNAMIC_ATTRIBUTES. But for enums, those markers should # apply to the whole class. - if ((base.cls and base.cls.full_name != "enum.EnumMeta") or + if ((base.cls.full_name != "enum.EnumMeta") or base.maybe_missing_members or base.has_dynamic_attributes()): cls.maybe_missing_members = True return @@ -472,7 +472,6 @@ def _setup_interpreterclass(self, node, cls): cls.members["__new_member__"] = saved_new self._mark_dynamic_enum(cls) cls.members["__new__"] = self._make_new(node, member_type, cls) - cls.members["__eq__"] = EnumCmpEQ(self.vm).to_variable(node) # _generate_next_value_ is used as a static method of the enum, not a class # method. We need to rebind it here to make pytype analyze it correctly. # However, we skip this if it's already a staticmethod. @@ -529,7 +528,6 @@ def _setup_pytdclass(self, node, cls): member_type = self.vm.convert.constant_to_value( pytd_utils.JoinTypes(member_types)) cls.members["__new__"] = self._make_new(node, member_type, cls) - cls.members["__eq__"] = EnumCmpEQ(self.vm).to_variable(node) return node def call(self, node, func, args, alias_map=None): diff --git a/pytype/overlays/typing_overlay.py b/pytype/overlays/typing_overlay.py index a41a3827e..c5966ed45 100644 --- a/pytype/overlays/typing_overlay.py +++ b/pytype/overlays/typing_overlay.py @@ -101,7 +101,7 @@ def getitem_slot(self, node, slice_var): self.vm.errorlog.invalid_ellipses( self.vm.frames, inner_ellipses, args.name) else: - if args.cls and args.cls.full_name == "builtins.list": + if args.cls.full_name == "builtins.list": self.vm.errorlog.ambiguous_annotation(self.vm.frames, [args]) elif 0 not in ellipses or not isinstance(args, abstract.Unsolvable): self.vm.errorlog.invalid_annotation( @@ -145,7 +145,7 @@ def _get_constant(self, var, name, arg_type, arg_type_desc=None): def _get_annotation(self, node, var, name): with self.vm.errorlog.checkpoint() as record: - annot = self.vm.annotations_util.extract_annotation( + annot = self.vm.annotation_utils.extract_annotation( node, var, name, self.vm.simple_stack()) if record.errors: raise TypeVarError("\n".join(error.message for error in record.errors)) @@ -211,7 +211,7 @@ class Cast(abstract.PyTDFunction): def call(self, node, func, args): if args.posargs: - _, value = self.vm.annotations_util.extract_and_init_annotation( + _, value = self.vm.annotation_utils.extract_and_init_annotation( node, "typing.cast", args.posargs[0]) return node, value return super().call(node, func, args) @@ -288,8 +288,8 @@ def _getargs(self, node, args, functional): names.append(name_py_constant) if functional: allowed_type_params = ( - self.vm.annotations_util.get_callable_type_parameter_names(typ)) - annot = self.vm.annotations_util.extract_annotation( + self.vm.annotation_utils.get_callable_type_parameter_names(typ)) + annot = self.vm.annotation_utils.extract_annotation( node, typ, name_py_constant, self.vm.simple_stack(), allowed_type_params=allowed_type_params) else: @@ -484,7 +484,7 @@ def call(self, node, _, args, bases=None): self.vm.errorlog.invalid_namedtuple_arg(self.vm.frames, utils.message(e)) return node, self.vm.new_unsolvable(node) - annots = self.vm.annotations_util.convert_annotations_list( + annots = self.vm.annotation_utils.convert_annotations_list( node, zip(field_names, field_types)) field_types = [annots.get(field_name, self.vm.convert.unsolvable) for field_name in field_names] @@ -687,7 +687,7 @@ def _build_value(self, node, inner, ellipses): values = [] errors = [] for i, param in enumerate(inner): - # TODO(b/173742489): Once pytype has proper support for enums, we should + # TODO(b/173742489): Once the enum overlay is enabled, we should # stop allowing unsolvable and handle enums here. if (param == self.vm.convert.none or isinstance(param, abstract.LiteralClass) or @@ -696,6 +696,8 @@ def _build_value(self, node, inner, ellipses): elif (isinstance(param, abstract.ConcreteValue) and isinstance(param.pyval, (int, str, bytes))): value = abstract.LiteralClass(param, self.vm) + elif isinstance(param, abstract.Instance) and param.cls.is_enum: + value = abstract.LiteralClass(param, self.vm) else: if i in ellipses: invalid_param = "..." diff --git a/pytype/special_builtins.py b/pytype/special_builtins.py index 3fdc9664e..4630d9070 100644 --- a/pytype/special_builtins.py +++ b/pytype/special_builtins.py @@ -53,7 +53,7 @@ def call(self, node, func, args): # Removes TypeVars from the return value. ret = self.vm.program.NewVariable() for b in raw_ret.bindings: - value = self.vm.annotations_util.deformalize(b.data) + value = self.vm.annotation_utils.deformalize(b.data) ret.AddBinding(value, {b}, node) return node, ret @@ -255,8 +255,8 @@ def _is_instance(self, obj, class_spec): True if the object is derived from a class in the class_spec, False if it is not, and None if it is ambiguous whether obj matches class_spec. """ - cls = obj.get_class() - if (isinstance(obj, abstract.AMBIGUOUS_OR_EMPTY) or cls is None or + cls = obj.cls + if (isinstance(obj, abstract.AMBIGUOUS_OR_EMPTY) or isinstance(cls, abstract.AMBIGUOUS_OR_EMPTY)): return None return abstract_utils.check_against_mro(self.vm, cls, class_spec) @@ -376,9 +376,6 @@ def get_special_attribute(self, node, name, valself): else: return super().get_special_attribute(node, name, valself) - def get_class(self): - return self.cls - def call(self, node, _, args): self.vm.errorlog.not_callable(self.vm.frames, self) return node, self.vm.new_unsolvable(node) @@ -559,6 +556,12 @@ def call(self, node, funcv, args): raise NotImplementedError() +def _is_fn_abstract(func_var): + if func_var is None: + return False + return any(getattr(d, "is_abstract", None) for d in func_var.data) + + class PropertyInstance(abstract.SimpleValue, mixin.HasSlots): """Property instance (constructed by Property.call()).""" @@ -577,15 +580,7 @@ def __init__(self, vm, name, cls, fget=None, fset=None, fdel=None, doc=None): self.set_slot("getter", self.getter_slot) self.set_slot("setter", self.setter_slot) self.set_slot("deleter", self.deleter_slot) - self.is_abstract = any(self._is_fn_abstract(x) for x in [fget, fset, fdel]) - - def _is_fn_abstract(self, func_var): - if func_var is None: - return False - return any(getattr(d, "is_abstract", None) for d in func_var.data) - - def get_class(self): - return self.cls + self.is_abstract = any(_is_fn_abstract(x) for x in [fget, fset, fdel]) def fget_slot(self, node, obj, objtype): return self.vm.call_function(node, self.fget, function.Args((obj,))) @@ -641,9 +636,7 @@ def __init__(self, vm, cls, func): self.func = func self.cls = cls self.set_slot("__get__", self.func_slot) - - def get_class(self): - return self.cls + self.is_abstract = _is_fn_abstract(func) def func_slot(self, node, obj, objtype): return node, self.func @@ -681,9 +674,7 @@ def __init__(self, vm, cls, func): self.cls = cls self.func = func self.set_slot("__get__", self.func_slot) - - def get_class(self): - return self.cls + self.is_abstract = _is_fn_abstract(func) def func_slot(self, node, obj, objtype): results = [ClassMethodCallable(objtype, b.data) for b in self.func.bindings] diff --git a/pytype/stubs/stdlib/enum.pytd b/pytype/stubs/stdlib/enum.pytd index 6b3d416c6..7a83920e4 100644 --- a/pytype/stubs/stdlib/enum.pytd +++ b/pytype/stubs/stdlib/enum.pytd @@ -26,6 +26,7 @@ class Enum(metaclass=EnumMeta): _value_: Any name: str value: Any + def __eq__(self, other: Any) -> bool: ... def __new__(cls, value: str, names: Union[str, Iterable[str], Iterable[Tuple[str, Any]], Dict[str, Any]], module = ..., type: type = ..., start: complex = ...) -> Type[Enum]: ... def __new__(cls: Type[_T], value) -> _T: ... def __getattribute__(self, name) -> Any: ... diff --git a/pytype/tests/test_abc2.py b/pytype/tests/test_abc2.py index da0281037..f06459951 100644 --- a/pytype/tests/test_abc2.py +++ b/pytype/tests/test_abc2.py @@ -135,6 +135,24 @@ def f(x: Type[A[int]]): return x() """) + def test_abstract_classmethod(self): + self.Check(""" + import abc + class Foo(abc.ABC): + @classmethod + @abc.abstractmethod + def f(cls) -> str: ... + """) + + def test_bad_abstract_classmethod(self): + self.CheckWithErrors(""" + import abc + class Foo: + @classmethod + @abc.abstractmethod + def f(cls) -> str: ... # bad-return-type + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tests/test_dataclasses.py b/pytype/tests/test_dataclasses.py index ec9358112..1b31a90c9 100644 --- a/pytype/tests/test_dataclasses.py +++ b/pytype/tests/test_dataclasses.py @@ -712,6 +712,31 @@ def __init__(self, x: T) -> None: x2: str """) + def test_dataclass_attribute_with_getattr(self): + # Tests that the type of the 'x' attribute is correct in Child.__init__ + # (i.e., the __getattr__ return type shouldn't be used). + self.Check(""" + import dataclasses + from typing import Dict, Sequence + + class Base: + def __init__(self, x: str): + self.x = x + def __getattr__(self, name: str) -> 'Base': + return self + + class Child(Base): + def __init__(self, x: str, children: Sequence['Child']): + super().__init__(x) + self._children: Dict[str, Child] = {} + for child in children: + self._children[child.x] = child + + @dataclasses.dataclass + class Container: + child: Child + """) + class TestPyiDataclass(test_base.BaseTest): """Tests for @dataclasses in pyi files.""" diff --git a/pytype/tests/test_enums.py b/pytype/tests/test_enums.py index c88479be4..19bc9c5f8 100644 --- a/pytype/tests/test_enums.py +++ b/pytype/tests/test_enums.py @@ -233,6 +233,16 @@ class M(enum.Enum): _ = e.M["C"] # attribute-error """, pythonpath=[d.path]) + def test_name_lookup_from_canonical(self): + # Canonical enum members should have non-atomic names. + self.Check(""" + import enum + class M(enum.Enum): + A = 1 + def get(m: M): + m = M[m.name] + """) + def test_bad_name_lookup(self): self.CheckWithErrors(""" import enum @@ -338,6 +348,7 @@ class M(enum.Enum): y = foo.M(1) """, pythonpath=[d.path]) + @test_base.skip("Stricter equality disabled due to b/195136939") def test_enum_eq(self): # Note that this test only checks __eq__'s behavior. Though enums support # comparisons using `is`, pytype doesn't check `is` the same way as __eq__. @@ -368,6 +379,7 @@ class N(enum.Enum): assert_type(c, "bool") """) + @test_base.skip("Stricter equality disabled due to b/195136939") def test_enum_pytd_eq(self): with file_utils.Tempdir() as d: d.create_file("m.pyi", """ diff --git a/pytype/tests/test_typevar2.py b/pytype/tests/test_typevar2.py index 7385f64ed..0e1643741 100644 --- a/pytype/tests/test_typevar2.py +++ b/pytype/tests/test_typevar2.py @@ -500,7 +500,7 @@ def f(cls: Type[T]) -> Type[T]: return cls """) - @test_base.skip("Requires completing TODO in annotations_util.deformalize") + @test_base.skip("Requires completing TODO in annotation_utils.deformalize") def test_type_of_typevar(self): self.Check(""" from typing import Type, TypeVar @@ -564,6 +564,28 @@ def g(x: Union[int, str]): f(g, [0, '']) """) + def test_callable_instance_against_callable(self): + self.CheckWithErrors(""" + from typing import Any, Callable, TypeVar + T1 = TypeVar('T1') + T2 = TypeVar('T2', bound=int) + + def f() -> Callable[[T2], T2]: + return __any_object__ + + # Passing f() to g is an error because g expects a callable with an + # unconstrained parameter type. + def g(x: Callable[[T1], T1]): + pass + g(f()) # wrong-arg-types + + # Passing f() to h is okay because T1 in this Callable is just being used + # to save the parameter type for h's return type. + def h(x: Callable[[T1], Any]) -> T1: + return __any_object__ + h(f()) + """) + class GenericTypeAliasTest(test_base.BaseTest): """Tests for generic type aliases ("type macros").""" diff --git a/pytype/tests/test_typing1.py b/pytype/tests/test_typing1.py index f9c3e1ee0..3f2979f90 100644 --- a/pytype/tests/test_typing1.py +++ b/pytype/tests/test_typing1.py @@ -208,7 +208,7 @@ def test_reuse_name(self): import typing from typing import Any Sequence = typing.Sequence[int] - Sequence_: Any + Sequence_: type """) def test_type_checking_local(self): diff --git a/pytype/tests/test_typing2.py b/pytype/tests/test_typing2.py index 05defffcf..dc6fe2976 100644 --- a/pytype/tests/test_typing2.py +++ b/pytype/tests/test_typing2.py @@ -900,6 +900,15 @@ def f(x: Literal["x", "y"]): f(x) """) + def test_enum(self): + # Requires the enum overlay + self.Check(""" + import enum + from typing import Literal + class M(enum.Enum): + A = 1 + x: Literal[M.A] + """) if __name__ == "__main__": test_base.main() diff --git a/pytype/vm.py b/pytype/vm.py index 4edc7490e..accd6b9e2 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -25,7 +25,7 @@ from pytype import abstract from pytype import abstract_utils -from pytype import annotations_util +from pytype import annotation_utils from pytype import attribute from pytype import blocks from pytype import class_mixin @@ -230,7 +230,7 @@ def __init__(self, self.program = cfg.Program() self.root_node = self.program.NewCFGNode("root") self.program.entrypoint = self.root_node - self.annotations_util = annotations_util.AnnotationsUtil(self) + self.annotation_utils = annotation_utils.AnnotationUtils(self) self.attribute_handler = attribute.AbstractAttributeHandler(self) self.loaded_overlays = {} # memoize which overlays are loaded self.convert = convert.Converter(self) @@ -491,7 +491,7 @@ def _update_excluded_types(self, node): typ = local.get_type(node, name) if typ: func.signature.excluded_types.update( - p.name for p in self.annotations_util.get_type_parameters(typ)) + p.name for p in self.annotation_utils.get_type_parameters(typ)) if local.orig: for v in local.orig.data: if isinstance(v, abstract.BoundFunction): @@ -582,7 +582,7 @@ def _process_base_class(self, node, base): # other late annotations in order to support things like: # class Foo(List["Bar"]): ... # class Bar: ... - base_val = self.annotations_util.remove_late_annotations(base_val) + base_val = self.annotation_utils.remove_late_annotations(base_val) if isinstance(base_val, abstract.Union): # Union[A,B,...] is a valid base class, but we need to flatten it into a # single base variable. @@ -618,7 +618,7 @@ def _filter_out_metaclasses(self, bases): with_metaclass = True if not meta: # Only the first metaclass gets applied. - meta = b.get_class().to_variable(self.root_node) + meta = b.cls.to_variable(self.root_node) non_meta.extend(b.bases) if not with_metaclass: non_meta.append(base) @@ -1405,7 +1405,7 @@ def _load_annotation(self, node, name): if annots: typ = annots.get_type(node, name) if typ: - _, ret = self.annotations_util.init_annotation(node, name, typ) + _, ret = self.annotation_utils.init_annotation(node, name, typ) return ret raise KeyError(name) @@ -1467,24 +1467,24 @@ def _remove_recursion(self, node, name, value): for v in value.data): return value stack = self.simple_stack() - typ = self.annotations_util.extract_annotation(node, value, name, stack) + typ = self.annotation_utils.extract_annotation(node, value, name, stack) if self.late_annotations: recursive_annots = set(self.late_annotations[name]) else: recursive_annots = set() - for late_annot in self.annotations_util.get_late_annotations(typ): + for late_annot in self.annotation_utils.get_late_annotations(typ): if late_annot in recursive_annots: self.errorlog.not_supported_yet( stack, "Recursive type annotations", details="In annotation %r on %s" % (late_annot.expr, name)) - typ = self.annotations_util.remove_late_annotations(typ) + typ = self.annotation_utils.remove_late_annotations(typ) break return typ.to_variable(node) def _apply_annotation( self, state, op, name, orig_val, annotations_dict, check_types): """Applies the type annotation, if any, associated with this object.""" - typ, value = self.annotations_util.apply_annotation( + typ, value = self.annotation_utils.apply_annotation( state.node, op, name, orig_val) if annotations_dict is not None: if annotations_dict is self.current_annotated_locals: @@ -2604,7 +2604,8 @@ def byte_STORE_ATTR(self, state, op): maybe_cls.members) if annotations_dict: annotations_dict = annotations_dict.annotated_locals - elif isinstance(maybe_cls, abstract.PyTDClass): + elif (isinstance(maybe_cls, abstract.PyTDClass) and + maybe_cls != self.convert.type_type): node, attr = self.attribute_handler.get_attribute( state.node, obj_val, name) if attr: @@ -2652,8 +2653,8 @@ def byte_STORE_SUBSCR(self, state, op): else: allowed_type_params = ( self.frame.type_params | - self.annotations_util.get_callable_type_parameter_names(val)) - typ = self.annotations_util.extract_annotation( + self.annotation_utils.get_callable_type_parameter_names(val)) + typ = self.annotation_utils.extract_annotation( state.node, val, name, self.simple_stack(), allowed_type_params=allowed_type_params) self._record_annotation(state.node, op, name, typ) @@ -2698,15 +2699,13 @@ def _get_literal_sequence(self, data): try: return tuple(self.convert.value_to_constant(data, list)) except abstract_utils.ConversionError: - if data.cls: - for base in data.cls.mro: - if isinstance(base, abstract.TupleClass) and not base.formal: - # We've found a TupleClass with concrete parameters, which means - # we're a subclass of a heterogenous tuple (usually a - # typing.NamedTuple instance). - new_data = self.merge_values( - base.instantiate(self.root_node).data) - return self._get_literal_sequence(new_data) + for base in data.cls.mro: + if isinstance(base, abstract.TupleClass) and not base.formal: + # We've found a TupleClass with concrete parameters, which means + # we're a subclass of a heterogeneous tuple (usually a + # typing.NamedTuple instance). + new_data = self.merge_values(base.instantiate(self.root_node).data) + return self._get_literal_sequence(new_data) return None def _restructure_tuple(self, state, tup, pre, post): @@ -3178,7 +3177,7 @@ def _get_extra_function_args(self, state, arg): state, pos_defaults = state.popn(num_pos_defaults) free_vars = None # Python < 3.6 does not handle closure vars here. kw_defaults = self._convert_kw_defaults(kw_defaults) - annot = self.annotations_util.convert_function_annotations( + annot = self.annotation_utils.convert_function_annotations( state.node, raw_annotations) return state, pos_defaults, kw_defaults, annot, free_vars @@ -3194,7 +3193,7 @@ def _get_extra_function_args_3_6(self, state, arg): state, packed_annot = state.pop() annot = abstract_utils.get_atomic_python_constant(packed_annot, dict) for k in annot.keys(): - annot[k] = self.annotations_util.convert_function_type_annotation( + annot[k] = self.annotation_utils.convert_function_type_annotation( k, annot[k]) if arg & loadmarshal.MAKE_FUNCTION_HAS_KW_DEFAULTS: state, packed_kw_def = state.pop() @@ -3204,7 +3203,7 @@ def _get_extra_function_args_3_6(self, state, arg): state, packed_pos_def = state.pop() pos_defaults = abstract_utils.get_atomic_python_constant( packed_pos_def, tuple) - annot = self.annotations_util.convert_annotations_list( + annot = self.annotation_utils.convert_annotations_list( state.node, annot.items()) return state, pos_defaults, kw_defaults, annot, free_vars @@ -3243,7 +3242,7 @@ def _process_function_type_comment(self, node, op, func): if args != "...": annot = args.strip() try: - self.annotations_util.eval_multi_arg_annotation( + self.annotation_utils.eval_multi_arg_annotation( node, func, annot, fake_stack) except abstract_utils.ConversionError: self.errorlog.invalid_function_type_comment( @@ -3251,7 +3250,7 @@ def _process_function_type_comment(self, node, op, func): ret = self.convert.build_string(None, return_type) func.signature.set_annotation( - "return", self.annotations_util.extract_annotation( + "return", self.annotation_utils.extract_annotation( node, ret, "return", fake_stack)) def byte_MAKE_FUNCTION(self, state, op): @@ -3459,8 +3458,8 @@ def byte_STORE_ANNOTATION(self, state, op): state, value = state.pop() allowed_type_params = ( self.frame.type_params | - self.annotations_util.get_callable_type_parameter_names(value)) - typ = self.annotations_util.extract_annotation( + self.annotation_utils.get_callable_type_parameter_names(value)) + typ = self.annotation_utils.extract_annotation( state.node, value, name, self.simple_stack(), allowed_type_params=allowed_type_params) self._record_annotation(state.node, op, name, typ) @@ -3840,7 +3839,7 @@ def _check_test_assert(self, state, func, args): if not (f.isinstance_BoundFunction() and len(f.callself.data) == 1): return state cls = f.callself.data[0].cls - if not (cls and cls.isinstance_Class() and cls.is_test_class()): + if not (cls.isinstance_Class() and cls.is_test_class()): return state if f.name == "assertIsNotNone": if len(args) == 1: