diff --git a/pytype/abstract/_classes.py b/pytype/abstract/_classes.py index 9fb4284db..c9ecfc759 100644 --- a/pytype/abstract/_classes.py +++ b/pytype/abstract/_classes.py @@ -308,9 +308,12 @@ def has_protocol_base(self): return True return False - def get_undecorated_method(self, name, node): + def get_undecorated_method( + self, name: str, node: cfg.CFGNode) -> Optional[cfg.Variable]: + if name not in self._undecorated_methods: + return None return self.ctx.program.NewVariable( - self._undecorated_methods.get(name, ()), (), node) + self._undecorated_methods[name], (), node) class PyTDClass( diff --git a/pytype/config.py b/pytype/config.py index 5a1391e71..ea9a330bf 100644 --- a/pytype/config.py +++ b/pytype/config.py @@ -234,6 +234,8 @@ def add_options(o, arglist): FEATURE_FLAGS = [ + _flag("--bind-decorated-methods", False, + "Bind 'self' in methods with non-transparent decorators."), _flag("--overriding-renamed-parameter-count-checks", False, "Enable parameter count checks for overriding methods with " "renamed arguments."), diff --git a/pytype/tests/test_base.py b/pytype/tests/test_base.py index c754dfdea..3c8d03a86 100644 --- a/pytype/tests/test_base.py +++ b/pytype/tests/test_base.py @@ -80,6 +80,7 @@ def setUp(self): super().setUp() self.options = config.Options.create( python_version=self.python_version, + bind_decorated_methods=True, overriding_renamed_parameter_count_checks=True, strict_parameter_checks=True, strict_undefined_checks=True, diff --git a/pytype/tests/test_decorators2.py b/pytype/tests/test_decorators2.py index dda82ba7f..2fc20f6b8 100644 --- a/pytype/tests/test_decorators2.py +++ b/pytype/tests/test_decorators2.py @@ -344,6 +344,17 @@ def f(self, x: T): pass """) + def test_self_in_decorated_method(self): + self.Check(""" + from typing import Any + def decorate(f) -> Any: + return f + class C: + @decorate + def f(self): + assert_type(self, C) + """) + if __name__ == "__main__": test_base.main() diff --git a/pytype/tracer_vm.py b/pytype/tracer_vm.py index 76559ced6..304a4fd57 100644 --- a/pytype/tracer_vm.py +++ b/pytype/tracer_vm.py @@ -407,7 +407,7 @@ def bind(cur_node, m): # which can happen if the method is decorated, for example - then we look up # the method before any decorators were applied and use that instead. undecorated_method = cls.get_undecorated_method(method_name, node) - if undecorated_method.data: + if undecorated_method: return node, bind(node, undecorated_method) else: return node, bound_method @@ -477,6 +477,13 @@ def analyze_class(self, node, val): name = unwrapped.data[0].name if unwrapped else v.name self.ctx.errorlog.ignored_abstractmethod( self.ctx.vm.simple_stack(cls.get_first_opcode()), cls.name, name) + is_method_or_nested_class = any( + isinstance(m, (abstract.FUNCTION_TYPES, abstract.InterpreterClass)) + for m in methodvar.data) + if (self.ctx.options.bind_decorated_methods and + not is_method_or_nested_class and + (undecorated_method := cls.get_undecorated_method(name, node))): + methodvar = undecorated_method b = self._bind_method(node, methodvar, instance) node = self.analyze_method_var(node, name, b, val) return node