Skip to content

Commit

Permalink
Merge 2cd7407 into 6298739
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed Nov 15, 2019
2 parents 6298739 + 2cd7407 commit 55a0a33
Show file tree
Hide file tree
Showing 11 changed files with 381 additions and 163 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
__pycache__/
*.egg-info/
dist/
.coverage
.coverage*
htmlcov/
README.rst
docs/build/
Expand Down
91 changes: 91 additions & 0 deletions deal/linter/_contract.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import ast
import builtins
import enum

import astroid

from ._extractors import get_name


TEMPLATE = """
contract = PLACEHOLDER
result = contract(*args, **kwargs)
"""


class Category(enum.Enum):
POST = 'post'
RAISES = 'raises'
SILENT = 'silent'


class Contract:
def __init__(self, args, category: Category):
self.args = args
self.category = category

@property
def body(self):
contract = self.args[0]
# convert astroid node to ast node
if hasattr(contract, 'as_string'):
contract = self._resolve_name(contract)
contract = ast.parse(contract.as_string()).body[0]
return contract

@staticmethod
def _resolve_name(contract):
if not isinstance(contract, astroid.Name):
return contract
definitions = contract.lookup(contract.name)[1]
if not definitions:
return contract
definition = definitions[0]
if isinstance(definition, astroid.FunctionDef):
return definition
if isinstance(definition, astroid.AssignName):
return definition.parent.value
# resolved into something tricky, live with it
return contract # pragma: no cover

@property
def exceptions(self) -> list:
excs = []
for expr in self.args:
name = get_name(expr)
if not name:
continue
exc = getattr(builtins, name, name)
excs.append(exc)
return excs

@property
def bytecode(self):
module = ast.parse(TEMPLATE)
contract = self.body
if isinstance(contract, ast.FunctionDef):
# if contract is function, add it's definition and assign it's name
# to `contract` variable.
module.body = [contract] + module.body
module.body[1].value = ast.Name(
id=contract.name,
lineno=1,
col_offset=1,
ctx=ast.Load(),
)
else:
if isinstance(contract, ast.Expr):
contract = contract.value
module.body[0].value = contract
return compile(module, filename='<ast>', mode='exec')

def run(self, *args, **kwargs):
globals = dict(args=args, kwargs=kwargs)
exec(self.bytecode, globals)
return globals['result']

def __repr__(self) -> str:
return '{name}({category})'.format(
name=type(self).__name__,
category=self.category.value,
)
1 change: 1 addition & 0 deletions deal/linter/_extractors/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
EXPR=(ast.Expr, astroid.Expr),
FOR=(ast.For, astroid.For),
IF=(ast.If, astroid.If),
NAME=(ast.Name, astroid.Name),
RAISE=(ast.Raise, astroid.Raise),
RETURN=(ast.Return, astroid.Return),
TRY=(ast.Try, astroid.TryExcept, astroid.TryFinally),
Expand Down
34 changes: 33 additions & 1 deletion deal/linter/_extractors/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .common import traverse, Token, TOKENS, get_name


def get_exceptions(body: list) -> Iterator[Token]:
def get_exceptions(body: list, *, dive: bool = True) -> Iterator[Token]:
for expr in traverse(body):
token_info = dict(line=expr.lineno, col=expr.col_offset)

Expand Down Expand Up @@ -53,3 +53,35 @@ def get_exceptions(body: list) -> Iterator[Token]:
if name and name == 'sys.exit':
yield Token(value=SystemExit, **token_info)
continue

# infer function call and check the function body for raises
if not dive:
continue
for name_node in get_names(expr):
if not isinstance(name_node, astroid.Name):
continue
try:
guesses = tuple(name_node.infer())
except astroid.exceptions.NameInferenceError:
continue
for value in guesses:
if not isinstance(value, astroid.FunctionDef):
continue
for error in get_exceptions(body=value.body, dive=False):
yield Token(
value=error.value,
line=name_node.lineno,
col=name_node.col_offset,
)


def get_names(expr):
if isinstance(expr, astroid.Assign):
yield from get_names(expr.value)
if isinstance(expr, TOKENS.CALL):
if isinstance(expr.func, TOKENS.NAME):
yield expr.func
for subnode in expr.args:
yield from get_names(subnode)
for subnode in (expr.keywords or ()):
yield from get_names(subnode.value)
8 changes: 6 additions & 2 deletions deal/linter/_extractors/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ def get_imports(body: list) -> Iterator[Token]:
for expr in traverse(body):
token_info = dict(line=expr.lineno, col=expr.col_offset)
if isinstance(expr, astroid.ImportFrom):
yield Token(value=expr.modname, **token_info)
dots = '.' * (expr.level or 0)
name = expr.modname or ''
yield Token(value=dots + name, **token_info)
if isinstance(expr, ast.ImportFrom):
yield Token(value=expr.module, **token_info)
dots = '.' * expr.level
name = expr.module or ''
yield Token(value=dots + name, **token_info)
108 changes: 21 additions & 87 deletions deal/linter/_func.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import ast
import builtins
import enum
from pathlib import Path
from typing import List
from typing import Iterable, List

import astroid

from ._extractors import get_name, get_contracts
from ._contract import Category, Contract
from ._extractors import get_contracts


TEMPLATE = """
Expand All @@ -15,17 +14,10 @@
"""


class Category(enum.Enum):
POST = 'post'
RAISES = 'raises'
SILENT = 'silent'


class Func:
def __init__(self, body: list, args, category: Category):
def __init__(self, body: list, contracts: Iterable[Contract]):
self.body = body
self.args = args
self.category = category
self.contracts = contracts

@classmethod
def from_path(cls, path: Path) -> List['Func']:
Expand All @@ -44,12 +36,13 @@ def from_ast(cls, tree: ast.Module) -> List['Func']:
for expr in tree.body:
if not isinstance(expr, ast.FunctionDef):
continue
for cat, args in get_contracts(expr.decorator_list):
funcs.append(cls(
body=expr.body,
category=Category(cat),
args=args,
))
contracts = []
for category, args in get_contracts(expr.decorator_list):
contract = Contract(args=args, category=Category(category))
contracts.append(contract)
if not contracts:
continue
funcs.append(cls(body=expr.body, contracts=contracts))
return funcs

@classmethod
Expand All @@ -60,76 +53,17 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']:
continue
if not expr.decorators:
continue
for cat, args in get_contracts(expr.decorators.nodes):
funcs.append(cls(
body=expr.body,
category=Category(cat),
args=args,
))
return funcs

@property
def contract(self):
contract = self.args[0]
# convert astroid node to ast node
if hasattr(contract, 'as_string'):
contract = self._resolve_name(contract)
contract = ast.parse(contract.as_string()).body[0]
return contract

@staticmethod
def _resolve_name(contract):
if not isinstance(contract, astroid.Name):
return contract
definitions = contract.lookup(contract.name)[1]
if not definitions:
return contract
definition = definitions[0]
if isinstance(definition, astroid.FunctionDef):
return definition
if isinstance(definition, astroid.AssignName):
return definition.parent.value
# resolved into something tricky, live with it
return contract # pragma: no cover

@property
def exceptions(self) -> list:
excs = []
for expr in self.args:
name = get_name(expr)
if not name:
contracts = []
for category, args in get_contracts(expr.decorators.nodes):
contract = Contract(args=args, category=Category(category))
contracts.append(contract)
if not contracts:
continue
exc = getattr(builtins, name, name)
excs.append(exc)
return excs

@property
def bytecode(self):
module = ast.parse(TEMPLATE)
contract = self.contract
if isinstance(contract, ast.FunctionDef):
# if contract is function, add it's definition and assign it's name
# to `contract` variable.
module.body = [contract] + module.body
module.body[1].value = ast.Name(
id=contract.name,
lineno=1,
col_offset=1,
ctx=ast.Load(),
)
else:
if isinstance(contract, ast.Expr):
contract = contract.value
module.body[0].value = contract
return compile(module, filename='<ast>', mode='exec')

def run(self, *args, **kwargs):
globals = dict(args=args, kwargs=kwargs)
exec(self.bytecode, globals)
return globals['result']
funcs.append(cls(body=expr.body, contracts=contracts))
return funcs

def __repr__(self) -> str:
return '{name}({category})'.format(
return '{name}({cats})'.format(
name=type(self).__name__,
category=self.category.value,
cats=', '.join(contract.category.value for contract in self.contracts),
)
46 changes: 31 additions & 15 deletions deal/linter/_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

from ._error import Error
from ._extractors import get_exceptions, get_returns, get_imports, get_prints
from ._func import Func, Category
from ._func import Func
from ._contract import Category, Contract


rules = []
Expand All @@ -27,13 +28,14 @@ class CheckImports:

def __call__(self, tree) -> Iterator[Error]:
for token in get_imports(tree.body):
if token.value == 'deal':
yield Error(
code=self.code,
text=self.message,
row=token.line,
col=token.col,
)
if token.value != 'deal':
continue
yield Error(
code=self.code,
text=self.message,
row=token.line,
col=token.col,
)


@register
Expand All @@ -43,11 +45,15 @@ class CheckReturns:
required = Required.FUNC

def __call__(self, func: Func) -> Iterator[Error]:
if func.category != Category.POST:
return
for contract in func.contracts:
if contract.category != Category.POST:
continue
yield from self._check(func=func, contract=contract)

def _check(self, func: Func, contract: Contract) -> Iterator[Error]:
for token in get_returns(body=func.body):
try:
result = func.run(token.value)
result = contract.run(token.value)
except NameError:
# cannot resolve contract dependencies
return
Expand All @@ -72,9 +78,13 @@ class CheckRaises:
required = Required.FUNC

def __call__(self, func: Func) -> Iterator[Error]:
if func.category != Category.RAISES:
return
allowed = func.exceptions
for contract in func.contracts:
if contract.category != Category.RAISES:
continue
yield from self._check(func=func, contract=contract)

def _check(self, func: Func, contract: Contract) -> Iterator[Error]:
allowed = contract.exceptions
allowed_types = tuple(exc for exc in allowed if type(exc) is not str)
for token in get_exceptions(body=func.body):
if token.value in allowed:
Expand All @@ -100,8 +110,14 @@ class CheckPrints:
required = Required.FUNC

def __call__(self, func: Func) -> Iterator[Error]:
if func.category != Category.SILENT:
for contract in func.contracts:
if contract.category != Category.SILENT:
continue
yield from self._check(func=func, contract=contract)
# if `@deal.silent` is duplicated, check the function only once
return

def _check(self, func: Func, contract: Contract) -> Iterator[Error]:
for token in get_prints(body=func.body):
yield Error(
code=self.code,
Expand Down

0 comments on commit 55a0a33

Please sign in to comment.