Skip to content

Commit

Permalink
Merge 27992b8 into d5aaa6c
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed May 4, 2020
2 parents d5aaa6c + 27992b8 commit 9a8c0bf
Show file tree
Hide file tree
Showing 8 changed files with 85 additions and 26 deletions.
5 changes: 3 additions & 2 deletions deal/linter/_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@
from .globals import get_globals
from .imports import get_imports
from .prints import get_prints
from .returns import get_returns
from .returns import get_returns, has_returns


__all__ = [
'get_asserts',
'get_contracts',
'get_exceptions',
'get_exceptions_stubs',
'get_exceptions',
'get_globals',
'get_imports',
'get_name',
'get_prints',
'get_returns',
'has_returns',
]
8 changes: 3 additions & 5 deletions deal/linter/_extractors/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@

@get_asserts.register(*TOKENS.ASSERT)
def handle_assert(expr) -> Optional[Token]:
handler = inner_extractor.handlers.get(type(expr.test))
if handler:
token = handler(expr=expr.test)
if token is not None:
return token
# inner_extractor
for token in inner_extractor.handle(expr=expr.test):
return token

# astroid inference
if hasattr(expr.test, 'infer'):
Expand Down
24 changes: 14 additions & 10 deletions deal/linter/_extractors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
TRY=(ast.Try, astroid.TryExcept, astroid.TryFinally),
UNARY_OP=(ast.UnaryOp, astroid.UnaryOp),
WITH=(ast.With, astroid.With),
YIELD=(ast.Yield, astroid.Yield),
)


Expand Down Expand Up @@ -104,15 +105,18 @@ def _register(self, types: Tuple[type], handler: Callable) -> Callable:
def register(self, *types):
return partial(self._register, types)

def handle(self, expr, **kwargs):
handler = self.handlers.get(type(expr))
if not handler:
return
token = handler(expr=expr, **kwargs)
if token is None:
return
if type(token) is Token:
yield token
return
yield from token

def __call__(self, body: List, **kwargs) -> Iterator[Token]:
for expr in traverse(body=body):
handler = self.handlers.get(type(expr))
if not handler:
continue
token = handler(expr=expr, **kwargs)
if token is None:
continue
if type(token) is Token:
yield token
continue
yield from token
yield from self.handle(expr=expr, **kwargs)
17 changes: 11 additions & 6 deletions deal/linter/_extractors/returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,25 @@
import astroid

# app
from .common import TOKENS, Extractor, Token, infer
from .common import TOKENS, Extractor, Token, infer, traverse


get_returns = Extractor()
inner_extractor = Extractor()


def has_returns(body: list) -> bool:
for expr in traverse(body=body):
if isinstance(expr, TOKENS.RETURN + TOKENS.YIELD):
return True
return False


@get_returns.register(*TOKENS.RETURN)
def handle_returns(expr) -> Optional[Token]:
handler = inner_extractor.handlers.get(type(expr.value))
if handler:
token = handler(expr=expr.value)
if token is not None:
return token
# inner_extractor
for token in inner_extractor.handle(expr=expr.value):
return token

# astroid inference
if hasattr(expr.value, 'infer'):
Expand Down
10 changes: 8 additions & 2 deletions deal/linter/_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@


class Func:
__slots__ = ('body', 'contracts', 'name')
__slots__ = ('body', 'contracts', 'name', 'line', 'col')

def __init__(self, *, body: list, contracts: Iterable[Contract], name: str):
def __init__(self, *, body: list, contracts: Iterable[Contract], name: str, line: int, col: int):
self.body = body
self.contracts = contracts
self.name = name
self.line = line
self.col = col

@classmethod
def from_path(cls, path: Path) -> List['Func']:
Expand All @@ -50,6 +52,8 @@ def from_ast(cls, tree: ast.Module) -> List['Func']:
name=expr.name,
body=expr.body,
contracts=contracts,
line=expr.lineno,
col=expr.col_offset,
))
return funcs

Expand All @@ -68,6 +72,8 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']:
name=expr.name,
body=expr.body,
contracts=contracts,
line=expr.lineno,
col=expr.col_offset,
))
return funcs

Expand Down
9 changes: 9 additions & 0 deletions deal/linter/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._extractors import (
get_exceptions, get_exceptions_stubs, get_globals,
get_imports, get_prints, get_returns, get_asserts,
has_returns,
)
from ._func import Func
from ._stub import StubsManager
Expand Down Expand Up @@ -158,6 +159,14 @@ def __call__(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]:
return

def _check(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]:
if not has_returns(body=func.body):
yield Error(
code=self.code,
text=self.message,
value='no return',
row=func.line,
col=func.col,
)
for token in get_globals(body=func.body):
yield Error(
code=self.code,
Expand Down
20 changes: 19 additions & 1 deletion tests/test_linter/test_extractors/test_returns.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

# project
from deal.linter._extractors import get_returns
from deal.linter._extractors import get_returns, has_returns


@pytest.mark.parametrize('text, expected', [
Expand Down Expand Up @@ -57,3 +57,21 @@ def test_get_returns_inference(text, expected):
print(tree.repr_tree())
returns = tuple(r.value for r in get_returns(body=tree.body))
assert returns == expected


@pytest.mark.parametrize('text, expected', [
('return', True),
('return 1', True),
('if b:\n return 1', True),
('yield 1', True),
('if b:\n yield 1', True),
('1 + 2', False),
])
def test_has_returns(text, expected):
tree = ast.parse(text)
print(ast.dump(tree))
assert has_returns(body=tree.body) is expected

tree = astroid.parse(text)
print(tree.repr_tree())
assert has_returns(body=tree.body) is expected
18 changes: 18 additions & 0 deletions tests/test_linter/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def test_check_pure():
@deal.pure
def test(a):
global b
return b
"""
text = dedent(text).strip()
funcs1 = Func.from_ast(ast.parse(text))
Expand All @@ -157,6 +158,23 @@ def test(a):
assert actual == expected


def test_check_pure_no_returns():
checker = CheckPure()
text = """
@deal.pure
def test(a):
a + 3
"""
text = dedent(text).strip()
funcs1 = Func.from_ast(ast.parse(text))
funcs2 = Func.from_astroid(astroid.parse(text))
for func in (funcs1[0], funcs2[0]):
actual = [tuple(err) for err in checker(func)]
assert len(actual) == 1
expected = 'DEAL014 pure contract error (no return)'
assert actual[0][2] == expected


def test_check_asserts():
checker = CheckAsserts()
text = """
Expand Down

0 comments on commit 9a8c0bf

Please sign in to comment.