diff --git a/deal/_cli/_base.py b/deal/_cli/_base.py index 4eeba100..d4a96235 100644 --- a/deal/_cli/_base.py +++ b/deal/_cli/_base.py @@ -1,4 +1,4 @@ -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from pathlib import Path from typing import TextIO @@ -18,5 +18,5 @@ def print(self, *args) -> None: def init_parser(parser: ArgumentParser) -> None: raise NotImplementedError - def __call__(self, args) -> int: + def __call__(self, args: Namespace) -> int: raise NotImplementedError diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py new file mode 100644 index 00000000..faa9b1e9 --- /dev/null +++ b/deal/_cli/_decorate.py @@ -0,0 +1,63 @@ +from argparse import ArgumentParser +from pathlib import Path + +from .._colors import get_colors +from ..linter import TransformationType, Transformer +from ._base import Command +from ._common import get_paths + + +class DecorateCommand(Command): + """Add decorators to your code. + + ```bash + python3 -m deal decorate project/ + ``` + + Options: + + `--types`: types of decorators to apply. All are enabled by default. + + `--double-quotes`: use double quotes. Single quotes are used by default. + + `--nocolor`: do not use colors in the console output. + + The exit code is always 0. If you want to test the code for missed decorators, + use the `lint` command instead. + """ + + @staticmethod + def init_parser(parser: ArgumentParser) -> None: + parser.add_argument( + '--types', + nargs='*', + choices=[tt.value for tt in TransformationType], + default=['has', 'raises', 'safe', 'import'], + help='types of decorators to apply', + ) + parser.add_argument( + '--double-quotes', + action='store_true', + help='use double quotes', + ) + parser.add_argument('--nocolor', action='store_true', help='colorless output') + parser.add_argument('paths', nargs='*', default='.') + + def __call__(self, args) -> int: + types = {TransformationType(t) for t in args.types} + colors = get_colors(args) + for arg in args.paths: + for path in get_paths(Path(arg)): + self.print('{magenta}{path}{end}'.format(path=path, **colors)) + original_code = path.read_text(encoding='utf8') + tr = Transformer( + content=original_code, + path=path, + types=types, + ) + if args.double_quotes: + tr = tr._replace(quote='"') + modified_code = tr.transform() + if original_code == modified_code: + self.print(' {blue}no changes{end}'.format(**colors)) + else: + path.write_text(modified_code) + self.print(' {green}decorated{end}'.format(**colors)) + return 0 diff --git a/deal/_cli/_main.py b/deal/_cli/_main.py index 329c0f3e..50a40974 100644 --- a/deal/_cli/_main.py +++ b/deal/_cli/_main.py @@ -1,33 +1,40 @@ import sys from argparse import ArgumentParser from pathlib import Path -from types import MappingProxyType from typing import Mapping, Sequence, TextIO, Type from ._base import Command -from ._lint import LintCommand -from ._memtest import MemtestCommand -from ._prove import ProveCommand -from ._stub import StubCommand -from ._test import TestCommand CommandsType = Mapping[str, Type[Command]] -COMMANDS: CommandsType = MappingProxyType(dict( - lint=LintCommand, - memtest=MemtestCommand, - prove=ProveCommand, - stub=StubCommand, - test=TestCommand, -)) + + +def get_commands() -> CommandsType: + from ._decorate import DecorateCommand + from ._lint import LintCommand + from ._memtest import MemtestCommand + from ._prove import ProveCommand + from ._stub import StubCommand + from ._test import TestCommand + + return dict( + decorate=DecorateCommand, + lint=LintCommand, + memtest=MemtestCommand, + prove=ProveCommand, + stub=StubCommand, + test=TestCommand, + ) def main( argv: Sequence[str], *, - commands: CommandsType = COMMANDS, + commands: CommandsType = None, root: Path = None, stream: TextIO = sys.stdout, ) -> int: + if commands is None: + commands = get_commands() if root is None: # pragma: no cover root = Path() parser = ArgumentParser(prog='python3 -m deal') diff --git a/deal/_cli/_prove.py b/deal/_cli/_prove.py index 7dbebe53..f82bade8 100644 --- a/deal/_cli/_prove.py +++ b/deal/_cli/_prove.py @@ -19,8 +19,11 @@ class DealTheorem(Theorem): @staticmethod def get_contracts(func: astroid.FunctionDef) -> Iterator[Contract]: - for name, args in get_contracts(func): - yield Contract(name=name, args=args) + for contract in get_contracts(func): + yield Contract( + name=contract.name, + args=contract.args, # type: ignore[arg-type] + ) def run_solver( diff --git a/deal/linter/__init__.py b/deal/linter/__init__.py index 76d57d0b..019b5e8a 100644 --- a/deal/linter/__init__.py +++ b/deal/linter/__init__.py @@ -1,5 +1,12 @@ from ._checker import Checker from ._stub import StubsManager, generate_stub +from ._transformer import TransformationType, Transformer -__all__ = ['Checker', 'StubsManager', 'generate_stub'] +__all__ = [ + 'Checker', + 'generate_stub', + 'StubsManager', + 'TransformationType', + 'Transformer', +] diff --git a/deal/linter/_contract.py b/deal/linter/_contract.py index 5aef275c..1b0f92c9 100644 --- a/deal/linter/_contract.py +++ b/deal/linter/_contract.py @@ -3,7 +3,7 @@ import enum from copy import copy from pathlib import Path -from typing import Dict, FrozenSet, Iterable, List +from typing import Dict, FrozenSet, Iterable, List, Type, Union import astroid @@ -24,6 +24,11 @@ class Category(enum.Enum): PURE = 'pure' RAISES = 'raises' SAFE = 'safe' + INHERIT = 'inherit' + + @property + def brackets_optional(self) -> bool: + return self in {Category.SAFE, Category.PURE} class Contract: @@ -31,6 +36,7 @@ class Contract: category: Category func_args: ast.arguments context: Dict[str, ast.stmt] + line: int def __init__( self, @@ -38,11 +44,13 @@ def __init__( category: Category, func_args: ast.arguments, context: Dict[str, ast.stmt] = None, + line: int = 0, ): self.args = tuple(args) self.category = category self.func_args = func_args self.context = context or dict() + self.line = line @cached_property def body(self) -> ast.AST: @@ -110,7 +118,7 @@ def _resolve_name(contract): return contract # pragma: no cover @cached_property - def exceptions(self) -> list: + def exceptions(self) -> List[Union[str, Type[Exception]]]: from ._extractors import get_name excs = [] diff --git a/deal/linter/_extractors/common.py b/deal/linter/_extractors/common.py index a72dd30d..305a6443 100644 --- a/deal/linter/_extractors/common.py +++ b/deal/linter/_extractors/common.py @@ -31,8 +31,7 @@ class Token(NamedTuple): line: int col: int value: Optional[object] = None - # marker name or error message: - marker: Optional[str] = None + marker: Optional[str] = None # marker name or error message def traverse(body: List) -> Iterator: diff --git a/deal/linter/_extractors/contracts.py b/deal/linter/_extractors/contracts.py index d1e06622..a2fa3ab0 100644 --- a/deal/linter/_extractors/contracts.py +++ b/deal/linter/_extractors/contracts.py @@ -1,5 +1,5 @@ import ast -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, NamedTuple, Optional, Union import astroid @@ -20,7 +20,13 @@ Attr = Union[ast.Attribute, astroid.Attribute] -def get_contracts(func) -> Iterator[Tuple[str, list]]: +class ContractInfo(NamedTuple): + name: str + args: List[Union[ast.expr, astroid.Expr]] + line: int + + +def get_contracts(func) -> Iterator[ContractInfo]: if isinstance(func, ast.FunctionDef): yield from _get_contracts(func.decorator_list) return @@ -32,13 +38,17 @@ def get_contracts(func) -> Iterator[Tuple[str, list]]: yield from _get_contracts(func.decorators.nodes) -def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: +def _get_contracts(decorators: list) -> Iterator[ContractInfo]: for contract in decorators: if isinstance(contract, TOKENS.ATTR): name = get_name(contract) if name not in SUPPORTED_MARKERS: continue - yield name.split('.')[-1], [] + yield ContractInfo( + name=name.split('.')[-1], + args=[], + line=contract.lineno, + ) if name == 'deal.inherit': yield from _resolve_inherit(contract) @@ -50,7 +60,11 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: yield from _get_contracts(contract.args) if name not in SUPPORTED_CONTRACTS: continue - yield name.split('.')[-1], contract.args + yield ContractInfo( + name=name.split('.')[-1], + args=contract.args, + line=contract.lineno, + ) # infer assigned value if isinstance(contract, astroid.Name): @@ -68,7 +82,7 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: yield from _get_contracts([expr.value]) -def _resolve_inherit(contract: Attr) -> Iterator[Tuple[str, List[astroid.Expr]]]: +def _resolve_inherit(contract: Attr) -> Iterator[ContractInfo]: if not isinstance(contract, astroid.Attribute): return cls = _get_parent_class(contract) diff --git a/deal/linter/_extractors/exceptions.py b/deal/linter/_extractors/exceptions.py index 52feb242..a05524b5 100644 --- a/deal/linter/_extractors/exceptions.py +++ b/deal/linter/_extractors/exceptions.py @@ -104,10 +104,10 @@ def _exceptions_from_func(expr) -> Iterator[Token]: yield Token(value=error.value, line=expr.lineno, col=expr.col_offset) # get explicitly specified exceptions from `@deal.raises` - for category, args in get_contracts(value): - if category != 'raises': + for contract in get_contracts(value): + if contract.name != 'raises': continue - for arg in args: + for arg in contract.args: name = get_name(arg) if name is None: continue diff --git a/deal/linter/_extractors/markers.py b/deal/linter/_extractors/markers.py index 005fb33b..931886b6 100644 --- a/deal/linter/_extractors/markers.py +++ b/deal/linter/_extractors/markers.py @@ -273,10 +273,10 @@ def _markers_from_func(expr: astroid.NodeNG, inferred: tuple) -> Iterator[Token] ) # get explicitly specified markers from `@deal.has` - for category, args in get_contracts(value): - if category != 'has': + for contract in get_contracts(value): + if contract.name != 'has': continue - for arg in args: + for arg in contract.args: value = get_value(arg) if type(value) is not str: continue diff --git a/deal/linter/_extractors/pre.py b/deal/linter/_extractors/pre.py index 7caf3dbd..24c10469 100644 --- a/deal/linter/_extractors/pre.py +++ b/deal/linter/_extractors/pre.py @@ -34,12 +34,11 @@ def handle_call(expr: astroid.Call, context: Dict[str, ast.stmt] = None) -> Iter continue code = f'def f({func.args.as_string()}):0' func_args = ast.parse(code).body[0].args # type: ignore - for category, contract_args in get_contracts(func): - if category != 'pre': + for cinfo in get_contracts(func): + if cinfo.name != 'pre': continue - contract = Contract( - args=contract_args, + args=cinfo.args, category=Category.PRE, func_args=func_args, context=context, diff --git a/deal/linter/_func.py b/deal/linter/_func.py index 4347bd79..27ea03cd 100644 --- a/deal/linter/_func.py +++ b/deal/linter/_func.py @@ -36,12 +36,13 @@ def from_ast(cls, tree: ast.Module) -> List['Func']: if not isinstance(expr, ast.FunctionDef): continue contracts = [] - for category, args in get_contracts(expr): + for cinfo in get_contracts(expr): contract = Contract( - args=args, + args=cinfo.args, func_args=expr.args, - category=Category(category), + category=Category(cinfo.name), context=definitions, + line=cinfo.line, ) contracts.append(contract) funcs.append(cls( @@ -68,12 +69,13 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']: # collect contracts contracts = [] - for category, args in get_contracts(expr): + for cinfo in get_contracts(expr): contract = Contract( - args=args, + args=cinfo.args, func_args=func_args, - category=Category(category), + category=Category(cinfo.name), context=definitions, + line=cinfo.line, ) contracts.append(contract) funcs.append(cls( @@ -86,6 +88,12 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']: )) return funcs + def has_contract(self, *categories: Category) -> bool: + for contract in self.contracts: + if contract.category in categories: + return True + return False + def __repr__(self) -> str: cats = ', '.join(contract.category.value for contract in self.contracts) return f'{type(self).__name__}({cats})' diff --git a/deal/linter/_rules.py b/deal/linter/_rules.py index bbf9d7c6..cb9eb882 100644 --- a/deal/linter/_rules.py +++ b/deal/linter/_rules.py @@ -1,7 +1,7 @@ import ast from itertools import chain from types import MappingProxyType -from typing import Iterator, List, Optional, Type, TypeVar +from typing import Iterator, List, Optional, Set, Type, TypeVar, Union import astroid @@ -17,6 +17,7 @@ T = TypeVar('T', bound=Type['Rule']) +Exceptions = List[Union[str, Type[Exception]]] rules: List['Rule'] = [] @@ -239,20 +240,29 @@ class CheckRaises(FuncRule): def __call__(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]: cats = {Category.RAISES, Category.SAFE, Category.PURE} + declared: Exceptions = [] + check = False for contract in func.contracts: if contract.category not in cats: continue - yield from self._check(func=func, contract=contract, stubs=stubs) - - def _check(self, func: Func, contract: Contract, stubs: StubsManager = None) -> Iterator[Error]: - allowed = contract.exceptions - allowed_types = tuple(exc for exc in allowed if type(exc) is not str) + declared.extend(contract.exceptions) + check = True + if check: + yield from self.get_undeclared(func, declared, stubs) + + def get_undeclared( + self, + func: Func, + declared: Exceptions, + stubs: Optional[StubsManager] = None, + ) -> Iterator[Error]: + declared_types = tuple(exc for exc in declared if not isinstance(exc, str)) for token in get_exceptions(body=func.body, stubs=stubs): - if token.value in allowed: + if token.value in declared: continue exc = token.value if isinstance(exc, type): - if issubclass(exc, allowed_types): + if issubclass(exc, declared_types): continue exc = exc.__name__ yield Error( @@ -309,18 +319,18 @@ class CheckMarkers(FuncRule): def __call__(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]: for contract in func.contracts: - markers = None + markers: Optional[Set[str]] = None if contract.category == Category.HAS: - markers = [get_value(arg) for arg in contract.args] + markers = {get_value(arg) for arg in contract.args} elif contract.category == Category.PURE: - markers = [] + markers = set() if markers is None: continue - yield from self._check(func=func, markers=markers) + yield from self.get_undeclared(func=func, markers=markers) return @classmethod - def _check(cls, func: Func, markers: List[str]) -> Iterator[Error]: + def get_undeclared(cls, func: Func, markers: Set[str]) -> Iterator[Error]: has = HasPatcher(markers) # function without IO must return something if not has.has_io and not has_returns(body=func.body): diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py new file mode 100644 index 00000000..ffd87042 --- /dev/null +++ b/deal/linter/_transformer.py @@ -0,0 +1,244 @@ +from enum import Enum +from pathlib import Path +from typing import Iterator, List, NamedTuple, Set, Tuple, Union + +import astroid + +from ._contract import Category +from ._extractors import get_value +from ._func import Func +from ._rules import CheckMarkers, CheckRaises + + +Priority = int + + +class TransformationType(Enum): + RAISES = 'raises' + HAS = 'has' + SAFE = 'safe' + IMPORT = 'import' + + +class InsertText(NamedTuple): + line: int + text: str + + def apply(self, lines: List[str]) -> None: + lines.insert(self.line - 1, f'{self}\n') + + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 1) + + def __str__(self) -> str: + return self.text + + +class InsertContract(NamedTuple): + line: int + contract: Category + args: List[str] + indent: int + + def apply(self, lines: List[str]) -> None: + lines.insert(self.line - 1, f'{self}\n') + + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 1) + + def __str__(self) -> str: + args = ', '.join(self.args) + if not self.args and self.contract.brackets_optional: + dec = f'@deal.{self.contract.value}' + else: + dec = f'@deal.{self.contract.value}({args})' + return ' ' * self.indent + dec + + +class Remove(NamedTuple): + line: int + + def apply(self, lines: List[str]) -> None: + lines.pop(self.line - 1) + + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 2) + + +Mutation = Union[InsertText, InsertContract, Remove] + + +class Transformer(NamedTuple): + """Transformer adds deal decorators into the given script. + """ + content: str + path: Path + types: Set[TransformationType] + mutations: List[Mutation] = [] + quote: str = "'" + + def transform(self) -> str: + self.mutations.clear() + tree = astroid.parse(self.content, path=self.path) + for func in Func.from_astroid(tree): + self._collect_mutations(func) + self.mutations.extend(self._mutations_import(tree)) + return self._apply_mutations(self.content) + + def _collect_mutations(self, func: Func) -> None: + self.mutations.extend(self._mutations_excs(func)) + self.mutations.extend(self._mutations_markers(func)) + + def _mutations_excs(self, func: Func) -> Iterator[Mutation]: + """Add @deal.raises or @deal.safe if needed. + """ + cats = {Category.RAISES, Category.SAFE, Category.PURE} + + # collect declared exceptions + declared: List[Union[str, type]] = [] + for contract in func.contracts: + if contract.category not in cats: + continue + declared.extend(contract.exceptions) + + # collect undeclared exceptions + excs: Set[str] = set() + for error in CheckRaises().get_undeclared(func, declared): + assert isinstance(error.value, str) + excs.add(error.value) + + # if no new exceptions found, add deal.safe + if not excs: + if declared: + return + if TransformationType.SAFE not in self.types: + return + if func.has_contract(Category.PURE, Category.SAFE): + return + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.SAFE, + args=[], + ) + return + + # if new exceptions detected, remove old contracts and add a new deal.raises + if TransformationType.RAISES not in self.types: + return + for contract in func.contracts: + if contract.category not in cats: + continue + yield Remove(contract.line) + if contract.category == Category.PURE: + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[], + ) + contract_args = [self._exc_as_str(exc) for exc in declared] + contract_args.extend(sorted(excs)) + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.RAISES, + args=contract_args, + ) + + @staticmethod + def _exc_as_str(exc) -> str: + if isinstance(exc, str): + return exc + return exc.__name__ + + def _mutations_markers(self, func: Func) -> Iterator[Mutation]: + """Add @deal.has if needed. + """ + if TransformationType.HAS not in self.types: + return + cats = {Category.HAS, Category.PURE} + + # collect declared markers + declared: List[str] = [] + for contract in func.contracts: + if contract.category not in cats: + continue + declared.extend(get_value(arg) for arg in contract.args) + + # collect undeclared markers + markers: Set[str] = set() + for error in CheckMarkers().get_undeclared(func, set(declared)): + assert isinstance(error.value, str) + markers.add(error.value) + + # if no new markers found, add deal.has() + if not markers: + if func.has_contract(Category.PURE, Category.HAS): + return + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[], + ) + return + + # if new markers detected, remove old contracts and add a new deal.raises + for contract in func.contracts: + if contract.category not in cats: + continue + yield Remove(contract.line) + if contract.category == Category.PURE: + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.SAFE, + args=[], + ) + contract_args = [self._exc_as_str(marker) for marker in declared] + contract_args.extend(sorted(markers)) + yield InsertContract( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[f'{self.quote}{arg}{self.quote}' for arg in contract_args], + ) + + def _mutations_import(self, tree: astroid.Module) -> Iterator[Mutation]: + """Add `import deal` if needed. + """ + if TransformationType.IMPORT not in self.types: + return + if not self.mutations: + return + # check if already imported + for stmt in tree.body: + if not isinstance(stmt, astroid.Import): + continue + for name, _ in stmt.names: + if name == 'deal': + return + + # We insert the import after `__future__` imports and module imports. + # We don't skip `from` imports, though, because they can be multiline. + line = 1 + for stmt in tree.body: + if isinstance(stmt, astroid.Import): + line = stmt.lineno + 1 + if isinstance(stmt, astroid.ImportFrom): + if stmt.modname == '__future__': + line = stmt.lineno + 1 + yield InsertText(line=line, text='import deal') + + def _apply_mutations(self, content: str) -> str: + if not self.mutations: + return content + lines = content.splitlines(keepends=True) + self.mutations.sort(key=lambda x: x.key, reverse=True) + for mutation in self.mutations: + mutation.apply(lines) + return ''.join(lines) diff --git a/docs/basic/refs.md b/docs/basic/refs.md index a8206dcd..f1d859c7 100644 --- a/docs/basic/refs.md +++ b/docs/basic/refs.md @@ -52,13 +52,14 @@ This page provides a quick navigation by the documentation in case if you're loo ## CLI commands -| command | reference | documentation | -| --------- | -------------------------- | ------------- | -| `lint` | {ref}`details/cli:lint` | {doc}`/basic/linter` / {ref}`basic/linter:built-in cli command` | -| `memtest` | {ref}`details/cli:memtest` | {doc}`/details/tests` / {ref}`details/tests:finding memory leaks` | -| `prove` | {ref}`details/cli:prove` | {doc}`/basic/verification` | -| `stub` | {ref}`details/cli:stub` | {doc}`/details/stubs` | -| `test` | {ref}`details/cli:test` | {doc}`/basic/tests` / {ref}`basic/tests:cli` | +| command | reference | documentation | +| ---------- | --------------------------- | ------------- | +| `decorate` | {ref}`details/cli:decorate` | {doc}`/details/contracts` / {ref}`details/contracts:generating contracts` | +| `lint` | {ref}`details/cli:lint` | {doc}`/basic/linter` / {ref}`basic/linter:built-in cli command` | +| `memtest` | {ref}`details/cli:memtest` | {doc}`/details/tests` / {ref}`details/tests:finding memory leaks` | +| `prove` | {ref}`details/cli:prove` | {doc}`/basic/verification` | +| `stub` | {ref}`details/cli:stub` | {doc}`/details/stubs` | +| `test` | {ref}`details/cli:test` | {doc}`/basic/tests` / {ref}`basic/tests:cli` | ## Integrations diff --git a/docs/details/cli.md b/docs/details/cli.md index 5ac3f266..44034f49 100644 --- a/docs/details/cli.md +++ b/docs/details/cli.md @@ -6,10 +6,10 @@ .. autofunction:: deal._cli._lint.LintCommand ``` -## stub +## decorate ```{eval-rst} -.. autofunction:: deal._cli._stub.StubCommand +.. autofunction:: deal._cli._decorate.DecorateCommand ``` ## test @@ -24,6 +24,12 @@ .. autofunction:: deal._cli._memtest.MemtestCommand ``` +## stub + +```{eval-rst} +.. autofunction:: deal._cli._stub.StubCommand +``` + ## prove ```{eval-rst} diff --git a/docs/details/contracts.md b/docs/details/contracts.md index 88060f21..5e7922db 100644 --- a/docs/details/contracts.md +++ b/docs/details/contracts.md @@ -1,5 +1,21 @@ # More on writing contracts +## Generating contracts + +The best way to get started with deal is to automatically generate some contracts using {ref}`details/cli:decorate` CLI command: + +```bash +python3 -m deal decorate my_project/ +``` + +It will run {doc}`/basic/linter` on your code and add some of the missed contracts. The rest of contacts is still on you, though. Also, you should carefully check the generated code for correctness, deal may miss something. + +The following contracts are supported by the command and will be added to your code: + ++ {py:func}`deal.has` ++ {py:func}`deal.raises` ++ {py:func}`deal.safe` + ## Simplified signature The main problem with contracts is that they have to duplicate the original function's signature, including default arguments. While it's not a problem for small examples, things become more complicated when the signature grows. In this case, you can specify a function that accepts only one `_` argument, and deal will pass there a container with arguments of the function call, including default ones: diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py new file mode 100644 index 00000000..28327861 --- /dev/null +++ b/tests/test_cli/test_decorate.py @@ -0,0 +1,92 @@ +from io import StringIO +from pathlib import Path +from textwrap import dedent + +import pytest + +from deal._cli import main + + +@pytest.mark.parametrize('flags, given, expected', [ + ( + [], + """ + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + + @deal.has('stdout') + @deal.raises(ZeroDivisionError) + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + ), + ( + ['--types', 'raises', 'safe'], + """ + import deal + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + @deal.raises(ZeroDivisionError) + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + ), + ( + ['--types', 'has', '--double-quotes'], + """ + import deal + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + @deal.has("stdout") + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + ), + ( + [], + """ + import deal + @deal.pure + def f(x): + return x + """, + """ + import deal + @deal.pure + def f(x): + return x + """, + ), +]) +def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path): + file_path = tmp_path / 'example.py' + file_path.write_text(dedent(given)) + stream = StringIO() + code = main(['decorate', *flags, '--', str(tmp_path)], stream=stream) + assert code == 0 + + stream.seek(0) + captured = stream.read() + assert str(file_path) in captured + assert file_path.read_text().lstrip('\n') == dedent(expected).lstrip('\n') diff --git a/tests/test_docs.py b/tests/test_docs.py index cc2c3b9e..4b1e51f3 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -5,7 +5,7 @@ import pytest import deal -from deal._cli._main import COMMANDS +from deal._cli._main import get_commands from deal.linter._rules import CheckMarkers, rules @@ -51,7 +51,7 @@ def test_all_marker_codes_listed(): def test_cli_included(): path = root / 'details' / 'cli.md' content = path.read_text() - for name, cmd in COMMANDS.items(): + for name, cmd in get_commands().items(): # has header tmpl = '## {n}\n\n' line = tmpl.format(n=name) @@ -106,7 +106,7 @@ def test_all_public_listed_in_refs(): def test_all_cli_commands_listed_in_refs(): path = root / 'basic' / 'refs.md' content = path.read_text() - for name, cmd in COMMANDS.items(): + for name, cmd in get_commands().items(): assert f'`{name}`' in content assert f'`details/cli:{name}`' in content diff --git a/tests/test_linter/test_contract.py b/tests/test_linter/test_contract.py index a6edf362..35172753 100644 --- a/tests/test_linter/test_contract.py +++ b/tests/test_linter/test_contract.py @@ -30,7 +30,11 @@ def test_exceptions(): def test_repr(): - c = Contract(category=Category.RAISES, args=[], func_args=None) + c = Contract( + category=Category.RAISES, + args=[], + func_args=None, # type: ignore[arg-type] + ) assert repr(c) == 'Contract(raises)' diff --git a/tests/test_linter/test_extractors/test_contracts.py b/tests/test_linter/test_extractors/test_contracts.py index fe77a8c1..1cbce1b7 100644 --- a/tests/test_linter/test_extractors/test_contracts.py +++ b/tests/test_linter/test_extractors/test_contracts.py @@ -4,11 +4,18 @@ import astroid import pytest +from deal.linter._contract import Category from deal.linter._extractors import get_contracts +from deal.linter._extractors.contracts import SUPPORTED_CONTRACTS, SUPPORTED_MARKERS + + +def test_supported_contracts_match_categories(): + cats = {f'deal.{c.value}' for c in Category} + assert cats == SUPPORTED_CONTRACTS | SUPPORTED_MARKERS def get_cats(target) -> tuple: - return tuple(cat for cat, _ in get_contracts(target)) + return tuple(contract.name for contract in get_contracts(target)) @pytest.mark.parametrize('text, expected', [ diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py new file mode 100644 index 00000000..a68dfbc7 --- /dev/null +++ b/tests/test_linter/test_transformer.py @@ -0,0 +1,467 @@ +from pathlib import Path +from textwrap import dedent + +import pytest + +from deal.linter import TransformationType, Transformer + + +@pytest.mark.parametrize('content', [ + # add deal.safe + """ + def f(): + return 1 + --- + @deal.safe + def f(): + return 1 + """, + # preserve deal.safe + """ + @deal.safe + def f(): + return 1 + --- + @deal.safe + def f(): + return 1 + """, + # preserve deal.safe with comments + """ + @deal.safe # oh hi mark + def f(): + return 1 + --- + @deal.safe # oh hi mark + def f(): + return 1 + """, + # preserve deal.raises + """ + @deal.raises(KeyError) + def f(): + return 1 + --- + @deal.raises(KeyError) + def f(): + return 1 + """, + """ + @deal.raises(KeyError, UnknownError) + def f(): + return 1 + --- + @deal.raises(KeyError, UnknownError) + def f(): + return 1 + """, + # add a new deal.raises + """ + def f(): + raise ValueError + --- + @deal.raises(ValueError) + def f(): + raise ValueError + """, + # add deal.raises for unknown error + """ + def f(): + raise UnknownError + --- + @deal.raises(UnknownError) + def f(): + raise UnknownError + """, + # preserve unknown error + """ + @deal.raises(UnknownError) + def f(): + raise ValueError + --- + @deal.raises(UnknownError, ValueError) + def f(): + raise ValueError + """, + # remove deal.safe if adding deal.raises + """ + @deal.safe + def f(): + raise ValueError + --- + @deal.raises(ValueError) + def f(): + raise ValueError + """, + # remove deal.pure if adding deal.raises + """ + @deal.pure + def f(): + raise ValueError + --- + @deal.raises(ValueError) + @deal.has() + def f(): + raise ValueError + """, + # merge deal.raises + """ + @deal.raises(ZeroDivisionError) + def f(): + raise ValueError + --- + @deal.raises(ZeroDivisionError, ValueError) + def f(): + raise ValueError + """, + # preserve comments + """ + # oh hi mark + def f(): + raise ValueError + --- + # oh hi mark + @deal.raises(ValueError) + def f(): + raise ValueError + """, + # preserve comments + """ + # oh hi mark + + def f(): + # hello + raise ValueError + --- + # oh hi mark + + @deal.raises(ValueError) + def f(): + # hello + raise ValueError + """, + # preserve contracts + """ + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.raises(ValueError) + @deal.pre(lambda: True) + def f(): + 1/0 + --- + @deal.raises(ValueError, ZeroDivisionError) + @deal.pre(lambda: True) + def f(): + 1/0 + """, + # # support methods + # """ + # class A: + # def f(self): + # 1/0 + # --- + # class A: + # @deal.raises(ZeroDivisionError) + # def f(self): + # 1/0 + # """, +]) +def test_transformer_raises(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given) + expected = dedent(expected) + tr = Transformer( + content=given, + path=tmp_path / 'example.py', + types={TransformationType.RAISES, TransformationType.SAFE}, + ) + actual = tr.transform() + print(tr.mutations) + assert actual == expected + + +@pytest.mark.parametrize('content', [ + # add deal.has() + """ + def f(): + return 1 + --- + @deal.has() + def f(): + return 1 + """, + # add deal.has with markers + """ + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + def f(): + print("hi") + return 1 + """, + """ + import random + def f(): + print(random.choice([1,2])) + return 1 + --- + import random + @deal.has('random', 'stdout') + def f(): + print(random.choice([1,2])) + return 1 + """, + # merge deal.has + """ + @deal.has('random') + def f(): + print("hi") + return 1 + --- + @deal.has('random', 'stdout') + def f(): + print("hi") + return 1 + """, + """ + @deal.has() + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + def f(): + print("hi") + return 1 + """, + # replace deal.pure + """ + @deal.pure + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + @deal.safe + def f(): + print("hi") + return 1 + """, + # preserve contracts + """ + @deal.pre(lambda: True) + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + @deal.pre(lambda: True) + def f(): + print("hi") + return 1 + """, + """ + @deal.has('random') + def f(): + return 1 + --- + @deal.has('random') + def f(): + return 1 + """, + """ + @deal.has('random') + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has('random') + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + """, + # do not add deal.raises if transformation type is disabled + """ + def f(): + raise ValueError + return 1 + --- + @deal.has() + def f(): + raise ValueError + return 1 + """, +]) +def test_transformer_has(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given) + expected = dedent(expected) + tr = Transformer( + content=given, + path=tmp_path / 'example.py', + types={TransformationType.HAS}, + ) + actual = tr.transform() + assert actual == expected + + +@pytest.mark.parametrize('content', [ + # add import if needed + """ + def f(): + return 1 + --- + import deal + @deal.has() + def f(): + return 1 + """, + # skip imports + """ + import re + + def f(): + return 1 + --- + import re + import deal + + @deal.has() + def f(): + return 1 + """, + # skip import-from, do not skip consts + """ + import re + from textwrap import dedent + + HI = 1 + + def f(): + return 1 + --- + import re + import deal + from textwrap import dedent + + HI = 1 + + @deal.has() + def f(): + return 1 + """, + # do nothing if there are no mutations + """ + @deal.has() + def f(): + return 1 + --- + @deal.has() + def f(): + return 1 + """, + # do not duplicate existing imports + """ + import deal + def f(): + return 1 + --- + import deal + @deal.has() + def f(): + return 1 + """, + # support multiline imports + """ + from textwrap import ( + dedent, + ) + def f(): + return 1 + --- + import deal + from textwrap import ( + dedent, + ) + @deal.has() + def f(): + return 1 + """, + # skip __future__ imports + """ + from __future__ import annotations + def f(): + return 1 + --- + from __future__ import annotations + import deal + @deal.has() + def f(): + return 1 + """, + # skip from imports before module imports + """ + from __future__ import annotations + def f(): + return 1 + --- + from __future__ import annotations + import deal + @deal.has() + def f(): + return 1 + """, +]) +def test_transformer_import(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given).lstrip('\n') + expected = dedent(expected) + tr = Transformer( + content=given, + path=tmp_path / 'example.py', + types={TransformationType.HAS, TransformationType.IMPORT}, + ) + actual = tr.transform() + assert actual.lstrip('\n') == expected.lstrip('\n')