From 9d2a869365a720ab6cbef4d6d71b4dcc91f8a4db Mon Sep 17 00:00:00 2001 From: rechen Date: Mon, 1 Jun 2020 17:14:17 -0700 Subject: [PATCH 01/17] Enable checking of assignments on the same line as a PEP 526-style annotation. With this change, pytype will report an error for: x: str = 0 by default. Erroring on: x = 0 # type: str or x: str x = 0 will still require the --check-variable-types flag. PiperOrigin-RevId: 314234094 --- pytype/overlays/dataclass_overlay.py | 5 ++++- pytype/vm.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 0c11330f7..d996b752a 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -82,7 +82,10 @@ def decorate(self, node, cls): else: init = True - if (not self.vm.options.check_variable_types or + # TODO(b/74434237): The first check can be removed once + # --check-variable-types is on by default. + if ((not self.vm.options.check_variable_types and + local.last_op.line not in self.vm.director._variable_annotations) or # pylint: disable=protected-access orig and orig.data == [self.vm.convert.none]): # vm._apply_annotation mostly takes care of checking that the default # matches the declared type. However, it allows None defaults, and diff --git a/pytype/vm.py b/pytype/vm.py index f3f1d1406..1ed1f35a9 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1293,7 +1293,9 @@ def _apply_annotation(self, state, op, name, orig_val, local): # An Any annotation can be used to essentially turn off inference in # cases where it is causing false positives or other issues. value = self.new_unsolvable(state.node) - if self.options.check_variable_types: + # TODO(b/74434237): Enable --check-variable-types by default. + if (self.options.check_variable_types or + op.line in self.director._variable_annotations): # pylint: disable=protected-access self.check_annotation_type_mismatch( state.node, name, typ, orig_val, self.frames, allow_none=True) return value From 41a828ef6e3881b6590acfb234ae14847fa815cb Mon Sep 17 00:00:00 2001 From: mdemello Date: Mon, 1 Jun 2020 17:48:54 -0700 Subject: [PATCH 02/17] Check that mutating a container does not violate its annotated type. If options.check_variable_types is set, the following code should fail: a: List[int] = [] a.append("hello") Also adds some test cases in bugs/mutation.py for TODOs. PiperOrigin-RevId: 314239043 --- pytype/abstract.py | 38 ++++++++++++++++++++++++++++ pytype/annotations_util.py | 2 ++ pytype/errors.py | 15 +++++++++++ pytype/tests/py3/test_annotations.py | 25 ++++++++++++++++++ 4 files changed, 80 insertions(+) diff --git a/pytype/abstract.py b/pytype/abstract.py index e5a1b20d0..63b284830 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -74,6 +74,9 @@ def __init__(self, name, vm): self._all_template_names = None self._instance = None + # true for instances created to apply type annotations + self.from_annotation = False + @property def all_template_names(self): if self._all_template_names is None: @@ -1771,6 +1774,41 @@ def call(self, node, func, args, alias_map=None): retvar.PasteVariable(result, node) all_mutations.update(mutations) + if all_mutations and self.vm.options.check_variable_types: + # Raise an error if: + # - An annotation has a type param that is not ambigious or empty + # - The mutation adds a type that is not ambiguous or empty + # TODO(mdemello): This does not check annotations in function args. + def filter_contents(var): + # reduces the work compatible_with has to do. + return set(x for x in var.data + if not x.isinstance_AMBIGUOUS_OR_EMPTY()) + + def compatible_with(existing, new): + """Check whether a new type can be added to a container.""" + for data in existing: + if self.vm.matcher.match_from_mro(new.cls, data.cls): + return True + return False + + for obj, name, values in all_mutations: + if obj.from_annotation: + params = obj.get_instance_type_parameter(name) + ps = filter_contents(params) + if ps: + # check if the container type is being broadened. + vs = filter_contents(values) + new = [x for x in (vs - ps) if not compatible_with(ps, x)] + if new: + # TODO(mdemello): Can we get the variable name of the container + # object from the opcode traces? + # TODO(mdemello): If the same object has several violations, e.g. + # a: Dict[str, int] = {} + # a[1] = 'a' + # we will print each mutation as a separate error. + self.vm.errorlog.container_type_mismatch( + self.vm.frames, obj.cls, params, values, None) + node = abstract_utils.apply_mutations(node, all_mutations.__iter__) return node, retvar diff --git a/pytype/annotations_util.py b/pytype/annotations_util.py index 2be6d69da..4d6c6d253 100644 --- a/pytype/annotations_util.py +++ b/pytype/annotations_util.py @@ -203,6 +203,8 @@ def apply_annotation(self, state, op, name, value): typ = self.extract_annotation( state.node, var, name, self.vm.simple_stack(), is_var=True) _, value = self.vm.init_class(state.node, typ) + for d in value.data: + d.from_annotation = True return typ, value def extract_annotation(self, node, var, name, stack, is_var=False): diff --git a/pytype/errors.py b/pytype/errors.py index a61c2126c..f94d89e48 100644 --- a/pytype/errors.py +++ b/pytype/errors.py @@ -1023,6 +1023,21 @@ def annotation_type_mismatch(self, stack, annot, binding, name): err_msg = "Type annotation%s does not match type of assignment" % suffix self.error(stack, err_msg, details=details) + @_error_name("container-type-mismatch") + def container_type_mismatch(self, stack, obj, params, values, name): + """Invalid combination of annotation and mutation.""" + annot_string = self._print_as_expected_type(obj) + old_content = self._join_printed_types( + set(self._print_as_actual_type(v) for v in params.data)) + new_content = self._join_printed_types( + set(self._print_as_actual_type(v) for v in values.data)) + details = ("Annotation: %s\n" % annot_string + + "Contained type: %s\n" % old_content + + "New contained type: %s" % new_content) + suffix = "" if name is None else " for " + name + err_msg = "New container type%s does not match type annotation" % suffix + self.error(stack, err_msg, details=details) + @_error_name("invalid-function-definition") def invalid_function_definition(self, stack, msg): """Invalid function constructed via metaprogramming.""" diff --git a/pytype/tests/py3/test_annotations.py b/pytype/tests/py3/test_annotations.py index cc0a2fc82..9ca3a4cac 100644 --- a/pytype/tests/py3/test_annotations.py +++ b/pytype/tests/py3/test_annotations.py @@ -1041,6 +1041,31 @@ def test_variable_annotations(self): a: int """) + def test_container_mutation(self): + errors = self.CheckWithErrors(""" + from typing import List + x: List[int] = [] + x.append("hello") # container-type-mismatch[e] + """) + pattern = r"Annot.*List\[int\].*Contained.*int.*New.*Union\[int, str\]" + self.assertErrorRegexes(errors, {"e": pattern}) + + def test_allowed_container_mutation_subclass(self): + self.Check(""" + from typing import List + class A: pass + class B(A): pass + x: List[A] = [] + x.append(B()) + """) + + def test_allowed_container_mutation_builtins(self): + self.Check(""" + from typing import List + x: List[float] = [] + x.append(0) + """) + @test_utils.skipUnlessPy((3, 7), reason="__future__.annotations is 3.7+ and " "is the default behavior in 3.8+") def test_postponed_evaluation(self): From becf9f89704e5d5d2f63f3374ecc204841375327 Mon Sep 17 00:00:00 2001 From: mdemello Date: Tue, 2 Jun 2020 13:12:30 -0700 Subject: [PATCH 03/17] Treat objects as True in a boolean context, unless explicitly overridden. Exceptions (treated as ambiguous) are: - booleans - numeric types - classes overriding __bool__ (__nonzero__ in python2) - classes overriding __len__ Provably empty builtin containers continue to be treated as False. PiperOrigin-RevId: 314390311 --- pytype/compare.py | 12 +++- pytype/compare_test.py | 104 +++++++++++++++----------------- pytype/mixin.py | 16 +++++ pytype/tests/py3/test_splits.py | 15 +++++ pytype/tests/test_splits.py | 23 +++++++ 5 files changed, 111 insertions(+), 59 deletions(-) diff --git a/pytype/compare.py b/pytype/compare.py index 68d39915b..5c10925f7 100644 --- a/pytype/compare.py +++ b/pytype/compare.py @@ -119,14 +119,22 @@ def compatible_with(value, logical_value): elif isinstance(value, mixin.PythonConstant): return bool(value.pyval) == logical_value elif isinstance(value, abstract.Instance): - # Containers with unset parameters and NoneType instances cannot match True. name = value.full_name if logical_value and name in _CONTAINER_NAMES: - return ( + # Containers with unset parameters cannot match True. + ret = ( value.has_instance_type_parameter(abstract_utils.T) and bool(value.get_instance_type_parameter(abstract_utils.T).bindings)) + return ret elif name == "__builtin__.NoneType": + # NoneType instances cannot match True. return not logical_value + elif name in NUMERIC: + # Numeric types can match both True and False + return True + elif not value.cls.overrides_bool: + # Objects evaluate to True unless explicitly overridden. + return logical_value return True elif isinstance(value, (abstract.Function, mixin.Class)): # Functions and classes always evaluate to True. diff --git a/pytype/compare_test.py b/pytype/compare_test.py index a769981aa..32e5d8cc9 100644 --- a/pytype/compare_test.py +++ b/pytype/compare_test.py @@ -25,45 +25,56 @@ def setUp(self): self._program = self._vm.program self._node = self._vm.root_cfg_node.ConnectNew("test_node") + def assertTruthy(self, value): + self.assertIs(True, compare.compatible_with(value, True)) + self.assertIs(False, compare.compatible_with(value, False)) + + def assertFalsy(self, value): + self.assertIs(False, compare.compatible_with(value, True)) + self.assertIs(True, compare.compatible_with(value, False)) + + def assertAmbiguous(self, value): + self.assertIs(True, compare.compatible_with(value, True)) + self.assertIs(True, compare.compatible_with(value, False)) + class InstanceTest(CompareTestBase): - def test_compatible_with_non_container(self): - # Compatible with either True or False. + def test_compatible_with_object(self): + # object() is not compatible with False i = abstract.Instance(self._vm.convert.object_type, self._vm) - self.assertIs(True, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + self.assertTruthy(i) + + def test_compatible_with_numeric(self): + # Numbers can evaluate to True or False + i = abstract.Instance(self._vm.convert.int_type, self._vm) + self.assertAmbiguous(i) def test_compatible_with_list(self): i = abstract.List([], self._vm) # Empty list is not compatible with True. - self.assertIs(False, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + self.assertFalsy(i) # Once a type parameter is set, list is compatible with True and False. i.merge_instance_type_parameter( self._node, abstract_utils.T, self._vm.convert.object_type.to_variable(self._vm.root_cfg_node)) - self.assertIs(True, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + self.assertAmbiguous(i) def test_compatible_with_set(self): i = abstract.Instance(self._vm.convert.set_type, self._vm) - # Empty list is not compatible with True. - self.assertIs(False, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + # Empty set is not compatible with True. + self.assertFalsy(i) # Once a type parameter is set, list is compatible with True and False. i.merge_instance_type_parameter( self._node, abstract_utils.T, self._vm.convert.object_type.to_variable(self._vm.root_cfg_node)) - self.assertIs(True, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + self.assertAmbiguous(i) def test_compatible_with_none(self): # This test is specifically for abstract.Instance, so we don't use # self._vm.convert.none, which is an AbstractOrConcreteValue. i = abstract.Instance(self._vm.convert.none_type, self._vm) - self.assertIs(False, compare.compatible_with(i, True)) - self.assertIs(True, compare.compatible_with(i, False)) + self.assertFalsy(i) def test_compare_frozensets(self): """Test that two frozensets can be compared for equality.""" @@ -82,13 +93,11 @@ def setUp(self): def test_compatible_with__not_empty(self): t = abstract.Tuple((self._var,), self._vm) - self.assertIs(True, compare.compatible_with(t, True)) - self.assertIs(False, compare.compatible_with(t, False)) + self.assertTruthy(t) def test_compatible_with__empty(self): t = abstract.Tuple((), self._vm) - self.assertIs(False, compare.compatible_with(t, True)) - self.assertIs(True, compare.compatible_with(t, False)) + self.assertFalsy(t) def test_getitem__concrete_index(self): t = abstract.Tuple((self._var,), self._vm) @@ -116,39 +125,33 @@ def setUp(self): self._var.AddBinding(abstract.Unknown(self._vm), [], self._node) def test_compatible_with__when_empty(self): - self.assertIs(False, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertFalsy(self._d) def test_compatible_with__after_setitem(self): # Once a slot is added, dict is ambiguous. self._d.setitem_slot(self._node, self._var, self._var) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) def test_compatible_with__after_set_str_item(self): self._d.set_str_item(self._node, "key", self._var) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(False, compare.compatible_with(self._d, False)) + self.assertTruthy(self._d) def test_compatible_with__after_unknown_update(self): # Updating an empty dict with an unknown value makes the former ambiguous. self._d.update(self._node, abstract.Unknown(self._vm)) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) def test_compatible_with__after_empty_update(self): empty_dict = abstract.Dict(self._vm) self._d.update(self._node, empty_dict) - self.assertIs(False, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertFalsy(self._d) def test_compatible_with__after_unambiguous_update(self): unambiguous_dict = abstract.Dict(self._vm) unambiguous_dict.set_str_item( self._node, "a", self._vm.new_unsolvable(self._node)) self._d.update(self._node, unambiguous_dict) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(False, compare.compatible_with(self._d, False)) + self.assertTruthy(self._d) def test_compatible_with__after_ambiguous_update(self): ambiguous_dict = abstract.Dict(self._vm) @@ -156,23 +159,19 @@ def test_compatible_with__after_ambiguous_update(self): self._node, abstract_utils.K, self._vm.new_unsolvable(self._node)) ambiguous_dict.could_contain_anything = True self._d.update(self._node, ambiguous_dict) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) def test_compatible_with__after_concrete_update(self): self._d.update(self._node, {}) - self.assertIs(False, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertFalsy(self._d) self._d.update(self._node, {"a": self._vm.new_unsolvable(self._node)}) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(False, compare.compatible_with(self._d, False)) + self.assertTruthy(self._d) def test_pop(self): self._d.set_str_item(self._node, "a", self._var) node, ret = self._d.pop_slot( self._node, self._vm.convert.build_string(self._node, "a")) - self.assertIs(False, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertFalsy(self._d) self.assertIs(node, self._node) self.assertIs(ret, self._var) @@ -181,8 +180,7 @@ def test_pop_with_default(self): node, ret = self._d.pop_slot( self._node, self._vm.convert.build_string(self._node, "a"), self._vm.convert.none.to_variable(self._node)) # default is ignored - self.assertIs(False, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertFalsy(self._d) self.assertIs(node, self._node) self.assertIs(ret, self._var) @@ -190,8 +188,7 @@ def test_bad_pop(self): self._d.set_str_item(self._node, "a", self._var) self.assertRaises(function.DictKeyMissing, self._d.pop_slot, self._node, self._vm.convert.build_string(self._node, "b")) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(False, compare.compatible_with(self._d, False)) + self.assertTruthy(self._d) def test_bad_pop_with_default(self): val = self._vm.convert.primitive_class_instances[int] @@ -199,8 +196,7 @@ def test_bad_pop_with_default(self): node, ret = self._d.pop_slot( self._node, self._vm.convert.build_string(self._node, "b"), self._vm.convert.none.to_variable(self._node)) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(False, compare.compatible_with(self._d, False)) + self.assertTruthy(self._d) self.assertIs(node, self._node) self.assertListEqual(ret.data, [self._vm.convert.none]) @@ -210,8 +206,7 @@ def test_ambiguous_pop(self): ambiguous_key = self._vm.convert.primitive_class_instances[str] node, ret = self._d.pop_slot( self._node, ambiguous_key.to_variable(self._node)) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) self.assertIs(node, self._node) self.assertListEqual(ret.data, [val]) @@ -222,8 +217,7 @@ def test_ambiguous_pop_with_default(self): default_var = self._vm.convert.none.to_variable(self._node) node, ret = self._d.pop_slot( self._node, ambiguous_key.to_variable(self._node), default_var) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) self.assertIs(node, self._node) self.assertSetEqual(set(ret.data), {val, self._vm.convert.none}) @@ -234,8 +228,7 @@ def test_ambiguous_dict_after_pop(self): self._node, ambiguous_key.to_variable(self._node), val.to_variable(self._node)) _, ret = self._d.pop_slot(node, self._vm.convert.build_string(node, "a")) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) self.assertListEqual(ret.data, [val]) def test_ambiguous_dict_after_pop_with_default(self): @@ -246,8 +239,7 @@ def test_ambiguous_dict_after_pop_with_default(self): val.to_variable(self._node)) _, ret = self._d.pop_slot(node, self._vm.convert.build_string(node, "a"), self._vm.convert.none.to_variable(node)) - self.assertIs(True, compare.compatible_with(self._d, True)) - self.assertIs(True, compare.compatible_with(self._d, False)) + self.assertAmbiguous(self._d) self.assertSetEqual(set(ret.data), {val, self._vm.convert.none}) @@ -257,16 +249,14 @@ def test_compatible_with(self): pytd_sig = pytd.Signature((), None, None, pytd.AnythingType(), (), ()) sig = function.PyTDSignature("f", pytd_sig, self._vm) f = abstract.PyTDFunction("f", (sig,), pytd.METHOD, self._vm) - self.assertIs(True, compare.compatible_with(f, True)) - self.assertIs(False, compare.compatible_with(f, False)) + self.assertTruthy(f) class ClassTest(CompareTestBase): def test_compatible_with(self): cls = abstract.InterpreterClass("X", [], {}, None, self._vm) - self.assertIs(True, compare.compatible_with(cls, True)) - self.assertIs(False, compare.compatible_with(cls, False)) + self.assertTruthy(cls) if __name__ == "__main__": diff --git a/pytype/mixin.py b/pytype/mixin.py index 841da8127..d3a9bd998 100644 --- a/pytype/mixin.py +++ b/pytype/mixin.py @@ -158,6 +158,7 @@ def init_mixin(self, metaclass): self._instance_cache = {} self._init_abstract_methods() self._init_protocol_methods() + self._init_overrides_bool() self._all_formal_type_parameters = datatypes.AliasingMonitorDict() self._all_formal_type_parameters_loaded = False @@ -232,6 +233,21 @@ def _init_protocol_methods(self): protocol_methods = {m for m in protocol_methods if m not in cls} self.protocol_methods = protocol_methods + def _init_overrides_bool(self): + """Compute and cache whether the class sets its own boolean value.""" + # A class's instances can evaluate to False if it defines __bool__ or + # __len__. Python2 used __nonzero__ rather than __bool__. + bool_override = "__bool__" if self.vm.PY3 else "__nonzero__" + if self.isinstance_ParameterizedClass(): + self.overrides_bool = self.base_cls.overrides_bool + return + for cls in self.mro: + if isinstance(cls, Class): + if any(x in cls.get_own_methods() for x in (bool_override, "__len__")): + self.overrides_bool = True + return + self.overrides_bool = False + def get_own_abstract_methods(self): """Get the abstract methods defined by this class.""" raise NotImplementedError(self.__class__.__name__) diff --git a/pytype/tests/py3/test_splits.py b/pytype/tests/py3/test_splits.py index 43c084e08..4e3ea7055 100644 --- a/pytype/tests/py3/test_splits.py +++ b/pytype/tests/py3/test_splits.py @@ -256,5 +256,20 @@ def f(x: Optional[Union[str, bytes]]): return x.upper() """) + def test_override_bool(self): + ty = self.Infer(""" + class A: + def __bool__(self): + return __random__ + + x = A() and True + """) + self.assertTypesMatchPytd(ty, """ + from typing import Union + class A: + def __bool__(self) -> bool: ... + x: Union[A, bool] + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_splits.py b/pytype/tests/test_splits.py index 0c71fdeff..a0df3f1ad 100644 --- a/pytype/tests/test_splits.py +++ b/pytype/tests/test_splits.py @@ -794,5 +794,28 @@ def h(): return f(object).values() """) + def test_object_truthiness(self): + ty = self.Infer(""" + x = object() and True + """) + self.assertTypesMatchPytd(ty, """ + x: bool + """) + + def test_override_len(self): + ty = self.Infer(""" + class A: + def __len__(self): + return 42 + + x = A() and True + """) + self.assertTypesMatchPytd(ty, """ + from typing import Union + class A: + def __len__(self) -> int: ... + x: Union[A, bool] + """) + test_base.main(globals(), __name__ == "__main__") From 52b641d67c531d57d7e05a7e419740d6e7ae2891 Mon Sep 17 00:00:00 2001 From: mdemello Date: Tue, 2 Jun 2020 16:03:51 -0700 Subject: [PATCH 04/17] Group multiple container type errors for a single object into one error. New output format example: File "third_party/py/pytype/bugs/mutation.py", line 19, in : New container type does not match type annotation [container-type-mismatch] Annotation: Dict[str, str] (type parameters Dict[_K, _V]) Contained types: _K: str _V: str New contained types: _K: Union[int, str] _V: Union[float, str] PiperOrigin-RevId: 314422487 --- pytype/abstract.py | 14 +++++---- pytype/errors.py | 45 +++++++++++++++++++++------- pytype/tests/py3/test_annotations.py | 11 +++++++ 3 files changed, 54 insertions(+), 16 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index 63b284830..e22139f5c 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -1791,6 +1791,8 @@ def compatible_with(existing, new): return True return False + errors = collections.defaultdict(dict) + for obj, name, values in all_mutations: if obj.from_annotation: params = obj.get_instance_type_parameter(name) @@ -1802,12 +1804,12 @@ def compatible_with(existing, new): if new: # TODO(mdemello): Can we get the variable name of the container # object from the opcode traces? - # TODO(mdemello): If the same object has several violations, e.g. - # a: Dict[str, int] = {} - # a[1] = 'a' - # we will print each mutation as a separate error. - self.vm.errorlog.container_type_mismatch( - self.vm.frames, obj.cls, params, values, None) + formal = name.split(".")[-1] + errors[obj][formal] = (params, values) + + for obj, errs in errors.items(): + self.vm.errorlog.container_type_mismatch( + self.vm.frames, obj, errs, None) node = abstract_utils.apply_mutations(node, all_mutations.__iter__) return node, retvar diff --git a/pytype/errors.py b/pytype/errors.py index f94d89e48..7073605eb 100644 --- a/pytype/errors.py +++ b/pytype/errors.py @@ -510,6 +510,14 @@ def _print_as_actual_type(self, t): with t.vm.convert.pytd_convert.produce_detailed_output(): return self._pytd_print(t.to_type()) + def _print_as_generic_type(self, t): + generic = pytd_utils.MakeClassOrContainerType( + t.get_instance_type().base_type, + t.formal_type_parameters.keys(), + False) + with t.vm.convert.pytd_convert.produce_detailed_output(): + return self._pytd_print(generic) + def _print_as_return_type(self, t): ret = self._pytd_print(t) # typing.NoReturn is a prettier alias for nothing. @@ -1024,16 +1032,33 @@ def annotation_type_mismatch(self, stack, annot, binding, name): self.error(stack, err_msg, details=details) @_error_name("container-type-mismatch") - def container_type_mismatch(self, stack, obj, params, values, name): - """Invalid combination of annotation and mutation.""" - annot_string = self._print_as_expected_type(obj) - old_content = self._join_printed_types( - set(self._print_as_actual_type(v) for v in params.data)) - new_content = self._join_printed_types( - set(self._print_as_actual_type(v) for v in values.data)) - details = ("Annotation: %s\n" % annot_string + - "Contained type: %s\n" % old_content + - "New contained type: %s" % new_content) + def container_type_mismatch(self, stack, obj, mutations, name): + """Invalid combination of annotation and mutation. + + Args: + stack: the frame stack + obj: the container instance being mutated + mutations: a dict of {parameter name: (annotated types, new types)} + name: the variable name (or None) + """ + cls = obj.cls + annot_string = "%s (type parameters %s)" % ( + self._print_as_expected_type(cls), + self._print_as_generic_type(cls)) + details = "Annotation: %s\n" % annot_string + contained = "" + new_contained = "" + for formal in cls.formal_type_parameters.keys(): + if formal in mutations: + params, values = mutations[formal] + old_content = self._join_printed_types( + set(self._print_as_actual_type(v) for v in params.data)) + new_content = self._join_printed_types( + set(self._print_as_actual_type(v) for v in values.data)) + contained += " %s: %s\n" % (formal, old_content) + new_contained += " %s: %s\n" % (formal, new_content) + details += ("Contained types:\n" + contained + + "New contained types:\n" + new_contained) suffix = "" if name is None else " for " + name err_msg = "New container type%s does not match type annotation" % suffix self.error(stack, err_msg, details=details) diff --git a/pytype/tests/py3/test_annotations.py b/pytype/tests/py3/test_annotations.py index 9ca3a4cac..eccc1478b 100644 --- a/pytype/tests/py3/test_annotations.py +++ b/pytype/tests/py3/test_annotations.py @@ -1050,6 +1050,17 @@ def test_container_mutation(self): pattern = r"Annot.*List\[int\].*Contained.*int.*New.*Union\[int, str\]" self.assertErrorRegexes(errors, {"e": pattern}) + def test_container_multiple_mutations(self): + errors = self.CheckWithErrors(""" + from typing import Dict + x: Dict[int, str] = {} + x["hello"] = 1.0 # container-type-mismatch[e] + """) + pattern = (r"Annot.*Dict\[int, str\].*Dict\[_K, _V\].*" + + r"Contained.*_K.*int.*_V.*str.*" + r"New.*_K.*Union\[int, str\].*_V.*Union\[float, str\]") + self.assertErrorRegexes(errors, {"e": pattern}) + def test_allowed_container_mutation_subclass(self): self.Check(""" from typing import List From 770c2b5cdaccbadeed06a5bbc071bbb354a28ac0 Mon Sep 17 00:00:00 2001 From: rechen Date: Wed, 3 Jun 2020 00:38:15 -0700 Subject: [PATCH 05/17] Type-check class attributes against their annotations. Adds a new option, --check-attribute-types, that will make pytype type-check attribute assignments against annotations, e.g., class Foo: x: int def foo(self): self.x = 'hello, world' # annotation-type-mismatch I've left a couple of features for future CLs: * Making annotations take precedence over instance attribute values when generating pyi files. * Handling attribute type comments in __new__ and __init__. Also puts container mutation checking behind its own flag, --check-container-types, rather than reusing --check-variable-types. Having this many flags is rather ugly and redundant, but they'll go away as we release various things, and they make the release process a bit easier. PiperOrigin-RevId: 314483072 --- pytype/abstract.py | 6 +-- pytype/config.py | 18 ++++++-- pytype/tests/py3/test_attributes.py | 9 ++++ pytype/tests/test_attributes.py | 28 +++++++++++++ pytype/tests/test_base.py | 2 + pytype/tools/analyze_project/config.py | 4 ++ pytype/vm.py | 58 ++++++++++++++++++++------ 7 files changed, 105 insertions(+), 20 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index e22139f5c..6bd34361e 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -1130,9 +1130,9 @@ def update(self, node, other_dict, omit=()): class AnnotationsDict(Dict): """__annotations__ dict.""" - def __init__(self, vm): + def __init__(self, annotated_locals, vm): super().__init__(vm) - self.annotated_locals = vm.current_annotated_locals + self.annotated_locals = annotated_locals def get_type(self, node, name): if name not in self.annotated_locals: @@ -1774,7 +1774,7 @@ def call(self, node, func, args, alias_map=None): retvar.PasteVariable(result, node) all_mutations.update(mutations) - if all_mutations and self.vm.options.check_variable_types: + if all_mutations and self.vm.options.check_container_types: # Raise an error if: # - An annotation has a type param that is not ambigious or empty # - The mutation adds a type that is not ambiguous or empty diff --git a/pytype/config.py b/pytype/config.py index aeec0f03c..72b00ad99 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -134,14 +134,24 @@ def add_basic_options(o): "--strict-import", action="store_true", dest="strict_import", default=False, help="Experimental: Only load submodules that are explicitly imported.") - o.add_argument( - "--check-variable-types", action="store_true", - dest="check_variable_types", default=False, - help="Experimental: Check variable values against their annotations.") o.add_argument( "--precise-return", action="store_true", dest="precise_return", default=False, help=("Experimental: Infer precise return types even for " "invalid function calls.")) + temporary = ("This flag is temporary and will be removed once this behavior " + "is enabled by default.") + o.add_argument( + "--check-attribute-types", action="store_true", + dest="check_attribute_types", default=False, + help="Check attribute values against their annotations. " + temporary) + o.add_argument( + "--check-container-types", action="store_true", + dest="check_container_types", default=False, + help="Check container mutations against their annotations. " + temporary) + o.add_argument( + "--check-variable-types", action="store_true", + dest="check_variable_types", default=False, + help="Check variable values against their annotations. " + temporary) def add_subtools(o): diff --git a/pytype/tests/py3/test_attributes.py b/pytype/tests/py3/test_attributes.py index 8d6548344..9429864a8 100644 --- a/pytype/tests/py3/test_attributes.py +++ b/pytype/tests/py3/test_attributes.py @@ -172,5 +172,14 @@ def __getitem__(self, x): pass """) + def test_check_variable_annotation(self): + errors = self.CheckWithErrors(""" + class Foo: + x: int + def foo(self): + self.x = 'hello, world' # annotation-type-mismatch[e] + """) + self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"}) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_attributes.py b/pytype/tests/test_attributes.py index 1228ea5bd..e5566f630 100644 --- a/pytype/tests/test_attributes.py +++ b/pytype/tests/test_attributes.py @@ -789,5 +789,33 @@ def f(): def f() -> Optional[str]: ... """) + def test_bad_instance_assignment(self): + errors = self.CheckWithErrors(""" + class Foo: + x = None # type: int + def foo(self): + self.x = 'hello, world' # annotation-type-mismatch[e] + """) + self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"}) + + def test_bad_cls_assignment(self): + errors = self.CheckWithErrors(""" + class Foo: + x = None # type: int + Foo.x = 'hello, world' # annotation-type-mismatch[e] + """) + self.assertErrorRegexes(errors, {"e": r"Annotation: int.*Assignment: str"}) + + def test_any_annotation(self): + self.Check(""" + from typing import Any + class Foo: + x = None # type: Any + def foo(self): + print(self.x.some_attr) + self.x = 0 + print(self.x.some_attr) + """) + test_base.main(globals(), __name__ == "__main__") diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index a9284b6d4..ba373d975 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -157,6 +157,8 @@ def t(name): # pylint: disable=invalid-name def setUp(self): super(BaseTest, self).setUp() self.options = config.Options.create(python_version=self.python_version, + check_attribute_types=True, + check_container_types=True, check_variable_types=True) @property diff --git a/pytype/tools/analyze_project/config.py b/pytype/tools/analyze_project/config.py index 8656c1488..4c20799f6 100644 --- a/pytype/tools/analyze_project/config.py +++ b/pytype/tools/analyze_project/config.py @@ -62,6 +62,10 @@ # The missing fields will be filled in by generate_sample_config_or_die. _PYTYPE_SINGLE_ITEMS = { + 'check_attribute_types': Item( + None, 'False', ArgInfo('--check-attribute-types', None), None), + 'check_container_types': Item( + None, 'False', ArgInfo('--check-container-types', None), None), 'check_variable_types': Item( None, 'False', ArgInfo('--check-variable-types', None), None), 'disable': Item( diff --git a/pytype/vm.py b/pytype/vm.py index 1ed1f35a9..910316500 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -635,6 +635,15 @@ def make_class(self, node, name_var, bases, class_dict_var, cls_var, # pylint: disable=g-long-ternary cls = abstract_utils.get_atomic_value( cls_var, default=self.convert.unsolvable) if cls_var else None + if ("__annotations__" not in class_dict.members and + name in self.annotated_locals): + # Stores type comments in an __annotations__ member as if they were + # PEP 526-style variable annotations, so that we can type-check + # attribute assignments. + annotations_dict = self.annotated_locals[name] + if any(local.typ for local in annotations_dict.values()): + class_dict.members["__annotations__"] = abstract.AnnotationsDict( + annotations_dict, self).to_variable(node) try: val = abstract.InterpreterClass( name, @@ -1281,21 +1290,21 @@ def _check_aliased_type_params(self, value): self.errorlog.not_supported_yet( self.frames, "aliases of Unions with type parameters") - def _apply_annotation(self, state, op, name, orig_val, local): + def _apply_annotation( + self, state, op, name, orig_val, annotations_dict, check_types): """Applies the type annotation, if any, associated with this object.""" typ, value = self.annotations_util.apply_annotation( state, op, name, orig_val) - if local: - self._record_local(state.node, op, name, typ, orig_val) - if typ is None and name in self.current_annotated_locals: - typ = self.current_annotated_locals[name].get_type(state.node, name) + if annotations_dict is not None: + if annotations_dict is self.current_annotated_locals: + self._record_local(state.node, op, name, typ, orig_val) + if typ is None and name in annotations_dict: + typ = annotations_dict[name].get_type(state.node, name) if typ == self.convert.unsolvable: # An Any annotation can be used to essentially turn off inference in # cases where it is causing false positives or other issues. value = self.new_unsolvable(state.node) - # TODO(b/74434237): Enable --check-variable-types by default. - if (self.options.check_variable_types or - op.line in self.director._variable_annotations): # pylint: disable=protected-access + if check_types: self.check_annotation_type_mismatch( state.node, name, typ, orig_val, self.frames, allow_none=True) return value @@ -1329,7 +1338,12 @@ def check_annotation_type_mismatch( def _pop_and_store(self, state, op, name, local): """Pop a value off the stack and store it in a variable.""" state, orig_val = state.pop() - value = self._apply_annotation(state, op, name, orig_val, local) + annotations_dict = self.current_annotated_locals if local else None + # TODO(b/74434237): Enable --check-variable-types by default. + check_types = (self.options.check_variable_types or + op.line in self.director._variable_annotations) # pylint: disable=protected-access + value = self._apply_annotation( + state, op, name, orig_val, annotations_dict, check_types) self._check_aliased_type_params(value) state = state.forward_cfg_node() state = self._store_value(state, name, value, local) @@ -1946,7 +1960,8 @@ def byte_STORE_DEREF(self, state, op): state, value = state.pop() assert isinstance(value, cfg.Variable) name = self.get_closure_var_name(op.arg) - value = self._apply_annotation(state, op, name, value, True) + value = self._apply_annotation( + state, op, name, value, self.current_annotated_locals, check_types=True) state = state.forward_cfg_node() self.frame.cells[op.arg].PasteVariable(value, state.node) state = state.forward_cfg_node() @@ -2156,8 +2171,24 @@ def byte_STORE_ATTR(self, state, op): """Store an attribute.""" name = self.frame.f_code.co_names[op.arg] state, (val, obj) = state.popn(2) - # We do not want to record attributes as local variables. - val = self._apply_annotation(state, op, name, val, False) + # If `obj` is a single InterpreterClass or an instance of one, then grab its + # __annotations__ dict so we can type-check the new attribute value. + try: + maybe_cls = abstract_utils.get_atomic_value(obj) + except abstract_utils.ConversionError: + annotations_dict = None + else: + if not isinstance(maybe_cls, abstract.InterpreterClass): + maybe_cls = maybe_cls.cls + if isinstance(maybe_cls, abstract.InterpreterClass): + annotations_dict = abstract_utils.get_annotations_dict( + maybe_cls.members) + if annotations_dict: + annotations_dict = annotations_dict.annotated_locals + else: + annotations_dict = None + val = self._apply_annotation(state, op, name, val, annotations_dict, + self.options.check_attribute_types) state = state.forward_cfg_node() state = self.store_attr(state, obj, name, val) state = state.forward_cfg_node() @@ -2960,7 +2991,8 @@ def byte_DELETE_SLICE_3(self, state, op): def byte_SETUP_ANNOTATIONS(self, state, op): """Sets up variable annotations in locals().""" - annotations = abstract.AnnotationsDict(self).to_variable(state.node) + annotations = abstract.AnnotationsDict( + self.current_annotated_locals, self).to_variable(state.node) return self.store_local(state, "__annotations__", annotations) def _record_annotation(self, node, op, name, typ): From 7c209d107b445638a154e476812727e5c54381e0 Mon Sep 17 00:00:00 2001 From: mdemello Date: Thu, 4 Jun 2020 11:06:24 -0700 Subject: [PATCH 06/17] Check for function signature annotations when mutating an argument. A corner case where an alias can point to two separate annotated args via different code paths has been left as a TODO PiperOrigin-RevId: 314763912 --- pytype/abstract.py | 2 ++ pytype/tests/py3/test_annotations.py | 7 +++++-- pytype/tests/py3/test_functions.py | 8 ++++---- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index 6bd34361e..15cd856af 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -3179,6 +3179,8 @@ def call(self, node, func, args, new_locals=False, alias_map=None): extra_key = (self.get_first_opcode(), name) node, callargs[name] = self.vm.init_class( node, annotations[name], extra_key=extra_key) + for d in callargs[name].data: + d.from_annotation = True try: frame = self.vm.make_frame( node, self.code, self.f_globals, self.f_locals, callargs, diff --git a/pytype/tests/py3/test_annotations.py b/pytype/tests/py3/test_annotations.py index eccc1478b..74f33673c 100644 --- a/pytype/tests/py3/test_annotations.py +++ b/pytype/tests/py3/test_annotations.py @@ -90,6 +90,9 @@ def foo(x: int): self.assertErrorRegexes(errors, {"e": r"upper.*int"}) def test_list(self): + # TODO(mdemello): Do not check variables with bindings from multiple + # annotations. + self.options.tweak(check_container_types=False) ty = self.Infer(""" from typing import List @@ -819,10 +822,10 @@ def __init__(self) -> None: ... """) def test_change_annotated_arg(self): - ty = self.Infer(""" + ty, _ = self.InferWithErrors(""" from typing import Dict def f(x: Dict[str, str]): - x[True] = 42 + x[True] = 42 # container-type-mismatch return x v = f({"a": "b"}) """, deep=False) diff --git a/pytype/tests/py3/test_functions.py b/pytype/tests/py3/test_functions.py index 3c238699f..26c3da727 100644 --- a/pytype/tests/py3/test_functions.py +++ b/pytype/tests/py3/test_functions.py @@ -206,10 +206,10 @@ def __init__(self, x: int): """) def test_argument_name_conflict(self): - ty = self.Infer(""" + ty, _ = self.InferWithErrors(""" from typing import Dict def f(x: Dict[str, int]): - x[""] = "" + x[""] = "" # container-type-mismatch return x def g(x: Dict[str, int]): return x @@ -221,10 +221,10 @@ def g(x: Dict[str, int]) -> Dict[str, int] """) def test_argument_type_conflict(self): - ty = self.Infer(""" + ty, _ = self.InferWithErrors(""" from typing import Dict def f(x: Dict[str, int], y: Dict[str, int]): - x[""] = "" + x[""] = "" # container-type-mismatch return x, y """) self.assertTypesMatchPytd(ty, """ From d4d0252f2e05298fcb247070c0abc761221a77b0 Mon Sep 17 00:00:00 2001 From: mdemello Date: Thu, 4 Jun 2020 13:21:19 -0700 Subject: [PATCH 07/17] Preserve the variable name to which a type annotation was applied. Used in error messages when reporting container type annotation mismatches. PiperOrigin-RevId: 314791344 --- pytype/abstract.py | 18 ++++++++++-------- pytype/annotations_util.py | 2 +- pytype/errors.py | 2 +- pytype/tests/py3/test_annotations.py | 4 ++-- 4 files changed, 14 insertions(+), 12 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index 15cd856af..8d38b1381 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -74,8 +74,11 @@ def __init__(self, name, vm): self._all_template_names = None self._instance = None - # true for instances created to apply type annotations - self.from_annotation = False + # The variable or function arg name with the type annotation that this + # instance was created from. For example, + # x: str = "hello" + # would create an instance of str with from_annotation = 'x' + self.from_annotation = None @property def all_template_names(self): @@ -1778,7 +1781,6 @@ def call(self, node, func, args, alias_map=None): # Raise an error if: # - An annotation has a type param that is not ambigious or empty # - The mutation adds a type that is not ambiguous or empty - # TODO(mdemello): This does not check annotations in function args. def filter_contents(var): # reduces the work compatible_with has to do. return set(x for x in var.data @@ -1802,14 +1804,14 @@ def compatible_with(existing, new): vs = filter_contents(values) new = [x for x in (vs - ps) if not compatible_with(ps, x)] if new: - # TODO(mdemello): Can we get the variable name of the container - # object from the opcode traces? formal = name.split(".")[-1] - errors[obj][formal] = (params, values) + errors[obj][formal] = (params, values, obj.from_annotation) for obj, errs in errors.items(): + names = {name for _, _, name in errs.values()} + name = list(names)[0] if len(names) == 1 else None self.vm.errorlog.container_type_mismatch( - self.vm.frames, obj, errs, None) + self.vm.frames, obj, errs, name) node = abstract_utils.apply_mutations(node, all_mutations.__iter__) return node, retvar @@ -3180,7 +3182,7 @@ def call(self, node, func, args, new_locals=False, alias_map=None): node, callargs[name] = self.vm.init_class( node, annotations[name], extra_key=extra_key) for d in callargs[name].data: - d.from_annotation = True + d.from_annotation = name try: frame = self.vm.make_frame( node, self.code, self.f_globals, self.f_locals, callargs, diff --git a/pytype/annotations_util.py b/pytype/annotations_util.py index 4d6c6d253..3c0c82fec 100644 --- a/pytype/annotations_util.py +++ b/pytype/annotations_util.py @@ -204,7 +204,7 @@ def apply_annotation(self, state, op, name, value): state.node, var, name, self.vm.simple_stack(), is_var=True) _, value = self.vm.init_class(state.node, typ) for d in value.data: - d.from_annotation = True + d.from_annotation = name return typ, value def extract_annotation(self, node, var, name, stack, is_var=False): diff --git a/pytype/errors.py b/pytype/errors.py index 7073605eb..15199a309 100644 --- a/pytype/errors.py +++ b/pytype/errors.py @@ -1050,7 +1050,7 @@ def container_type_mismatch(self, stack, obj, mutations, name): new_contained = "" for formal in cls.formal_type_parameters.keys(): if formal in mutations: - params, values = mutations[formal] + params, values, _ = mutations[formal] old_content = self._join_printed_types( set(self._print_as_actual_type(v) for v in params.data)) new_content = self._join_printed_types( diff --git a/pytype/tests/py3/test_annotations.py b/pytype/tests/py3/test_annotations.py index 74f33673c..f9440670b 100644 --- a/pytype/tests/py3/test_annotations.py +++ b/pytype/tests/py3/test_annotations.py @@ -825,7 +825,7 @@ def test_change_annotated_arg(self): ty, _ = self.InferWithErrors(""" from typing import Dict def f(x: Dict[str, str]): - x[True] = 42 # container-type-mismatch + x[True] = 42 # container-type-mismatch[e] return x v = f({"a": "b"}) """, deep=False) @@ -1059,7 +1059,7 @@ def test_container_multiple_mutations(self): x: Dict[int, str] = {} x["hello"] = 1.0 # container-type-mismatch[e] """) - pattern = (r"Annot.*Dict\[int, str\].*Dict\[_K, _V\].*" + + pattern = (r"New container.*for x.*Dict\[int, str\].*Dict\[_K, _V\].*" + r"Contained.*_K.*int.*_V.*str.*" r"New.*_K.*Union\[int, str\].*_V.*Union\[float, str\]") self.assertErrorRegexes(errors, {"e": pattern}) From 6456c2f89848da7c07a9ebb0fbdf6d678463a841 Mon Sep 17 00:00:00 2001 From: rechen Date: Thu, 4 Jun 2020 15:22:45 -0700 Subject: [PATCH 08/17] Fix a crash caused by overrides_bool not existing on non-class objects. Showed up when attempting to analyze //apps/intelligence/cody/tools/codylab:testit. PiperOrigin-RevId: 314814112 --- pytype/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytype/compare.py b/pytype/compare.py index 5c10925f7..f554f07fa 100644 --- a/pytype/compare.py +++ b/pytype/compare.py @@ -132,7 +132,7 @@ def compatible_with(value, logical_value): elif name in NUMERIC: # Numeric types can match both True and False return True - elif not value.cls.overrides_bool: + elif isinstance(value.cls, mixin.Class) and not value.cls.overrides_bool: # Objects evaluate to True unless explicitly overridden. return logical_value return True From 53d0670dd4fdc53ad39b1fa4dca31a984979976b Mon Sep 17 00:00:00 2001 From: rechen Date: Mon, 8 Jun 2020 12:15:12 -0700 Subject: [PATCH 09/17] Fix: bool(Iterable[X]) can be either True or False. PiperOrigin-RevId: 315326106 --- pytype/compare.py | 4 ++++ pytype/tests/py3/test_splits.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pytype/compare.py b/pytype/compare.py index f554f07fa..d7ee7f561 100644 --- a/pytype/compare.py +++ b/pytype/compare.py @@ -133,6 +133,10 @@ def compatible_with(value, logical_value): # Numeric types can match both True and False return True elif isinstance(value.cls, mixin.Class) and not value.cls.overrides_bool: + if getattr(value.cls, "template", None): + # A parameterized class can match both True and False, since it might be + # an empty container. + return True # Objects evaluate to True unless explicitly overridden. return logical_value return True diff --git a/pytype/tests/py3/test_splits.py b/pytype/tests/py3/test_splits.py index 4e3ea7055..7e0ff5302 100644 --- a/pytype/tests/py3/test_splits.py +++ b/pytype/tests/py3/test_splits.py @@ -212,6 +212,33 @@ def nested() -> None: print(arg.upper()) """) + def test_iterable_truthiness(self): + ty = self.Infer(""" + from typing import Iterable + def f(x: Iterable[int]): + return 0 if x else '' + """) + self.assertTypesMatchPytd(ty, """ + from typing import Iterable, Union + def f(x: Iterable[int]) -> Union[int, str]: ... + """) + + def test_custom_container_truthiness(self): + ty = self.Infer(""" + from typing import Iterable, TypeVar + T = TypeVar('T') + class MyIterable(Iterable[T]): + pass + def f(x: MyIterable[int]): + return 0 if x else '' + """) + self.assertTypesMatchPytd(ty, """ + from typing import Iterable, TypeVar, Union + T = TypeVar('T') + class MyIterable(Iterable[T]): ... + def f(x: MyIterable[int]) -> Union[int, str]: ... + """) + class SplitTestPy3(test_base.TargetPython3FeatureTest): """Tests for if-splitting in Python 3.""" From 44da154d69b8888559131f0330bdd86ae3f12c40 Mon Sep 17 00:00:00 2001 From: mdemello Date: Tue, 9 Jun 2020 12:39:16 -0700 Subject: [PATCH 10/17] If cls is the class argument of Foo.__new__, treat `cls is Foo` as ambiguous. Allows for the fact that __new__ might be called from a subclass of Foo, and that checking that we are not in the base class is a common idiom. PiperOrigin-RevId: 315538618 --- pytype/tests/test_cmp.py | 16 ++++++++++++++++ pytype/vm.py | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/pytype/tests/test_cmp.py b/pytype/tests/test_cmp.py index d096182a3..f59adaf6d 100644 --- a/pytype/tests/test_cmp.py +++ b/pytype/tests/test_cmp.py @@ -139,6 +139,22 @@ def f(x, y): """, show_library_calls=True) self.assertOnlyHasReturnType(ty.Lookup("f"), self.bool) + def test_class_new(self): + # The assert should not block inference of the return type, since cls could + # be a subclass of Foo + ty = self.Infer(""" + class Foo(object): + def __new__(cls, *args, **kwargs): + assert(cls is not Foo) + return object.__new__(cls) + """) + self.assertTypesMatchPytd(ty, """ + from typing import Type, TypeVar + _TFoo = TypeVar('_TFoo', bound=Foo) + class Foo: + def __new__(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... + """) + class LtTest(test_base.TargetIndependentTest): """Test for "x < y". Also test overloading.""" diff --git a/pytype/vm.py b/pytype/vm.py index 910316500..bc5dbddb0 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1587,7 +1587,28 @@ def unary_operator(self, state, name): state = state.push(result) return state + def _is_classmethod_cls_arg(self, var): + """True if var is the first arg of a class method in the current frame.""" + if not (self.frame.func and self.frame.first_posarg): + return False + + func = self.frame.func.data + # TODO(b/158525984): right now the only classmethod we infer a bound cls + # type for is cls.__new__ + if func.name.rsplit(".")[-1] == "__new__": + is_cls = not set(var.data) - set(self.frame.first_posarg.data) + return is_cls + return False + def expand_bool_result(self, node, left, right, name, maybe_predicate): + """Common functionality for 'is' and 'is not'.""" + if (self._is_classmethod_cls_arg(left) or + self._is_classmethod_cls_arg(right)): + # If cls is the first argument of a classmethod, it could be bound to + # either the defining class or one of its subclasses, so `is` is + # ambiguous. + return self.new_unsolvable(node) + result = self.program.NewVariable() for x in left.bindings: for y in right.bindings: From d4263e4b22d5e276805f2c07652b7acea8fb784e Mon Sep 17 00:00:00 2001 From: mdemello Date: Tue, 9 Jun 2020 14:45:47 -0700 Subject: [PATCH 11/17] FIX: Test that we are in a function when we add an attr to the callgraph. PiperOrigin-RevId: 315564137 --- pytype/tools/xref/callgraph.py | 4 ++++ pytype/tools/xref/callgraph_test.py | 27 +++++++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/pytype/tools/xref/callgraph.py b/pytype/tools/xref/callgraph.py index 126341427..05905aad7 100644 --- a/pytype/tools/xref/callgraph.py +++ b/pytype/tools/xref/callgraph.py @@ -133,6 +133,10 @@ def add_attr(self, ref, defn): """Add an attr access within a function body.""" attrib = ref.name scope = ref.ref_scope + if scope not in self.fmap: + # This call was not within a function body. + return + try: d = self.index.envs[scope].env[ref.target] except KeyError: diff --git a/pytype/tools/xref/callgraph_test.py b/pytype/tools/xref/callgraph_test.py index 8851a8b42..83808a099 100644 --- a/pytype/tools/xref/callgraph_test.py +++ b/pytype/tools/xref/callgraph_test.py @@ -160,5 +160,32 @@ def h(b): f = fns["module.%s" % fn] self.assertParamsEqual(f.params, params) + def test_toplevel_calls(self): + """Don't index calls outside a function.""" + ix = self.index_code(""" + def f(x: int): + return "hello" + + a = f(10) + a.upcase() + """) + fns = ix.function_map + # we should only have f in fns, despite function calls at module scope + self.assertHasFunctions(fns, ["f"]) + + def test_class_level_calls(self): + """Don't index calls outside a function.""" + ix = self.index_code(""" + def f(x: int): + return "hello" + + class A: + a = f(10) + b = a.upcase() + """) + fns = ix.function_map + # we should only have f in fns, despite function calls at class scope + self.assertHasFunctions(fns, ["f"]) + test_base.main(globals(), __name__ == "__main__") From 4027dcfeac50498246cf36d2979253001c484244 Mon Sep 17 00:00:00 2001 From: slebedev Date: Thu, 11 Jun 2020 13:26:22 -0700 Subject: [PATCH 12/17] Renamed testdata/import.py->testdata/imports.py import is a keyword and therefore is not a valid module name in Python. PiperOrigin-RevId: 315965683 --- pytype/tools/xref/testdata/{import.py => imports.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename pytype/tools/xref/testdata/{import.py => imports.py} (100%) diff --git a/pytype/tools/xref/testdata/import.py b/pytype/tools/xref/testdata/imports.py similarity index 100% rename from pytype/tools/xref/testdata/import.py rename to pytype/tools/xref/testdata/imports.py From 03e281f2051dd4a4ef1c0cbbba90d967d97a0582 Mon Sep 17 00:00:00 2001 From: mdemello Date: Thu, 11 Jun 2020 13:42:23 -0700 Subject: [PATCH 13/17] Populate the `cls` arg in classmethods with the class type. PiperOrigin-RevId: 315968839 --- pytype/abstract.py | 5 +++++ pytype/analyze.py | 34 ++++++++++++++++++++++++---------- pytype/special_builtins.py | 2 ++ pytype/tests/test_classes.py | 30 ++++++++++++++++++++++++++++++ pytype/tests/test_cmp.py | 20 +++++++++++++++++++- pytype/vm.py | 4 +--- 6 files changed, 81 insertions(+), 14 deletions(-) diff --git a/pytype/abstract.py b/pytype/abstract.py index 8d38b1381..02e85d8d6 100644 --- a/pytype/abstract.py +++ b/pytype/abstract.py @@ -1492,6 +1492,7 @@ def __init__(self, name, vm): super(Function, self).__init__(name, vm) self.cls = FunctionPyTDClass(self, vm) self.is_attribute_of_class = False + self.is_classmethod = False self.is_abstract = False self.members["func_name"] = self.vm.convert.build_string( self.vm.root_cfg_node, name) @@ -3488,6 +3489,10 @@ def is_abstract(self): def is_abstract(self, value): self.underlying.is_abstract = value + @property + def is_classmethod(self): + return self.underlying.is_classmethod + def repr_names(self, callself_repr=None): """Names to use in the bound function's string representation. diff --git a/pytype/analyze.py b/pytype/analyze.py index 5e99b3bc0..a1c653ea9 100644 --- a/pytype/analyze.py +++ b/pytype/analyze.py @@ -12,6 +12,7 @@ from pytype import function from pytype import metrics from pytype import output +from pytype import special_builtins from pytype import state as frame_state from pytype import vm from pytype.overlays import typing_overlay @@ -138,7 +139,17 @@ def call_function_in_frame(self, node, var, args, kwargs, self.pop_frame(frame) return state.node, ret - def maybe_analyze_method(self, node, val): + def _maybe_fix_classmethod_cls_arg(self, node, cls, func, args): + sig = func.signature + if (args.posargs and sig.param_names and + (sig.param_names[0] not in sig.annotations)): + # fix "cls" parameter + return args._replace( + posargs=(cls.AssignToNewVariable(node),) + args.posargs[1:]) + else: + return args + + def maybe_analyze_method(self, node, val, cls=None): method = val.data fname = val.data.name if isinstance(method, abstract.INTERPRETER_FUNCTION_TYPES): @@ -150,6 +161,8 @@ def maybe_analyze_method(self, node, val): else: for f in method.iter_signature_functions(): node, args = self.create_method_arguments(node, f) + if f.is_classmethod and cls: + args = self._maybe_fix_classmethod_cls_arg(node, cls, f, args) node, _ = self.call_function_with_args(node, val, args) return node @@ -191,18 +204,23 @@ def _call_with_fake_args(self, node0, funcv): log.info("Unable to generate fake arguments for %s", funcv) return node, self.new_unsolvable(node) - def analyze_method_var(self, node0, name, var): + def analyze_method_var(self, node0, name, var, cls=None): log.info("Analyzing %s", name) node1 = node0.ConnectNew(name) for val in var.bindings: - node2 = self.maybe_analyze_method(node1, val) + node2 = self.maybe_analyze_method(node1, val, cls) node2.ConnectTo(node0) return node0 def bind_method(self, node, name, methodvar, instance_var): bound = self.program.NewVariable() for m in methodvar.Data(node): - bound.AddBinding(m.property_get(instance_var), [], node) + if isinstance(m, special_builtins.ClassMethodInstance): + m = m.func.data[0] + is_cls = True + else: + is_cls = (m.isinstance_InterpreterFunction() and m.is_classmethod) + bound.AddBinding(m.property_get(instance_var, is_cls), [], node) return bound def _instantiate_binding(self, node0, cls): @@ -218,11 +236,7 @@ def _instantiate_binding(self, node0, cls): for b in new.bindings: self._analyzed_functions.add(b.data.get_first_opcode()) node2, args = self.create_method_arguments(node1, b.data) - if args.posargs and ( - b.data.signature.param_names[0] not in b.data.signature.annotations): - # fix "cls" parameter - args = args._replace( - posargs=(cls.AssignToNewVariable(node0),) + args.posargs[1:]) + args = self._maybe_fix_classmethod_cls_arg(node0, cls, b.data, args) node3 = node2.ConnectNew() node4, ret = self.call_function_with_args(node3, b, args) instance.PasteVariable(ret) @@ -345,7 +359,7 @@ def analyze_class(self, node, val): if name in self._CONSTRUCTORS: continue # We already called this method during initialization. b = self.bind_method(node, name, methodvar, instance) - node = self.analyze_method_var(node, name, b) + node = self.analyze_method_var(node, name, b, val) return node def analyze_function(self, node0, val): diff --git a/pytype/special_builtins.py b/pytype/special_builtins.py index bf9196c2e..4540fa7aa 100644 --- a/pytype/special_builtins.py +++ b/pytype/special_builtins.py @@ -797,4 +797,6 @@ def call(self, node, funcv, args): if len(args.posargs) != 1: raise function.WrongArgCount(self._SIGNATURE, args, self.vm) arg = args.posargs[0] + for d in arg.data: + d.is_classmethod = True return node, ClassMethodInstance(self.vm, self, arg).to_variable(node) diff --git a/pytype/tests/test_classes.py b/pytype/tests/test_classes.py index 019d72ff2..f5f2171c9 100644 --- a/pytype/tests/test_classes.py +++ b/pytype/tests/test_classes.py @@ -165,6 +165,36 @@ class Foo(object): def bar(cls) -> None: ... """) + def test_factory_classmethod(self): + ty = self.Infer(""" + class Foo(object): + @classmethod + def factory(cls, *args, **kwargs): + return object.__new__(cls) + """) + self.assertTypesMatchPytd(ty, """ + from typing import Type, TypeVar + _TFoo = TypeVar('_TFoo', bound=Foo) + class Foo: + @classmethod + def factory(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... + """) + + def test_classmethod_return_inference(self): + ty = self.Infer(""" + class Foo(object): + A = 10 + @classmethod + def bar(cls): + return cls.A + """) + self.assertTypesMatchPytd(ty, """ + class Foo(object): + A: int + @classmethod + def bar(cls) -> int: ... + """) + def test_inherit_from_unknown_attributes(self): ty = self.Infer(""" class Foo(__any_object__): diff --git a/pytype/tests/test_cmp.py b/pytype/tests/test_cmp.py index f59adaf6d..229bc45e3 100644 --- a/pytype/tests/test_cmp.py +++ b/pytype/tests/test_cmp.py @@ -152,7 +152,25 @@ def __new__(cls, *args, **kwargs): from typing import Type, TypeVar _TFoo = TypeVar('_TFoo', bound=Foo) class Foo: - def __new__(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... + def __new__(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... + """) + + def test_class_factory(self): + # The assert should not block inference of the return type, since cls could + # be a subclass of Foo + ty = self.Infer(""" + class Foo(object): + @classmethod + def factory(cls, *args, **kwargs): + assert(cls is not Foo) + return object.__new__(cls) + """) + self.assertTypesMatchPytd(ty, """ + from typing import Type, TypeVar + _TFoo = TypeVar('_TFoo', bound=Foo) + class Foo: + @classmethod + def factory(cls: Type[_TFoo], *args, **kwargs) -> _TFoo: ... """) diff --git a/pytype/vm.py b/pytype/vm.py index bc5dbddb0..eda114d88 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1593,9 +1593,7 @@ def _is_classmethod_cls_arg(self, var): return False func = self.frame.func.data - # TODO(b/158525984): right now the only classmethod we infer a bound cls - # type for is cls.__new__ - if func.name.rsplit(".")[-1] == "__new__": + if func.is_classmethod or func.name.rsplit(".")[-1] == "__new__": is_cls = not set(var.data) - set(self.frame.first_posarg.data) return is_cls return False From 4acdac9f997397743a9f90cde1aeab93417c435f Mon Sep 17 00:00:00 2001 From: slebedev Date: Fri, 12 Jun 2020 10:24:16 -0700 Subject: [PATCH 14/17] Fixed Import/ImportFrom location matching Prior to this change _get_match_location produced false-positive match when the name being matched occurred as a prefix of some other imported name, e.g. import foo as f PiperOrigin-RevId: 316128150 --- pytype/tools/traces/traces.py | 8 +++----- pytype/tools/traces/traces_test.py | 6 +++--- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytype/tools/traces/traces.py b/pytype/tools/traces/traces.py index f3b7c69fb..38e0de811 100644 --- a/pytype/tools/traces/traces.py +++ b/pytype/tools/traces/traces.py @@ -297,11 +297,9 @@ def _get_match_location(self, node, name=None): return loc if isinstance(node, (self._ast.Import, self._ast.ImportFrom)): # Search for imported module names - text = self.source.line(node.lineno) - c = text.find(" " + name) - if c == -1: - c = text.find("," + name) - if c != -1: + m = re.search("[ ,]" + name + r"\b", self.source.line(node.lineno)) + if m is not None: + c, _ = m.span() return source.Location(node.lineno, c + 1) elif isinstance(node, self._ast.Attribute): attr_loc, _ = self.source.get_attr_location(name, loc) diff --git a/pytype/tools/traces/traces_test.py b/pytype/tools/traces/traces_test.py index 72bb5e012..fa489a61b 100644 --- a/pytype/tools/traces/traces_test.py +++ b/pytype/tools/traces/traces_test.py @@ -124,10 +124,10 @@ def test_import(self): def test_import_from(self): matches = self._get_traces( - "from os import path as _path, environ", ast.ImportFrom) + "from os import path as p, environ", ast.ImportFrom) self.assertTracesEqual(matches, [ - ((1, 23), "STORE_NAME", "_path", ("module",)), - ((1, 30), "STORE_NAME", "environ", ("os._Environ[str]",))]) + ((1, 23), "STORE_NAME", "p", ("module",)), + ((1, 26), "STORE_NAME", "environ", ("os._Environ[str]",))]) class MatchAttributeTest(MatchAstTestCase): From c62829d9e158d455395a8d08af84c130f3abfdf4 Mon Sep 17 00:00:00 2001 From: mdemello Date: Fri, 12 Jun 2020 11:16:12 -0700 Subject: [PATCH 15/17] Basic support for flax dataclasses. PiperOrigin-RevId: 316139475 --- pytype/CMakeLists.txt | 1 + pytype/overlay_dict.py | 2 ++ pytype/overlays/dataclass_overlay.py | 4 ++-- pytype/overlays/flax_overlay.py | 34 ++++++++++++++++++++++++++++ pytype/tests/py3/test_dataclasses.py | 27 ++++++++++++++++++++++ 5 files changed, 66 insertions(+), 2 deletions(-) create mode 100644 pytype/overlays/flax_overlay.py diff --git a/pytype/CMakeLists.txt b/pytype/CMakeLists.txt index 6c2a55e05..bd2a80be3 100644 --- a/pytype/CMakeLists.txt +++ b/pytype/CMakeLists.txt @@ -174,6 +174,7 @@ py_library( overlays/classgen.py overlays/collections_overlay.py overlays/dataclass_overlay.py + overlays/flax_overlay.py overlays/future_overlay.py overlays/six_overlay.py overlays/subprocess_overlay.py diff --git a/pytype/overlay_dict.py b/pytype/overlay_dict.py index 876fa80c4..1d8c58ca1 100644 --- a/pytype/overlay_dict.py +++ b/pytype/overlay_dict.py @@ -13,6 +13,7 @@ from pytype.overlays import attr_overlay from pytype.overlays import collections_overlay from pytype.overlays import dataclass_overlay +from pytype.overlays import flax_overlay from pytype.overlays import future_overlay from pytype.overlays import six_overlay from pytype.overlays import subprocess_overlay @@ -27,6 +28,7 @@ "attr": attr_overlay.AttrOverlay, "collections": collections_overlay.CollectionsOverlay, "dataclasses": dataclass_overlay.DataclassOverlay, + "flax.struct": flax_overlay.DataclassOverlay, "future.utils": future_overlay.FutureUtilsOverlay, "six": six_overlay.SixOverlay, "subprocess": subprocess_overlay.SubprocessOverlay, diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index d996b752a..4597e4aa1 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -34,8 +34,8 @@ class Dataclass(classgen.Decorator): """Implements the @dataclass decorator.""" @classmethod - def make(cls, name, vm): - return super(Dataclass, cls).make(name, vm, "dataclasses") + def make(cls, name, vm, mod="dataclasses"): + return super(Dataclass, cls).make(name, vm, mod) def _handle_initvar(self, node, cls, name, typ, orig): """Unpack or delete an initvar in the class annotations.""" diff --git a/pytype/overlays/flax_overlay.py b/pytype/overlays/flax_overlay.py new file mode 100644 index 000000000..a88c19267 --- /dev/null +++ b/pytype/overlays/flax_overlay.py @@ -0,0 +1,34 @@ +"""Support for flax.struct dataclasses.""" + +# Flax is a high-performance neural network library for JAX +# see //third_party/py/flax +# +# Since flax.struct.dataclass uses dataclass.dataclass internally, we can simply +# reuse the dataclass overlay with some subclassed constructors to change the +# module name. +# +# NOTE: flax.struct.dataclasses set frozen=True, but since we don't support +# frozen anyway we needn't bother about that for now. + + +from pytype import overlay +from pytype.overlays import dataclass_overlay + + +class DataclassOverlay(overlay.Overlay): + """A custom overlay for the 'flax.struct' module.""" + + def __init__(self, vm): + member_map = { + "dataclass": Dataclass.make, + } + ast = vm.loader.import_name("flax.struct") + super(DataclassOverlay, self).__init__(vm, "flax.struct", member_map, ast) + + +class Dataclass(dataclass_overlay.Dataclass): + """Implements the @dataclass decorator.""" + + @classmethod + def make(cls, name, vm): + return super(Dataclass, cls).make(name, vm, "flax.struct") diff --git a/pytype/tests/py3/test_dataclasses.py b/pytype/tests/py3/test_dataclasses.py index a5682c80b..a3cbfdee0 100644 --- a/pytype/tests/py3/test_dataclasses.py +++ b/pytype/tests/py3/test_dataclasses.py @@ -1,6 +1,7 @@ # Lint as: python3 """Tests for the dataclasses overlay.""" +from pytype import file_utils from pytype.tests import test_base @@ -438,4 +439,30 @@ class NHNetConfig: """) +class TestFlaxDataclass(test_base.TargetPython3FeatureTest): + """Tests for flax.struct.dataclass.""" + + def test_basic(self): + with file_utils.Tempdir() as d: + d.create_file("flax/struct.pyi", """ + from typing import Type + def dataclass(_cls: Type[_T]) -> Type[_T]: ... + """) + ty = self.Infer(""" + import flax + @flax.struct.dataclass + class Foo(object): + x: bool + y: int + z: str + """, pythonpath=[d.path], module_name="foo") + self.assertTypesMatchPytd(ty, """ + flax: module + class Foo(object): + x: bool + y: int + z: str + def __init__(self, x: bool, y: int, z: str) -> None: ... + """) + test_base.main(globals(), __name__ == "__main__") From f8f80789702b9af660a0db8821bc59201bcf7e1d Mon Sep 17 00:00:00 2001 From: rechen Date: Fri, 12 Jun 2020 13:48:43 -0700 Subject: [PATCH 16/17] Enable more of --check-variable-types. With this change, v: int = 0 v = '' will also be an error. I already cleaned up all new findings introduced by this change as of about a week ago. PiperOrigin-RevId: 316170169 --- pytype/overlays/dataclass_overlay.py | 2 +- pytype/vm.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytype/overlays/dataclass_overlay.py b/pytype/overlays/dataclass_overlay.py index 4597e4aa1..6bc7be6dd 100644 --- a/pytype/overlays/dataclass_overlay.py +++ b/pytype/overlays/dataclass_overlay.py @@ -85,7 +85,7 @@ def decorate(self, node, cls): # TODO(b/74434237): The first check can be removed once # --check-variable-types is on by default. if ((not self.vm.options.check_variable_types and - local.last_op.line not in self.vm.director._variable_annotations) or # pylint: disable=protected-access + local.last_op.line in self.vm.director.type_comments) or orig and orig.data == [self.vm.convert.none]): # vm._apply_annotation mostly takes care of checking that the default # matches the declared type. However, it allows None defaults, and diff --git a/pytype/vm.py b/pytype/vm.py index eda114d88..eaaf4a0d8 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -1341,7 +1341,7 @@ def _pop_and_store(self, state, op, name, local): annotations_dict = self.current_annotated_locals if local else None # TODO(b/74434237): Enable --check-variable-types by default. check_types = (self.options.check_variable_types or - op.line in self.director._variable_annotations) # pylint: disable=protected-access + op.line not in self.director.type_comments) value = self._apply_annotation( state, op, name, orig_val, annotations_dict, check_types) self._check_aliased_type_params(value) From f6af86e0a0c388738a2e293fe7386d6d228989d4 Mon Sep 17 00:00:00 2001 From: rechen Date: Fri, 12 Jun 2020 13:55:46 -0700 Subject: [PATCH 17/17] Allow # type: ignore after the opening parenthesis in a function def. For https://github.com/python/typeshed/pull/4224. We really need a better pyi parser... PiperOrigin-RevId: 316171517 --- pytype/pyi/parser.yy | 5 +++-- pytype/pyi/parser_test.py | 6 ++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytype/pyi/parser.yy b/pytype/pyi/parser.yy index 1610daeb4..666116266 100644 --- a/pytype/pyi/parser.yy +++ b/pytype/pyi/parser.yy @@ -479,8 +479,9 @@ typevar_kwarg ; funcdef - : decorators maybe_async DEF funcname '(' params ')' return maybe_body { - $$ = ctx->Call(kNewFunction, "(NONNNN)", $1, $2, $4, $6, $8, $9); + : decorators maybe_async DEF funcname '(' maybe_type_ignore params ')' return + maybe_body { + $$ = ctx->Call(kNewFunction, "(NONNNN)", $1, $2, $4, $7, $9, $10); // Decorators is nullable and messes up the location tracking by // using the previous symbol as the start location for this production, // which is very misleading. It is better to ignore decorators and diff --git a/pytype/pyi/parser_test.py b/pytype/pyi/parser_test.py index 097d301ca..74d87334b 100644 --- a/pytype/pyi/parser_test.py +++ b/pytype/pyi/parser_test.py @@ -1014,6 +1014,12 @@ class Foo: class Foo: bar: str """) + self.check(""" + def f( # type: ignore + x: int) -> None: ... + """, """ + def f(x: int) -> None: ... + """) def test_decorators(self): # These tests are a bit questionable because most of the decorators only