Skip to content

Commit

Permalink
Merge c210a81 into 2a8e9bf
Browse files Browse the repository at this point in the history
  • Loading branch information
orsinium committed May 11, 2020
2 parents 2a8e9bf + c210a81 commit 98704a5
Show file tree
Hide file tree
Showing 6 changed files with 217 additions and 3 deletions.
2 changes: 2 additions & 0 deletions deal/_cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
# app
from ._lint import lint_command
from ._stub import stub_command
from ._test import test_command


CommandsType = Mapping[str, Callable[[Sequence[str]], int]]
COMMANDS: CommandsType = MappingProxyType(dict(
lint=lint_command,
stub=stub_command,
test=test_command,
))


Expand Down
105 changes: 105 additions & 0 deletions deal/_cli/_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# built-in
import sys
from argparse import ArgumentParser
from contextlib import contextmanager
from importlib import import_module
from pathlib import Path
from textwrap import indent
from traceback import format_exception
from typing import Iterator, Sequence, TextIO

# app
from .._testing import cases
from ..linter._contract import Category
from ..linter._extractors.pre import format_call_args
from ..linter._func import Func


COLORS = dict(
red='\033[91m',
green='\033[92m',
yellow='\033[93m',
blue='\033[94m',
magenta='\033[95m',
end='\033[0m',
)


@contextmanager
def sys_path(path: Path):
path = str(path)
sys.path.insert(0, path)
try:
yield
finally:
if sys.path[0] == path:
del sys.path[0]


def has_pure_contract(func: Func) -> bool:
for contract in func.contracts:
if contract.category == Category.PURE:
return True
return False


def get_func_names(path: Path) -> Iterator[str]:
for func in Func.from_path(path=path):
if has_pure_contract(func=func):
yield func.name


def print_exception(stream: TextIO) -> None:
lines = format_exception(*sys.exc_info())
text = indent(text=''.join(lines), prefix=' ')
text = '{red}{text}{end}'.format(text=text, **COLORS)
print(text, file=stream)


def run_tests(path: Path, root: Path, count: int, stream: TextIO = sys.stdout) -> int:
names = list(get_func_names(path=path))
if not names:
return 0
print('{magenta}running {path}{end}'.format(path=path, **COLORS), file=stream)
module_name = '.'.join(path.relative_to(root).with_suffix('').parts)
with sys_path(path=root):
module = import_module(module_name)
failed = 0
for func_name in names:
print(' {blue}running {name}{end}'.format(name=func_name, **COLORS), file=stream)
func = getattr(module, func_name)
for case in cases(func=func, count=count):
try:
case()
except Exception:
line = ' {yellow}{name}({args}){end}'.format(
name=func_name,
args=format_call_args(args=case.args, kwargs=case.kwargs),
**COLORS,
)
print(line, file=stream)
print_exception(stream=stream)
failed += 1
break
return failed


def test_command(
argv: Sequence[str], root: Path = None, stream: TextIO = sys.stdout,
) -> int:
if root is None: # pragma: no cover
root = Path()
parser = ArgumentParser(prog='python3 -m deal test')
parser.add_argument('--count', type=int, default=50)
parser.add_argument('paths', nargs='+')
args = parser.parse_args(argv)

failed = 0
for path in args.paths:
failed += run_tests(
path=Path(path),
root=root,
count=args.count,
stream=stream,
)
return failed
6 changes: 3 additions & 3 deletions deal/linter/_extractors/pre.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# built-in
import ast
from typing import Iterator
from typing import Iterator, Sequence

# external
import astroid
Expand Down Expand Up @@ -54,11 +54,11 @@ def handle_call(expr: astroid.Call) -> Iterator[Token]:
except NameError:
continue
if result is False or type(result) is str:
msg = _format_message(args, kwargs)
msg = format_call_args(args, kwargs)
yield Token(value=msg, line=expr.lineno, col=expr.col_offset)


def _format_message(args: list, kwargs: dict) -> str:
def format_call_args(args: Sequence, kwargs: dict) -> str:
sep = ', '
args_s = sep.join(map(repr, args))
kwargs_s = sep.join(['{}={!r}'.format(k, v) for k, v in kwargs.items()])
Expand Down
3 changes: 3 additions & 0 deletions docs/commands/test.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# **stub**: test pure functions

Extracts `@deal.pure` functions and runs [autogenerated tests](../testing) for it.
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
commands/lint
commands/stub
commands/test
.. toctree::
:maxdepth: 1
Expand Down
103 changes: 103 additions & 0 deletions tests/test_cli/test_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# built-in
import sys
from io import StringIO
from pathlib import Path
from textwrap import dedent

# project
from deal._cli._test import sys_path, test_command as command


def test_safe_violation(tmp_path: Path, capsys):
if 'example' in sys.modules:
del sys.modules['example']
text = """
import deal
@deal.pure
def func(a: int, b: int) -> float:
return a / b
"""
path = (tmp_path / 'example.py')
path.write_text(dedent(text))
stream = StringIO()
result = command(['--count', '1', str(path)], root=tmp_path, stream=stream)
assert result == 1

stream.seek(0)
captured = stream.read()
assert '/example.py' in captured
assert 'running func' in captured
assert 'func(a=0, b=0)' in captured
assert 'ZeroDivisionError' in captured
assert 'RaisesContractError' in captured


def test_no_violations(tmp_path: Path):
if 'example' in sys.modules:
del sys.modules['example']
text = """
import deal
@deal.pure
def func(a: int, b: int) -> float:
return a + b
def not_pure1(a: int, b: int) -> float:
return a / b
@deal.post(lambda result: result > 0)
def not_pure2(a: int, b: int) -> float:
return a / b
"""
path = (tmp_path / 'example.py')
path.write_text(dedent(text))
stream = StringIO()
result = command(['--count', '5', str(path)], root=tmp_path, stream=stream)
assert result == 0

stream.seek(0)
captured = stream.read()
assert '/example.py' in captured
assert 'running func' in captured
assert 'not_pure' not in captured
assert 'func(' not in captured


def test_no_matching_funcs(tmp_path: Path):
if 'example' in sys.modules:
del sys.modules['example']
text = """
import deal
def not_pure1(a: int, b: int) -> float:
return a / b
@deal.post(lambda result: result > 0)
def not_pure2(a: int, b: int) -> float:
return a / b
"""
path = (tmp_path / 'example.py')
path.write_text(dedent(text))
stream = StringIO()
result = command(['--count', '5', str(path)], root=tmp_path, stream=stream)
assert result == 0

stream.seek(0)
captured = stream.read()
assert '/example.py' not in captured


def test_sys_path():
path = Path('example')
size = len(sys.path)

assert sys.path[0] != 'example'
with sys_path(path):
assert sys.path[0] == 'example'
assert sys.path[0] != 'example'
assert len(sys.path) == size

with sys_path(path):
del sys.path[0]
assert len(sys.path) == size

0 comments on commit 98704a5

Please sign in to comment.