Skip to content

Commit

Permalink
Merge 54d36e5 into fe27139
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Apr 24, 2020
2 parents fe27139 + 54d36e5 commit e34bad9
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 9 deletions.
1 change: 1 addition & 0 deletions deal/linter/_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Category(enum.Enum):
POST = 'post'
RAISES = 'raises'
SILENT = 'silent'
PURE = 'pure'


class Contract:
Expand Down
2 changes: 2 additions & 0 deletions deal/linter/_extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .common import get_name
from .contracts import get_contracts
from .exceptions import get_exceptions
from .globals import get_globals
from .imports import get_imports
from .prints import get_prints
from .returns import get_returns
Expand All @@ -10,6 +11,7 @@
__all__ = [
'get_contracts',
'get_exceptions',
'get_globals',
'get_imports',
'get_name',
'get_prints',
Expand Down
6 changes: 4 additions & 2 deletions deal/linter/_extractors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@


TOKENS = SimpleNamespace(
ATTR=(ast.Attribute, astroid.Attribute),
ASSERT=(ast.Assert, astroid.Assert),
ATTR=(ast.Attribute, astroid.Attribute),
BIN_OP=(ast.BinOp, astroid.BinOp),
CALL=(ast.Call, astroid.Call),
EXPR=(ast.Expr, astroid.Expr),
FOR=(ast.For, astroid.For),
FUNC=(ast.FunctionDef, astroid.FunctionDef),
GLOBAL=(ast.Global, astroid.Global),
IF=(ast.If, astroid.If),
NAME=(ast.Name, astroid.Name),
NONLOCAL=(ast.Nonlocal, astroid.Nonlocal),
RAISE=(ast.Raise, astroid.Raise),
RETURN=(ast.Return, astroid.Return),
TRY=(ast.Try, astroid.TryExcept, astroid.TryFinally),
UNARY_OP=(ast.UnaryOp, astroid.UnaryOp),
WITH=(ast.With, astroid.With),
FUNC=(ast.FunctionDef, astroid.FunctionDef),
)


Expand Down
4 changes: 2 additions & 2 deletions deal/linter/_extractors/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from .common import TOKENS, get_name


SUPPORTED_CONTRACTS = {'deal.post', 'deal.raises', 'deal.silent'}
SUPPORTED_MARKERS = {'deal.silent'}
SUPPORTED_CONTRACTS = {'deal.post', 'deal.raises', 'deal.silent', 'deal.pure'}
SUPPORTED_MARKERS = {'deal.silent', 'deal.pure'}


def get_contracts(decorators: list) -> Iterator[Tuple[str, list]]:
Expand Down
36 changes: 36 additions & 0 deletions deal/linter/_extractors/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# built-in
import ast
from typing import Iterator

# external
import astroid

# app
from .common import TOKENS, Token, traverse


def get_globals(body: list) -> Iterator[Token]:
for expr in traverse(body):
if isinstance(expr, TOKENS.GLOBAL):
yield Token(value='global', line=expr.lineno, col=expr.col_offset)
continue

if isinstance(expr, TOKENS.NONLOCAL):
yield Token(value='nonlocal', line=expr.lineno, col=expr.col_offset)
continue

if type(expr) is ast.Import:
yield Token(value='import', line=expr.lineno, col=expr.col_offset)
continue

if type(expr) is astroid.Import:
yield Token(value='import', line=expr.lineno, col=expr.col_offset)
continue

if type(expr) is ast.ImportFrom:
yield Token(value='import', line=expr.lineno, col=expr.col_offset)
continue

if type(expr) is astroid.ImportFrom:
yield Token(value='import', line=expr.lineno, col=expr.col_offset)
continue
30 changes: 27 additions & 3 deletions deal/linter/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# app
from ._contract import Category, Contract
from ._error import Error
from ._extractors import get_exceptions, get_imports, get_prints, get_returns
from ._extractors import get_exceptions, get_imports, get_prints, get_returns, get_globals
from ._func import Func


Expand Down Expand Up @@ -116,11 +116,11 @@ def __call__(self, func: Func) -> Iterator[Error]:
for contract in func.contracts:
if contract.category != Category.SILENT:
continue
yield from self._check(func=func, contract=contract)
yield from self._check(func=func)
# if `@deal.silent` is duplicated, check the function only once
return

def _check(self, func: Func, contract: Contract) -> Iterator[Error]:
def _check(self, func: Func) -> Iterator[Error]:
for token in get_prints(body=func.body):
yield Error(
code=self.code,
Expand All @@ -129,3 +129,27 @@ def _check(self, func: Func, contract: Contract) -> Iterator[Error]:
row=token.line,
col=token.col,
)


@register
class CheckPure:
code = 14
message = 'pure contract error'
required = Required.FUNC

def __call__(self, func: Func) -> Iterator[Error]:
for contract in func.contracts:
if contract.category != Category.PURE:
continue
yield from self._check(func=func)
return

def _check(self, func: Func) -> Iterator[Error]:
for token in get_globals(body=func.body):
yield Error(
code=self.code,
text=self.message,
value=str(token.value),
row=token.line,
col=token.col,
)
2 changes: 1 addition & 1 deletion docs/decorators/pure.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# pure

Pure function cannot do network requests, write anything into stdout or raise any exceptions. It gets some parameters and returns some result. That's all. This is alias for `chain(safe, silent, offline)`.
Pure function cannot do network requests, write anything into stdout or raise any exceptions. It gets some parameters and returns some result. That's all. In runtime, it does the same checks as `chain(safe, silent, offline)`. However, [linter](../linter.html) checks a bit more, like no `import`, `global`, `nonlocal`, etc.

```python
@deal.pure
Expand Down
1 change: 1 addition & 0 deletions docs/linter.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ Another option is to use built-in CLI from deal: `python3 -m deal.linter`. I has
| DEAL011 | post contract error |
| DEAL012 | raises contract error |
| DEAL013 | silent contract error |
| DEAL014 | pure contract error |
35 changes: 35 additions & 0 deletions tests/test_linter/test_extractors/test_globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# built-in
import ast

# external
import astroid
import pytest

# project
from deal.linter._extractors import get_globals


@pytest.mark.parametrize('text, expected', [
('global a', ('global', )),
('global a, b, c', ('global', )),
('nonlocal a', ('nonlocal', )),
('nonlocal a, b, c', ('nonlocal', )),
('import a', ('import', )),
('import a as b', ('import', )),
('import a as b, c', ('import', )),
('from a import b', ('import', )),
('from a import b as c', ('import', )),
])
def test_get_globals_simple(text, expected):
tree = astroid.parse(text)
print(tree.repr_tree())
globals = tuple(r.value for r in get_globals(body=tree.body))
assert globals == expected

tree = ast.parse(text)
print(ast.dump(tree))
globals = tuple(r.value for r in get_globals(body=tree.body))
assert globals == expected
28 changes: 27 additions & 1 deletion tests/test_linter/test_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,17 @@

# project
from deal.linter._func import Func
from deal.linter._rules import CheckImports, CheckPrints, CheckRaises, CheckReturns
from deal.linter._rules import rules, CheckImports, CheckPrints, CheckRaises, CheckReturns, CheckPure


def test_error_codes():
codes = [rule.code for rule in rules]
assert len(codes) == len(set(codes))


def test_error_messages():
messages = [rule.message for rule in rules]
assert len(messages) == len(set(messages))


def test_check_returns():
Expand Down Expand Up @@ -129,6 +139,22 @@ def test(a):
assert actual == expected


def test_check_pure():
checker = CheckPure()
text = """
@deal.pure
def test(a):
global b
"""
text = dedent(text).strip()
funcs1 = Func.from_ast(ast.parse(text))
funcs2 = Func.from_astroid(astroid.parse(text))
for func in (funcs1[0], funcs2[0]):
actual = [tuple(err) for err in checker(func)]
expected = [(3, 4, 'DEAL014: pure contract error (global)')]
assert actual == expected


def test_check_imports():
checker = CheckImports()
text = """
Expand Down

0 comments on commit e34bad9

Please sign in to comment.