diff --git a/pytype/directors/directors.py b/pytype/directors/directors.py index 5dffdc32d..883034dbc 100644 --- a/pytype/directors/directors.py +++ b/pytype/directors/directors.py @@ -280,6 +280,7 @@ def __init__(self, src_tree, errorlog, filename, disable, code): # Store function ranges and return lines to distinguish explicit and # implicit returns (the bytecode has a `RETURN None` for implcit returns). self._return_lines = set() + self.block_returns = None self._function_ranges = _BlockRanges({}) # Parse the source code for directives. self._parse_src_tree(src_tree, code) @@ -313,7 +314,8 @@ def _parse_src_tree(self, src_tree, code): else: opcode_lines = None - self._return_lines = visitor.returns + self.block_returns = visitor.block_returns + self._return_lines = visitor.block_returns.all_returns() self._function_ranges = _BlockRanges(visitor.function_ranges) for line_range, group in visitor.structured_comment_groups.items(): diff --git a/pytype/directors/parser.py b/pytype/directors/parser.py index 3c8f02698..f105898ef 100644 --- a/pytype/directors/parser.py +++ b/pytype/directors/parser.py @@ -32,6 +32,9 @@ class LineRange: def from_node(cls, node): return cls(node.lineno, node.end_lineno) + def __contains__(self, line): + return self.start_line <= line <= self.end_line + @dataclasses.dataclass(frozen=True) class Call(LineRange): @@ -66,6 +69,44 @@ class _SourceTree: structured_comments: Mapping[int, Sequence[_StructuredComment]] +class BlockReturns: + """Tracks return statements in with/try blocks.""" + + def __init__(self): + self._block_ranges = [] + self._returns = [] + self._block_returns = {} + self._final = False + + def add_block(self, node): + line_range = LineRange.from_node(node) + self._block_ranges.append(line_range) + + def add_return(self, node): + self._returns.append(node.lineno) + + def finalize(self): + for br in self._block_ranges: + self._block_returns[br.start_line] = sorted( + r for r in self._returns if r in br + ) + self._final = True + + def all_returns(self): + return set(self._returns) + + def __iter__(self): + assert self._final + return iter(self._block_returns.items()) + + def __repr__(self): + return f""" + Blocks: {self._block_ranges} + Returns: {self._returns} + {self._block_returns} + """ + + class _ParseVisitor(visitor.BaseVisitor): """Visitor for parsing a source tree. @@ -97,8 +138,9 @@ def __init__(self, raw_structured_comments): self.variable_annotations = [] self.decorators = [] self.defs_start = None - self.returns = set() self.function_ranges = {} + self.block_returns = BlockReturns() + self.block_depth = 0 def _add_structured_comment_group(self, start_line, end_line, cls=LineRange): """Adds an empty _StructuredComment group with the given line range.""" @@ -171,6 +213,9 @@ def should_add(comment, group): if cls is not LineRange: group.extend(c for c in structured_comments if should_add(c, group)) + def leave_Module(self, node): + self.block_returns.finalize() + def visit_Call(self, node): self._process_structured_comments(LineRange.from_node(node), cls=Call) @@ -200,8 +245,22 @@ def visit_Try(self, node): def _visit_with(self, node): item = node.items[-1] end_lineno = (item.optional_vars or item.context_expr).end_lineno + if self.block_depth == 1: + self.block_returns.add_block(node) self._process_structured_comments(LineRange(node.lineno, end_lineno)) + def enter_With(self, node): + self.block_depth += 1 + + def leave_With(self, node): + self.block_depth -= 1 + + def enter_AsyncWith(self, node): + self.block_depth += 1 + + def leave_AsyncWith(self, node): + self.block_depth -= 1 + def visit_With(self, node): self._visit_with(node) @@ -226,8 +285,8 @@ def generic_visit(self, node): self._process_structured_comments(LineRange.from_node(node)) def visit_Return(self, node): + self.block_returns.add_return(node) self._process_structured_comments(LineRange.from_node(node)) - self.returns.add(node.lineno) def _visit_decorators(self, node): if not node.decorator_list: diff --git a/pytype/tests/CMakeLists.txt b/pytype/tests/CMakeLists.txt index 09c54d19d..6aec1f482 100644 --- a/pytype/tests/CMakeLists.txt +++ b/pytype/tests/CMakeLists.txt @@ -667,6 +667,15 @@ py_test( .test_base ) +py_test( + NAME + test_returns + SRCS + test_returns.py + DEPS + .test_base +) + py_test( NAME test_list1 diff --git a/pytype/tests/test_returns.py b/pytype/tests/test_returns.py new file mode 100644 index 000000000..362d81fd2 --- /dev/null +++ b/pytype/tests/test_returns.py @@ -0,0 +1,63 @@ +"""Tests for bad-return-type errors.""" + +from pytype.tests import test_base + + +class TestReturns(test_base.BaseTest): + """Tests for bad-return-type.""" + + def test_implicit_none(self): + self.CheckWithErrors(""" + def f(x) -> int: + pass # bad-return-type + """) + + def test_if(self): + # NOTE(b/233047104): The implict `return None` gets reported at the end of + # the function even though there is also a correct return on that line. + self.CheckWithErrors(""" + def f(x) -> int: + if x: + pass + else: + return 10 # bad-return-type + """) + + def test_nested_if(self): + self.CheckWithErrors(""" + def f(x) -> int: + if x: + if __random__: + pass + else: + return 'a' # bad-return-type + else: + return 10 + pass # bad-return-type + """) + + def test_with(self): + self.CheckWithErrors(""" + def f(x) -> int: + with open('foo'): + if __random__: + pass + else: + return 'a' # bad-return-type # bad-return-type + """) + + def test_nested_with(self): + self.CheckWithErrors(""" + def f(x) -> int: + with open('foo'): + if __random__: + with open('bar'): + if __random__: + pass + else: + return 'a' # bad-return-type # bad-return-type + """) + + +if __name__ == "__main__": + test_base.main() diff --git a/pytype/vm.py b/pytype/vm.py index 399158d7e..2d3823ebb 100644 --- a/pytype/vm.py +++ b/pytype/vm.py @@ -230,6 +230,7 @@ def run_frame(self, frame, node, annotated_locals=None): can_return = False return_nodes = [] finally_tracker = vm_utils.FinallyStateTracker() + vm_utils.adjust_block_returns(frame.f_code, self._director.block_returns) for block in frame.f_code.order: state = frame.states.get(block[0]) if not state: @@ -430,6 +431,7 @@ def run_program(self, src, filename, maximum_depth): self.ctx.errorlog.ignored_type_comment(self.filename, line, self._director.type_comments[line]) code = constant_folding.optimize(code) + vm_utils.adjust_block_returns(code, self._director.block_returns) node = self.ctx.root_node.ConnectNew("init") node, f_globals, f_locals, _ = self.run_bytecode(node, code) diff --git a/pytype/vm_utils.py b/pytype/vm_utils.py index 99c071d79..b0e23b8da 100644 --- a/pytype/vm_utils.py +++ b/pytype/vm_utils.py @@ -1137,3 +1137,17 @@ def to_coroutine(state, obj, top, ctx): for b in obj.bindings: state = _binding_to_coroutine(state, b, bad_bindings, ret, top, ctx) return state, ret + + +def adjust_block_returns(code, block_returns): + """Adjust line numbers for return statements in with blocks.""" + + rets = {k: iter(v) for k, v in block_returns} + for block in code.order: + for op in block: + if op.__class__.__name__ == "RETURN_VALUE": + if op.line in rets: + lines = rets[op.line] + new_line = next(lines, None) + if new_line: + op.line = new_line