From 38408b78194e312fc9ec4dffeb9bea816c6a8a9b Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 11:58:07 +0100 Subject: [PATCH 01/20] transformer: generate new deal.raises --- deal/linter/__init__.py | 3 +- deal/linter/_contract.py | 4 +- deal/linter/_extractors/common.py | 3 +- deal/linter/_rules.py | 26 +++++++--- deal/linter/_transformer.py | 72 +++++++++++++++++++++++++++ tests/test_linter/test_transformer.py | 44 ++++++++++++++++ 6 files changed, 139 insertions(+), 13 deletions(-) create mode 100644 deal/linter/_transformer.py create mode 100644 tests/test_linter/test_transformer.py diff --git a/deal/linter/__init__.py b/deal/linter/__init__.py index 76d57d0b..443fe4d5 100644 --- a/deal/linter/__init__.py +++ b/deal/linter/__init__.py @@ -1,5 +1,6 @@ from ._checker import Checker from ._stub import StubsManager, generate_stub +from ._transformer import Transformer -__all__ = ['Checker', 'StubsManager', 'generate_stub'] +__all__ = ['Checker', 'Transformer', 'StubsManager', 'generate_stub'] diff --git a/deal/linter/_contract.py b/deal/linter/_contract.py index 5aef275c..204bcd68 100644 --- a/deal/linter/_contract.py +++ b/deal/linter/_contract.py @@ -3,7 +3,7 @@ import enum from copy import copy from pathlib import Path -from typing import Dict, FrozenSet, Iterable, List +from typing import Dict, FrozenSet, Iterable, List, Type, Union import astroid @@ -110,7 +110,7 @@ def _resolve_name(contract): return contract # pragma: no cover @cached_property - def exceptions(self) -> list: + def exceptions(self) -> List[Union[str, Type[Exception]]]: from ._extractors import get_name excs = [] diff --git a/deal/linter/_extractors/common.py b/deal/linter/_extractors/common.py index a72dd30d..305a6443 100644 --- a/deal/linter/_extractors/common.py +++ b/deal/linter/_extractors/common.py @@ -31,8 +31,7 @@ class Token(NamedTuple): line: int col: int value: Optional[object] = None - # marker name or error message: - marker: Optional[str] = None + marker: Optional[str] = None # marker name or error message def traverse(body: List) -> Iterator: diff --git a/deal/linter/_rules.py b/deal/linter/_rules.py index bbf9d7c6..f10e9c80 100644 --- a/deal/linter/_rules.py +++ b/deal/linter/_rules.py @@ -1,7 +1,7 @@ import ast from itertools import chain from types import MappingProxyType -from typing import Iterator, List, Optional, Type, TypeVar +from typing import Iterator, List, Optional, Type, TypeVar, Union import astroid @@ -17,6 +17,7 @@ T = TypeVar('T', bound=Type['Rule']) +Exceptions = List[Union[str, Type[Exception]]] rules: List['Rule'] = [] @@ -239,20 +240,29 @@ class CheckRaises(FuncRule): def __call__(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]: cats = {Category.RAISES, Category.SAFE, Category.PURE} + declared: Exceptions = [] + check = False for contract in func.contracts: if contract.category not in cats: continue - yield from self._check(func=func, contract=contract, stubs=stubs) - - def _check(self, func: Func, contract: Contract, stubs: StubsManager = None) -> Iterator[Error]: - allowed = contract.exceptions - allowed_types = tuple(exc for exc in allowed if type(exc) is not str) + declared.extend(contract.exceptions) + check = True + if check: + yield from self.get_undeclared(func, declared, stubs) + + def get_undeclared( + self, + func: Func, + declared: Exceptions, + stubs: Optional[StubsManager] = None, + ) -> Iterator[Error]: + declared_types = tuple(exc for exc in declared if not isinstance(exc, str)) for token in get_exceptions(body=func.body, stubs=stubs): - if token.value in allowed: + if token.value in declared: continue exc = token.value if isinstance(exc, type): - if issubclass(exc, allowed_types): + if issubclass(exc, declared_types): continue exc = exc.__name__ yield Error( diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py new file mode 100644 index 00000000..edd76bc9 --- /dev/null +++ b/deal/linter/_transformer.py @@ -0,0 +1,72 @@ +from pathlib import Path +from typing import List, NamedTuple, Optional, Set +from .._cached_property import cached_property +from ._contract import Category +from ._func import Func +from ._rules import CheckRaises + + +class Mutation(NamedTuple): + line: int + contract: Category + args: List[str] + indent: int + + def __str__(self): + args = ', '.join(self.args) + dec = f'@deal.{self.contract.value}({args})' + return ' ' * self.indent + dec + + +class Transformer: + path: Path + mutations: List[Mutation] + + def __init__(self, path: Path) -> None: + self.path = path + self.mutations = [] + + @cached_property + def funcs(self) -> List[Func]: + return Func.from_path(self.path) + + def transform(self) -> str: + for func in self.funcs: + self._collect_mutations(func) + content = self.path.read_text() + return self._apply_mutations(content) + + def _collect_mutations(self, func: Func) -> None: + mut = self._mutation_excs(func) + if mut is not None: + self.mutations.append(mut) + + def _mutation_excs(self, func: Func) -> Optional[Mutation]: + checker = CheckRaises() + excs: Set[str] = set() + cats = {Category.RAISES, Category.SAFE, Category.PURE} + declared = [] + for contract in func.contracts: + if contract.category not in cats: + continue + declared.extend(contract.exceptions) + for error in checker.get_undeclared(func, declared): + assert isinstance(error.value, str) + excs.add(error.value) + if not excs: + return None + return Mutation( + line=func.line, + indent=func.col, + contract=Category.RAISES, + args=sorted(excs), + ) + + def _apply_mutations(self, content: str) -> str: + if not self.mutations: + return content + lines = content.splitlines(keepends=True) + self.mutations.sort(key=lambda x: x.line, reverse=True) + for mutation in self.mutations: + lines.insert(mutation.line - 1, f'{mutation}\n') + return ''.join(lines) diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py new file mode 100644 index 00000000..cb1958b2 --- /dev/null +++ b/tests/test_linter/test_transformer.py @@ -0,0 +1,44 @@ +from pathlib import Path +from textwrap import dedent +import pytest +from deal.linter import Transformer + + +@pytest.mark.parametrize('content', [ + # no-op + """ + def f(): + return 1 + --- + def f(): + return 1 + """, + # preserve contracts + """ + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.pre(lambda: True) + def f(): + return 1 + """, + # add a new deal.raises + """ + def f(): + raise ValueError + --- + @deal.raises(ValueError) + def f(): + raise ValueError + """, +]) +def test_transformer(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given) + expected = dedent(expected) + path = tmp_path / "example.py" + path.write_text(given) + tr = Transformer(path=path) + actual = tr.transform() + assert actual == expected From d013d128278e161ac6e227702dc70e783a9201df Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 13:42:01 +0100 Subject: [PATCH 02/20] transformer: remove old decorators when adding deal.raises --- deal/linter/_contract.py | 3 ++ deal/linter/_extractors/contracts.py | 30 +++++++++++++++---- deal/linter/_func.py | 14 +++++---- deal/linter/_transformer.py | 42 +++++++++++++++++++++------ tests/test_linter/test_transformer.py | 21 ++++++++++++++ 5 files changed, 89 insertions(+), 21 deletions(-) diff --git a/deal/linter/_contract.py b/deal/linter/_contract.py index 204bcd68..1caaf1f3 100644 --- a/deal/linter/_contract.py +++ b/deal/linter/_contract.py @@ -31,6 +31,7 @@ class Contract: category: Category func_args: ast.arguments context: Dict[str, ast.stmt] + line: int def __init__( self, @@ -38,11 +39,13 @@ def __init__( category: Category, func_args: ast.arguments, context: Dict[str, ast.stmt] = None, + line: int = 0, ): self.args = tuple(args) self.category = category self.func_args = func_args self.context = context or dict() + self.line = line @cached_property def body(self) -> ast.AST: diff --git a/deal/linter/_extractors/contracts.py b/deal/linter/_extractors/contracts.py index d1e06622..cca5efb7 100644 --- a/deal/linter/_extractors/contracts.py +++ b/deal/linter/_extractors/contracts.py @@ -1,5 +1,5 @@ import ast -from typing import Iterator, List, Optional, Tuple, Union +from typing import Iterator, List, NamedTuple, Optional, Union import astroid @@ -20,7 +20,17 @@ Attr = Union[ast.Attribute, astroid.Attribute] -def get_contracts(func) -> Iterator[Tuple[str, list]]: +class ContractInfo(NamedTuple): + name: str + args: List[Union[ast.expr, astroid.Expr]] + line: int + + def __iter__(self) -> Iterator: + yield self.name + yield self.args + + +def get_contracts(func) -> Iterator[ContractInfo]: if isinstance(func, ast.FunctionDef): yield from _get_contracts(func.decorator_list) return @@ -32,13 +42,17 @@ def get_contracts(func) -> Iterator[Tuple[str, list]]: yield from _get_contracts(func.decorators.nodes) -def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: +def _get_contracts(decorators: list) -> Iterator[ContractInfo]: for contract in decorators: if isinstance(contract, TOKENS.ATTR): name = get_name(contract) if name not in SUPPORTED_MARKERS: continue - yield name.split('.')[-1], [] + yield ContractInfo( + name=name.split('.')[-1], + args=[], + line=contract.lineno + ) if name == 'deal.inherit': yield from _resolve_inherit(contract) @@ -50,7 +64,11 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: yield from _get_contracts(contract.args) if name not in SUPPORTED_CONTRACTS: continue - yield name.split('.')[-1], contract.args + yield ContractInfo( + name=name.split('.')[-1], + args=contract.args, + line=contract.lineno, + ) # infer assigned value if isinstance(contract, astroid.Name): @@ -68,7 +86,7 @@ def _get_contracts(decorators: list) -> Iterator[Tuple[str, list]]: yield from _get_contracts([expr.value]) -def _resolve_inherit(contract: Attr) -> Iterator[Tuple[str, List[astroid.Expr]]]: +def _resolve_inherit(contract: Attr) -> Iterator[ContractInfo]: if not isinstance(contract, astroid.Attribute): return cls = _get_parent_class(contract) diff --git a/deal/linter/_func.py b/deal/linter/_func.py index 4347bd79..ebf4a1ab 100644 --- a/deal/linter/_func.py +++ b/deal/linter/_func.py @@ -36,12 +36,13 @@ def from_ast(cls, tree: ast.Module) -> List['Func']: if not isinstance(expr, ast.FunctionDef): continue contracts = [] - for category, args in get_contracts(expr): + for cinfo in get_contracts(expr): contract = Contract( - args=args, + args=cinfo.args, func_args=expr.args, - category=Category(category), + category=Category(cinfo.name), context=definitions, + line=cinfo.line, ) contracts.append(contract) funcs.append(cls( @@ -68,12 +69,13 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']: # collect contracts contracts = [] - for category, args in get_contracts(expr): + for cinfo in get_contracts(expr): contract = Contract( - args=args, + args=cinfo.args, func_args=func_args, - category=Category(category), + category=Category(cinfo.name), context=definitions, + line=cinfo.line, ) contracts.append(contract) funcs.append(cls( diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index edd76bc9..a5a49a7d 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -1,23 +1,36 @@ from pathlib import Path -from typing import List, NamedTuple, Optional, Set +from typing import Iterator, List, NamedTuple, Set, Union from .._cached_property import cached_property from ._contract import Category from ._func import Func from ._rules import CheckRaises -class Mutation(NamedTuple): +class Insert(NamedTuple): line: int contract: Category args: List[str] indent: int - def __str__(self): + def apply(self, lines: List[str]) -> None: + lines.insert(self.line - 1, f'{self}\n') + + def __str__(self) -> str: args = ', '.join(self.args) dec = f'@deal.{self.contract.value}({args})' return ' ' * self.indent + dec +class Remove(NamedTuple): + line: int + + def apply(self, lines: List[str]) -> None: + lines.pop(self.line - 1) + + +Mutation = Union[Insert, Remove] + + class Transformer: path: Path mutations: List[Mutation] @@ -37,11 +50,10 @@ def transform(self) -> str: return self._apply_mutations(content) def _collect_mutations(self, func: Func) -> None: - mut = self._mutation_excs(func) - if mut is not None: + for mut in self._mutations_excs(func): self.mutations.append(mut) - def _mutation_excs(self, func: Func) -> Optional[Mutation]: + def _mutations_excs(self, func: Func) -> Iterator[Mutation]: checker = CheckRaises() excs: Set[str] = set() cats = {Category.RAISES, Category.SAFE, Category.PURE} @@ -53,9 +65,21 @@ def _mutation_excs(self, func: Func) -> Optional[Mutation]: for error in checker.get_undeclared(func, declared): assert isinstance(error.value, str) excs.add(error.value) + if not excs: - return None - return Mutation( + return + for contract in func.contracts: + if contract.category not in cats: + continue + yield Remove(contract.line) + if contract.category == Category.PURE: + yield Insert( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[], + ) + yield Insert( line=func.line, indent=func.col, contract=Category.RAISES, @@ -68,5 +92,5 @@ def _apply_mutations(self, content: str) -> str: lines = content.splitlines(keepends=True) self.mutations.sort(key=lambda x: x.line, reverse=True) for mutation in self.mutations: - lines.insert(mutation.line - 1, f'{mutation}\n') + mutation.apply(lines) return ''.join(lines) diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index cb1958b2..b3ac9d4d 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -32,6 +32,27 @@ def f(): def f(): raise ValueError """, + # remove deal.safe if adding deal.raises + """ + @deal.safe + def f(): + raise ValueError + --- + @deal.raises(ValueError) + def f(): + raise ValueError + """, + # remove deal.pure if adding deal.raises + """ + @deal.pure + def f(): + raise ValueError + --- + @deal.raises(ValueError) + @deal.has() + def f(): + raise ValueError + """, ]) def test_transformer(content: str, tmp_path: Path) -> None: given, expected = content.split('---') From c070c0be4facd08568b13fdf3d33175df6ea22d4 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 14:01:48 +0100 Subject: [PATCH 03/20] transformer: preserve original deal.raises content --- deal/linter/_transformer.py | 25 ++++++++++++++++++++++--- tests/test_linter/test_transformer.py | 10 ++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index a5a49a7d..7a2b3d98 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -1,11 +1,14 @@ from pathlib import Path -from typing import Iterator, List, NamedTuple, Set, Union +from typing import Iterator, List, NamedTuple, Set, Tuple, Union from .._cached_property import cached_property from ._contract import Category from ._func import Func from ._rules import CheckRaises +Priority = int + + class Insert(NamedTuple): line: int contract: Category @@ -15,6 +18,10 @@ class Insert(NamedTuple): def apply(self, lines: List[str]) -> None: lines.insert(self.line - 1, f'{self}\n') + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 1) + def __str__(self) -> str: args = ', '.join(self.args) dec = f'@deal.{self.contract.value}({args})' @@ -27,6 +34,10 @@ class Remove(NamedTuple): def apply(self, lines: List[str]) -> None: lines.pop(self.line - 1) + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 2) + Mutation = Union[Insert, Remove] @@ -79,18 +90,26 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: contract=Category.HAS, args=[], ) + contract_args = [self._exc_as_str(exc) for exc in declared] + contract_args.extend(sorted(excs)) yield Insert( line=func.line, indent=func.col, contract=Category.RAISES, - args=sorted(excs), + args=contract_args, ) + @staticmethod + def _exc_as_str(exc) -> str: + if isinstance(exc, str): + return exc + return exc.__name__ + def _apply_mutations(self, content: str) -> str: if not self.mutations: return content lines = content.splitlines(keepends=True) - self.mutations.sort(key=lambda x: x.line, reverse=True) + self.mutations.sort(key=lambda x: x.key, reverse=True) for mutation in self.mutations: mutation.apply(lines) return ''.join(lines) diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index b3ac9d4d..004797b6 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -53,6 +53,16 @@ def f(): def f(): raise ValueError """, + # merge deal.raises + """ + @deal.raises(ZeroDivisionError) + def f(): + raise ValueError + --- + @deal.raises(ZeroDivisionError, ValueError) + def f(): + raise ValueError + """, ]) def test_transformer(content: str, tmp_path: Path) -> None: given, expected = content.split('---') From 60e28c9affa327886127b7ad335810b95da1b2f7 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 14:37:30 +0100 Subject: [PATCH 04/20] transformer: add deal.safe --- deal/linter/_contract.py | 5 +++ deal/linter/_transformer.py | 19 ++++++++- tests/test_linter/test_transformer.py | 60 +++++++++++++++++++++++++-- 3 files changed, 79 insertions(+), 5 deletions(-) diff --git a/deal/linter/_contract.py b/deal/linter/_contract.py index 1caaf1f3..1b0f92c9 100644 --- a/deal/linter/_contract.py +++ b/deal/linter/_contract.py @@ -24,6 +24,11 @@ class Category(enum.Enum): PURE = 'pure' RAISES = 'raises' SAFE = 'safe' + INHERIT = 'inherit' + + @property + def brackets_optional(self) -> bool: + return self in {Category.SAFE, Category.PURE} class Contract: diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 7a2b3d98..132af59e 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -24,7 +24,10 @@ def key(self) -> Tuple[int, Priority]: def __str__(self) -> str: args = ', '.join(self.args) - dec = f'@deal.{self.contract.value}({args})' + if not self.args and self.contract.brackets_optional: + dec = f'@deal.{self.contract.value}' + else: + dec = f'@deal.{self.contract.value}({args})' return ' ' * self.indent + dec @@ -77,8 +80,22 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: assert isinstance(error.value, str) excs.add(error.value) + # if no new exceptions found, add deal.safe if not excs: + if declared: + return + for contract in func.contracts: + if contract.category in {Category.PURE, Category.SAFE}: + return + yield Insert( + line=func.line, + indent=func.col, + contract=Category.SAFE, + args=[], + ) return + + # if new exceptions detected, remove old contracts and add a new deal.raises for contract in func.contracts: if contract.category not in cats: continue diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 004797b6..7c225a72 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -5,21 +5,31 @@ @pytest.mark.parametrize('content', [ - # no-op + # add deal.safe """ def f(): return 1 --- + @deal.safe def f(): return 1 """, - # preserve contracts + # preserve deal.raises """ - @deal.pre(lambda: True) + @deal.raises(KeyError) def f(): return 1 --- - @deal.pre(lambda: True) + @deal.raises(KeyError) + def f(): + return 1 + """, + """ + @deal.raises(KeyError, UnknownError) + def f(): + return 1 + --- + @deal.raises(KeyError, UnknownError) def f(): return 1 """, @@ -32,6 +42,15 @@ def f(): def f(): raise ValueError """, + # add deal.raises for unknown error + """ + def f(): + raise UnknownError + --- + @deal.raises(UnknownError) + def f(): + raise UnknownError + """, # remove deal.safe if adding deal.raises """ @deal.safe @@ -63,6 +82,39 @@ def f(): def f(): raise ValueError """, + # preserve contracts + """ + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.safe + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.raises(ValueError) + @deal.pre(lambda: True) + def f(): + 1/0 + --- + @deal.raises(ValueError, ZeroDivisionError) + @deal.pre(lambda: True) + def f(): + 1/0 + """, ]) def test_transformer(content: str, tmp_path: Path) -> None: given, expected = content.split('---') From 4d479b00a43c1fdf6f850127ebe05007fe41fbb6 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 15:55:39 +0100 Subject: [PATCH 05/20] transformer: add deal.has --- deal/linter/__init__.py | 3 +- deal/linter/_rules.py | 12 +- deal/linter/_transformer.py | 98 +++++++++++--- tests/test_linter/test_transformer.py | 187 +++++++++++++++++++++++++- 4 files changed, 270 insertions(+), 30 deletions(-) diff --git a/deal/linter/__init__.py b/deal/linter/__init__.py index 443fe4d5..bfae4d1f 100644 --- a/deal/linter/__init__.py +++ b/deal/linter/__init__.py @@ -1,6 +1,7 @@ +from ._contract import Category from ._checker import Checker from ._stub import StubsManager, generate_stub from ._transformer import Transformer -__all__ = ['Checker', 'Transformer', 'StubsManager', 'generate_stub'] +__all__ = ['Category', 'Checker', 'Transformer', 'StubsManager', 'generate_stub'] diff --git a/deal/linter/_rules.py b/deal/linter/_rules.py index f10e9c80..cb9eb882 100644 --- a/deal/linter/_rules.py +++ b/deal/linter/_rules.py @@ -1,7 +1,7 @@ import ast from itertools import chain from types import MappingProxyType -from typing import Iterator, List, Optional, Type, TypeVar, Union +from typing import Iterator, List, Optional, Set, Type, TypeVar, Union import astroid @@ -319,18 +319,18 @@ class CheckMarkers(FuncRule): def __call__(self, func: Func, stubs: StubsManager = None) -> Iterator[Error]: for contract in func.contracts: - markers = None + markers: Optional[Set[str]] = None if contract.category == Category.HAS: - markers = [get_value(arg) for arg in contract.args] + markers = {get_value(arg) for arg in contract.args} elif contract.category == Category.PURE: - markers = [] + markers = set() if markers is None: continue - yield from self._check(func=func, markers=markers) + yield from self.get_undeclared(func=func, markers=markers) return @classmethod - def _check(cls, func: Func, markers: List[str]) -> Iterator[Error]: + def get_undeclared(cls, func: Func, markers: Set[str]) -> Iterator[Error]: has = HasPatcher(markers) # function without IO must return something if not has.has_io and not has_returns(body=func.body): diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 132af59e..c15d20a1 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -1,9 +1,9 @@ from pathlib import Path -from typing import Iterator, List, NamedTuple, Set, Tuple, Union -from .._cached_property import cached_property +from typing import FrozenSet, Iterator, List, NamedTuple, Set, Tuple, Union from ._contract import Category from ._func import Func -from ._rules import CheckRaises +from ._rules import CheckRaises, CheckMarkers +from ._extractors import get_value Priority = int @@ -45,38 +45,40 @@ def key(self) -> Tuple[int, Priority]: Mutation = Union[Insert, Remove] -class Transformer: +class Transformer(NamedTuple): path: Path - mutations: List[Mutation] - - def __init__(self, path: Path) -> None: - self.path = path - self.mutations = [] - - @cached_property - def funcs(self) -> List[Func]: - return Func.from_path(self.path) + mutations: List[Mutation] = [] + quote: str = "'" + categories: FrozenSet[Category] = frozenset({ + Category.RAISES, + Category.SAFE, + Category.HAS, + }) def transform(self) -> str: - for func in self.funcs: + for func in Func.from_path(self.path): self._collect_mutations(func) content = self.path.read_text() return self._apply_mutations(content) def _collect_mutations(self, func: Func) -> None: - for mut in self._mutations_excs(func): - self.mutations.append(mut) + self.mutations.clear() + self.mutations.extend(self._mutations_excs(func)) + self.mutations.extend(self._mutations_markers(func)) def _mutations_excs(self, func: Func) -> Iterator[Mutation]: - checker = CheckRaises() - excs: Set[str] = set() cats = {Category.RAISES, Category.SAFE, Category.PURE} - declared = [] + + # collect declared exceptions + declared: List[Union[str, type]] = [] for contract in func.contracts: if contract.category not in cats: continue declared.extend(contract.exceptions) - for error in checker.get_undeclared(func, declared): + + # collect undeclared exceptions + excs: Set[str] = set() + for error in CheckRaises().get_undeclared(func, declared): assert isinstance(error.value, str) excs.add(error.value) @@ -84,6 +86,8 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: if not excs: if declared: return + if Category.SAFE not in self.categories: + return for contract in func.contracts: if contract.category in {Category.PURE, Category.SAFE}: return @@ -96,6 +100,8 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: return # if new exceptions detected, remove old contracts and add a new deal.raises + if Category.RAISES not in self.categories: + return for contract in func.contracts: if contract.category not in cats: continue @@ -122,6 +128,58 @@ def _exc_as_str(exc) -> str: return exc return exc.__name__ + def _mutations_markers(self, func: Func) -> Iterator[Mutation]: + if Category.HAS not in self.categories: + return + cats = {Category.HAS, Category.PURE} + + # collect declared markers + declared: List[str] = [] + for contract in func.contracts: + if contract.category not in cats: + continue + declared.extend(get_value(arg) for arg in contract.args) + + # collect undeclared markers + markers: Set[str] = set() + for error in CheckMarkers().get_undeclared(func, set(declared)): + assert isinstance(error.value, str) + markers.add(error.value) + + # if no new markers found, add deal.has() + if not markers: + for contract in func.contracts: + if contract.category in {Category.PURE, Category.HAS}: + return + yield Insert( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[], + ) + return + + # if new exceptions detected, remove old contracts and add a new deal.raises + for contract in func.contracts: + if contract.category not in cats: + continue + yield Remove(contract.line) + if contract.category == Category.PURE: + yield Insert( + line=func.line, + indent=func.col, + contract=Category.SAFE, + args=[], + ) + contract_args = [self._exc_as_str(marker) for marker in declared] + contract_args.extend(sorted(markers)) + yield Insert( + line=func.line, + indent=func.col, + contract=Category.HAS, + args=[f'{self.quote}{arg}{self.quote}' for arg in contract_args], + ) + def _apply_mutations(self, content: str) -> str: if not self.mutations: return content diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 7c225a72..ecd07a0b 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -1,7 +1,7 @@ from pathlib import Path from textwrap import dedent import pytest -from deal.linter import Transformer +from deal.linter import Transformer, Category @pytest.mark.parametrize('content', [ @@ -14,6 +14,16 @@ def f(): def f(): return 1 """, + # preserve deal.safe + """ + @deal.safe + def f(): + return 1 + --- + @deal.safe + def f(): + return 1 + """, # preserve deal.raises """ @deal.raises(KeyError) @@ -51,6 +61,16 @@ def f(): def f(): raise UnknownError """, + # preserve unknown error + """ + @deal.raises(UnknownError) + def f(): + raise ValueError + --- + @deal.raises(UnknownError, ValueError) + def f(): + raise ValueError + """, # remove deal.safe if adding deal.raises """ @deal.safe @@ -82,6 +102,32 @@ def f(): def f(): raise ValueError """, + # preserve comments + """ + # oh hi mark + def f(): + raise ValueError + --- + # oh hi mark + @deal.raises(ValueError) + def f(): + raise ValueError + """, + # preserve comments + """ + # oh hi mark + + def f(): + # hello + raise ValueError + --- + # oh hi mark + + @deal.raises(ValueError) + def f(): + # hello + raise ValueError + """, # preserve contracts """ @deal.safe @@ -116,12 +162,147 @@ def f(): 1/0 """, ]) -def test_transformer(content: str, tmp_path: Path) -> None: +def test_transformer_raises(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given) + expected = dedent(expected) + path = tmp_path / "example.py" + path.write_text(given) + tr = Transformer(path=path, categories=frozenset({Category.RAISES, Category.SAFE})) + actual = tr.transform() + assert actual == expected + + +@pytest.mark.parametrize('content', [ + # add deal.has() + """ + def f(): + return 1 + --- + @deal.has() + def f(): + return 1 + """, + # add deal.has with markers + """ + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + def f(): + print("hi") + return 1 + """, + """ + import random + def f(): + print(random.choice([1,2])) + return 1 + --- + import random + @deal.has('random', 'stdout') + def f(): + print(random.choice([1,2])) + return 1 + """, + # merge deal.has + """ + @deal.has('random') + def f(): + print("hi") + return 1 + --- + @deal.has('random', 'stdout') + def f(): + print("hi") + return 1 + """, + """ + @deal.has() + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + def f(): + print("hi") + return 1 + """, + # replace deal.pure + """ + @deal.pure + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + @deal.safe + def f(): + print("hi") + return 1 + """, + # preserve contracts + """ + @deal.pre(lambda: True) + def f(): + print("hi") + return 1 + --- + @deal.has('stdout') + @deal.pre(lambda: True) + def f(): + print("hi") + return 1 + """, + """ + @deal.has('random') + def f(): + return 1 + --- + @deal.has('random') + def f(): + return 1 + """, + """ + @deal.has('random') + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has('random') + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + """, + """ + @deal.pre(lambda: True) + def f(): + return 1 + --- + @deal.has() + @deal.pre(lambda: True) + def f(): + return 1 + """, +]) +def test_transformer_has(content: str, tmp_path: Path) -> None: given, expected = content.split('---') given = dedent(given) expected = dedent(expected) path = tmp_path / "example.py" path.write_text(given) - tr = Transformer(path=path) + tr = Transformer(path=path, categories=frozenset({Category.HAS})) actual = tr.transform() assert actual == expected From da65fb65d4c491e547b5fe06c24073a37e89c580 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 16:08:02 +0100 Subject: [PATCH 06/20] transformer: TransformationType --- deal/linter/__init__.py | 11 +++++++--- deal/linter/_transformer.py | 21 +++++++++++-------- tests/test_linter/test_transformer.py | 30 ++++++++++++++++++++++++--- 3 files changed, 47 insertions(+), 15 deletions(-) diff --git a/deal/linter/__init__.py b/deal/linter/__init__.py index bfae4d1f..2ec1b228 100644 --- a/deal/linter/__init__.py +++ b/deal/linter/__init__.py @@ -1,7 +1,12 @@ -from ._contract import Category from ._checker import Checker from ._stub import StubsManager, generate_stub -from ._transformer import Transformer +from ._transformer import Transformer, TransformationType -__all__ = ['Category', 'Checker', 'Transformer', 'StubsManager', 'generate_stub'] +__all__ = [ + 'Checker', + 'generate_stub', + 'StubsManager', + 'TransformationType', + 'Transformer', +] diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index c15d20a1..7cab64f9 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -1,5 +1,6 @@ +from enum import Enum from pathlib import Path -from typing import FrozenSet, Iterator, List, NamedTuple, Set, Tuple, Union +from typing import Iterator, List, NamedTuple, Set, Tuple, Union from ._contract import Category from ._func import Func from ._rules import CheckRaises, CheckMarkers @@ -9,6 +10,12 @@ Priority = int +class TransformationType(Enum): + RAISES = 'raises' + HAS = 'has' + SAFE = 'safe' + + class Insert(NamedTuple): line: int contract: Category @@ -47,13 +54,9 @@ def key(self) -> Tuple[int, Priority]: class Transformer(NamedTuple): path: Path + types: Set[TransformationType] mutations: List[Mutation] = [] quote: str = "'" - categories: FrozenSet[Category] = frozenset({ - Category.RAISES, - Category.SAFE, - Category.HAS, - }) def transform(self) -> str: for func in Func.from_path(self.path): @@ -86,7 +89,7 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: if not excs: if declared: return - if Category.SAFE not in self.categories: + if TransformationType.SAFE not in self.types: return for contract in func.contracts: if contract.category in {Category.PURE, Category.SAFE}: @@ -100,7 +103,7 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: return # if new exceptions detected, remove old contracts and add a new deal.raises - if Category.RAISES not in self.categories: + if TransformationType.RAISES not in self.types: return for contract in func.contracts: if contract.category not in cats: @@ -129,7 +132,7 @@ def _exc_as_str(exc) -> str: return exc.__name__ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: - if Category.HAS not in self.categories: + if TransformationType.HAS not in self.types: return cats = {Category.HAS, Category.PURE} diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index ecd07a0b..13522e78 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -1,7 +1,7 @@ from pathlib import Path from textwrap import dedent import pytest -from deal.linter import Transformer, Category +from deal.linter import Transformer, TransformationType @pytest.mark.parametrize('content', [ @@ -24,6 +24,16 @@ def f(): def f(): return 1 """, + # preserve deal.safe with comments + """ + @deal.safe # oh hi mark + def f(): + return 1 + --- + @deal.safe # oh hi mark + def f(): + return 1 + """, # preserve deal.raises """ @deal.raises(KeyError) @@ -168,7 +178,10 @@ def test_transformer_raises(content: str, tmp_path: Path) -> None: expected = dedent(expected) path = tmp_path / "example.py" path.write_text(given) - tr = Transformer(path=path, categories=frozenset({Category.RAISES, Category.SAFE})) + tr = Transformer( + path=path, + types={TransformationType.RAISES, TransformationType.SAFE}, + ) actual = tr.transform() assert actual == expected @@ -296,6 +309,17 @@ def f(): def f(): return 1 """, + # do not add deal.raises if transformation type is disabled + """ + def f(): + raise ValueError + return 1 + --- + @deal.has() + def f(): + raise ValueError + return 1 + """, ]) def test_transformer_has(content: str, tmp_path: Path) -> None: given, expected = content.split('---') @@ -303,6 +327,6 @@ def test_transformer_has(content: str, tmp_path: Path) -> None: expected = dedent(expected) path = tmp_path / "example.py" path.write_text(given) - tr = Transformer(path=path, categories=frozenset({Category.HAS})) + tr = Transformer(path=path, types={TransformationType.HAS}) actual = tr.transform() assert actual == expected From f381913201621d755313361be2e85fccc7c02369 Mon Sep 17 00:00:00 2001 From: gram Date: Fri, 12 Nov 2021 16:15:16 +0100 Subject: [PATCH 07/20] lazy import for CLI commands --- deal/_cli/_main.py | 33 +++++++++++++++++++-------------- tests/test_docs.py | 6 +++--- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/deal/_cli/_main.py b/deal/_cli/_main.py index 329c0f3e..535cf804 100644 --- a/deal/_cli/_main.py +++ b/deal/_cli/_main.py @@ -1,33 +1,38 @@ import sys from argparse import ArgumentParser from pathlib import Path -from types import MappingProxyType from typing import Mapping, Sequence, TextIO, Type from ._base import Command -from ._lint import LintCommand -from ._memtest import MemtestCommand -from ._prove import ProveCommand -from ._stub import StubCommand -from ._test import TestCommand CommandsType = Mapping[str, Type[Command]] -COMMANDS: CommandsType = MappingProxyType(dict( - lint=LintCommand, - memtest=MemtestCommand, - prove=ProveCommand, - stub=StubCommand, - test=TestCommand, -)) + + +def get_commands() -> CommandsType: + from ._lint import LintCommand + from ._memtest import MemtestCommand + from ._prove import ProveCommand + from ._stub import StubCommand + from ._test import TestCommand + + return dict( + lint=LintCommand, + memtest=MemtestCommand, + prove=ProveCommand, + stub=StubCommand, + test=TestCommand, + ) def main( argv: Sequence[str], *, - commands: CommandsType = COMMANDS, + commands: CommandsType = None, root: Path = None, stream: TextIO = sys.stdout, ) -> int: + if commands is None: + commands = get_commands() if root is None: # pragma: no cover root = Path() parser = ArgumentParser(prog='python3 -m deal') diff --git a/tests/test_docs.py b/tests/test_docs.py index cc2c3b9e..4b1e51f3 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -5,7 +5,7 @@ import pytest import deal -from deal._cli._main import COMMANDS +from deal._cli._main import get_commands from deal.linter._rules import CheckMarkers, rules @@ -51,7 +51,7 @@ def test_all_marker_codes_listed(): def test_cli_included(): path = root / 'details' / 'cli.md' content = path.read_text() - for name, cmd in COMMANDS.items(): + for name, cmd in get_commands().items(): # has header tmpl = '## {n}\n\n' line = tmpl.format(n=name) @@ -106,7 +106,7 @@ def test_all_public_listed_in_refs(): def test_all_cli_commands_listed_in_refs(): path = root / 'basic' / 'refs.md' content = path.read_text() - for name, cmd in COMMANDS.items(): + for name, cmd in get_commands().items(): assert f'`{name}`' in content assert f'`details/cli:{name}`' in content From 1ddabbf9043dfd2491674ed1d58456a290041ad9 Mon Sep 17 00:00:00 2001 From: gram Date: Sat, 13 Nov 2021 10:14:27 +0100 Subject: [PATCH 08/20] `deal decorate` CLI command --- deal/_cli/_base.py | 4 ++-- deal/_cli/_decorate.py | 40 +++++++++++++++++++++++++++++++++++++ deal/_cli/_main.py | 2 ++ deal/linter/_transformer.py | 7 +++++-- 4 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 deal/_cli/_decorate.py diff --git a/deal/_cli/_base.py b/deal/_cli/_base.py index 4eeba100..d4a96235 100644 --- a/deal/_cli/_base.py +++ b/deal/_cli/_base.py @@ -1,4 +1,4 @@ -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace from pathlib import Path from typing import TextIO @@ -18,5 +18,5 @@ def print(self, *args) -> None: def init_parser(parser: ArgumentParser) -> None: raise NotImplementedError - def __call__(self, args) -> int: + def __call__(self, args: Namespace) -> int: raise NotImplementedError diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py new file mode 100644 index 00000000..f9cde2a8 --- /dev/null +++ b/deal/_cli/_decorate.py @@ -0,0 +1,40 @@ +from argparse import ArgumentParser +from pathlib import Path + +from ..linter import Transformer, TransformationType +from ._base import Command +from ._common import get_paths + + +class DecorateCommand(Command): + """Add decorators to your code. + + ```bash + python3 -m deal decorate project/ + ``` + + Options: + + `--types`: types of decorators to apply. All are enabled by default. + + """ + + @staticmethod + def init_parser(parser: ArgumentParser) -> None: + parser.add_argument( + '--types', + nargs='*', + choices=[tt.value for tt in TransformationType], + default=['has', 'raises', 'safe'], + help='types of decorators to apply', + ) + parser.add_argument('paths', nargs='*', default='.') + + def __call__(self, args) -> int: + types = {TransformationType(t) for t in args.types} + for arg in args.paths: + for path in get_paths(Path(arg)): + self.print(path) + tr = Transformer(path=path, types=types) + content = tr.transform() + path.write_text(content) + return 0 diff --git a/deal/_cli/_main.py b/deal/_cli/_main.py index 535cf804..50a40974 100644 --- a/deal/_cli/_main.py +++ b/deal/_cli/_main.py @@ -10,6 +10,7 @@ def get_commands() -> CommandsType: + from ._decorate import DecorateCommand from ._lint import LintCommand from ._memtest import MemtestCommand from ._prove import ProveCommand @@ -17,6 +18,7 @@ def get_commands() -> CommandsType: from ._test import TestCommand return dict( + decorate=DecorateCommand, lint=LintCommand, memtest=MemtestCommand, prove=ProveCommand, diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 7cab64f9..4ef71c6f 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -1,6 +1,8 @@ from enum import Enum from pathlib import Path from typing import Iterator, List, NamedTuple, Set, Tuple, Union + +import astroid from ._contract import Category from ._func import Func from ._rules import CheckRaises, CheckMarkers @@ -59,9 +61,10 @@ class Transformer(NamedTuple): quote: str = "'" def transform(self) -> str: - for func in Func.from_path(self.path): - self._collect_mutations(func) content = self.path.read_text() + tree = astroid.parse(content, path=self.path) + for func in Func.from_astroid(tree): + self._collect_mutations(func) return self._apply_mutations(content) def _collect_mutations(self, func: Func) -> None: From f040d946b539813db07f86c7f6ec28c6e2eacad9 Mon Sep 17 00:00:00 2001 From: gram Date: Sat, 13 Nov 2021 10:50:24 +0100 Subject: [PATCH 09/20] test decorate --- deal/_cli/_decorate.py | 10 +++++ tests/test_cli/test_decorate.py | 75 +++++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 tests/test_cli/test_decorate.py diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py index f9cde2a8..c4693696 100644 --- a/deal/_cli/_decorate.py +++ b/deal/_cli/_decorate.py @@ -15,7 +15,10 @@ class DecorateCommand(Command): Options: + `--types`: types of decorators to apply. All are enabled by default. + + `--double-quotes`: use double quotes. Single quotes are used by default. + The exit code is always 0. If you want to test the code for missed decorators, + use the `lint` command instead. """ @staticmethod @@ -27,6 +30,11 @@ def init_parser(parser: ArgumentParser) -> None: default=['has', 'raises', 'safe'], help='types of decorators to apply', ) + parser.add_argument( + '--double-quotes', + action='store_true', + help='use double quotes', + ) parser.add_argument('paths', nargs='*', default='.') def __call__(self, args) -> int: @@ -35,6 +43,8 @@ def __call__(self, args) -> int: for path in get_paths(Path(arg)): self.print(path) tr = Transformer(path=path, types=types) + if args.double_quotes: + tr = tr._replace(quote='"') content = tr.transform() path.write_text(content) return 0 diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py new file mode 100644 index 00000000..da1bed90 --- /dev/null +++ b/tests/test_cli/test_decorate.py @@ -0,0 +1,75 @@ +from io import StringIO +from pathlib import Path +from textwrap import dedent +import pytest +from deal._cli import main + + +@pytest.mark.parametrize('flags, given, expected', [ + ( + [], + """ + import deal + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + @deal.has('stdout') + @deal.raises(ZeroDivisionError) + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """ + ), + ( + ['--types', 'raises', 'safe'], + """ + import deal + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + @deal.raises(ZeroDivisionError) + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """ + ), + ( + ['--types', 'has', '--double-quotes'], + """ + import deal + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """, + """ + import deal + @deal.has("stdout") + @deal.post(lambda x: x > 0) + def f(x): + print(1/0) + return -1 + """ + ), +]) +def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path): + file_path = tmp_path / 'example.py' + file_path.write_text(dedent(given)) + stream = StringIO() + code = main(['decorate', *flags, '--', str(tmp_path)], stream=stream) + assert code == 0 + + stream.seek(0) + captured = stream.read() + assert captured.strip() == str(file_path) + assert file_path.read_text() == dedent(expected) From b4a1150c8d43f3c74dede6cd6ffa0141dbb8e20a Mon Sep 17 00:00:00 2001 From: gram Date: Sat, 13 Nov 2021 11:11:56 +0100 Subject: [PATCH 10/20] docs for decorate command --- docs/basic/refs.md | 15 ++++++++------- docs/details/cli.md | 10 ++++++++-- docs/details/contracts.md | 16 ++++++++++++++++ 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/docs/basic/refs.md b/docs/basic/refs.md index a8206dcd..f1d859c7 100644 --- a/docs/basic/refs.md +++ b/docs/basic/refs.md @@ -52,13 +52,14 @@ This page provides a quick navigation by the documentation in case if you're loo ## CLI commands -| command | reference | documentation | -| --------- | -------------------------- | ------------- | -| `lint` | {ref}`details/cli:lint` | {doc}`/basic/linter` / {ref}`basic/linter:built-in cli command` | -| `memtest` | {ref}`details/cli:memtest` | {doc}`/details/tests` / {ref}`details/tests:finding memory leaks` | -| `prove` | {ref}`details/cli:prove` | {doc}`/basic/verification` | -| `stub` | {ref}`details/cli:stub` | {doc}`/details/stubs` | -| `test` | {ref}`details/cli:test` | {doc}`/basic/tests` / {ref}`basic/tests:cli` | +| command | reference | documentation | +| ---------- | --------------------------- | ------------- | +| `decorate` | {ref}`details/cli:decorate` | {doc}`/details/contracts` / {ref}`details/contracts:generating contracts` | +| `lint` | {ref}`details/cli:lint` | {doc}`/basic/linter` / {ref}`basic/linter:built-in cli command` | +| `memtest` | {ref}`details/cli:memtest` | {doc}`/details/tests` / {ref}`details/tests:finding memory leaks` | +| `prove` | {ref}`details/cli:prove` | {doc}`/basic/verification` | +| `stub` | {ref}`details/cli:stub` | {doc}`/details/stubs` | +| `test` | {ref}`details/cli:test` | {doc}`/basic/tests` / {ref}`basic/tests:cli` | ## Integrations diff --git a/docs/details/cli.md b/docs/details/cli.md index 5ac3f266..44034f49 100644 --- a/docs/details/cli.md +++ b/docs/details/cli.md @@ -6,10 +6,10 @@ .. autofunction:: deal._cli._lint.LintCommand ``` -## stub +## decorate ```{eval-rst} -.. autofunction:: deal._cli._stub.StubCommand +.. autofunction:: deal._cli._decorate.DecorateCommand ``` ## test @@ -24,6 +24,12 @@ .. autofunction:: deal._cli._memtest.MemtestCommand ``` +## stub + +```{eval-rst} +.. autofunction:: deal._cli._stub.StubCommand +``` + ## prove ```{eval-rst} diff --git a/docs/details/contracts.md b/docs/details/contracts.md index 88060f21..5e7922db 100644 --- a/docs/details/contracts.md +++ b/docs/details/contracts.md @@ -1,5 +1,21 @@ # More on writing contracts +## Generating contracts + +The best way to get started with deal is to automatically generate some contracts using {ref}`details/cli:decorate` CLI command: + +```bash +python3 -m deal decorate my_project/ +``` + +It will run {doc}`/basic/linter` on your code and add some of the missed contracts. The rest of contacts is still on you, though. Also, you should carefully check the generated code for correctness, deal may miss something. + +The following contracts are supported by the command and will be added to your code: + ++ {py:func}`deal.has` ++ {py:func}`deal.raises` ++ {py:func}`deal.safe` + ## Simplified signature The main problem with contracts is that they have to duplicate the original function's signature, including default arguments. While it's not a problem for small examples, things become more complicated when the signature grows. In this case, you can specify a function that accepts only one `_` argument, and deal will pass there a container with arguments of the function call, including default ones: From d72c4ad608b0aa4df7a8c2e58a3308c0b796113a Mon Sep 17 00:00:00 2001 From: gram Date: Sat, 13 Nov 2021 11:20:01 +0100 Subject: [PATCH 11/20] do not touch files if no modifications needed --- deal/_cli/_decorate.py | 12 +++++++++--- deal/linter/_transformer.py | 6 +++--- tests/test_cli/test_decorate.py | 15 +++++++++++++++ tests/test_linter/test_transformer.py | 13 +++++++------ 4 files changed, 34 insertions(+), 12 deletions(-) diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py index c4693696..97768589 100644 --- a/deal/_cli/_decorate.py +++ b/deal/_cli/_decorate.py @@ -42,9 +42,15 @@ def __call__(self, args) -> int: for arg in args.paths: for path in get_paths(Path(arg)): self.print(path) - tr = Transformer(path=path, types=types) + original_code = path.read_text(encoding='utf8') + tr = Transformer( + content=original_code, + path=path, + types=types, + ) if args.double_quotes: tr = tr._replace(quote='"') - content = tr.transform() - path.write_text(content) + modified_code = tr.transform() + if original_code != modified_code: + path.write_text(modified_code) return 0 diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 4ef71c6f..a31ddeef 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -55,17 +55,17 @@ def key(self) -> Tuple[int, Priority]: class Transformer(NamedTuple): + content: str path: Path types: Set[TransformationType] mutations: List[Mutation] = [] quote: str = "'" def transform(self) -> str: - content = self.path.read_text() - tree = astroid.parse(content, path=self.path) + tree = astroid.parse(self.content, path=self.path) for func in Func.from_astroid(tree): self._collect_mutations(func) - return self._apply_mutations(content) + return self._apply_mutations(self.content) def _collect_mutations(self, func: Func) -> None: self.mutations.clear() diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index da1bed90..4547ff20 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -61,6 +61,21 @@ def f(x): return -1 """ ), + ( + [], + """ + import deal + @deal.pure + def f(x): + return x + """, + """ + import deal + @deal.pure + def f(x): + return x + """ + ), ]) def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path): file_path = tmp_path / 'example.py' diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 13522e78..30303a03 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -176,10 +176,9 @@ def test_transformer_raises(content: str, tmp_path: Path) -> None: given, expected = content.split('---') given = dedent(given) expected = dedent(expected) - path = tmp_path / "example.py" - path.write_text(given) tr = Transformer( - path=path, + content=given, + path=tmp_path / "example.py", types={TransformationType.RAISES, TransformationType.SAFE}, ) actual = tr.transform() @@ -325,8 +324,10 @@ def test_transformer_has(content: str, tmp_path: Path) -> None: given, expected = content.split('---') given = dedent(given) expected = dedent(expected) - path = tmp_path / "example.py" - path.write_text(given) - tr = Transformer(path=path, types={TransformationType.HAS}) + tr = Transformer( + content=given, + path=tmp_path / "example.py", + types={TransformationType.HAS}, + ) actual = tr.transform() assert actual == expected From 341e86727f5cced8f135384b6f31e7d04c0fc982 Mon Sep 17 00:00:00 2001 From: gram Date: Sat, 13 Nov 2021 11:26:22 +0100 Subject: [PATCH 12/20] deal decorate: use colors --- deal/_cli/_decorate.py | 11 +++++++++-- tests/test_cli/test_decorate.py | 2 +- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py index 97768589..e625c572 100644 --- a/deal/_cli/_decorate.py +++ b/deal/_cli/_decorate.py @@ -4,6 +4,7 @@ from ..linter import Transformer, TransformationType from ._base import Command from ._common import get_paths +from .._colors import get_colors class DecorateCommand(Command): @@ -16,6 +17,7 @@ class DecorateCommand(Command): Options: + `--types`: types of decorators to apply. All are enabled by default. + `--double-quotes`: use double quotes. Single quotes are used by default. + + `--nocolor`: do not use colors in the console output. The exit code is always 0. If you want to test the code for missed decorators, use the `lint` command instead. @@ -35,13 +37,15 @@ def init_parser(parser: ArgumentParser) -> None: action='store_true', help='use double quotes', ) + parser.add_argument('--nocolor', action='store_true', help='colorless output') parser.add_argument('paths', nargs='*', default='.') def __call__(self, args) -> int: types = {TransformationType(t) for t in args.types} + colors = get_colors(args) for arg in args.paths: for path in get_paths(Path(arg)): - self.print(path) + self.print('{magenta}{path}{end}'.format(path=path, **colors)) original_code = path.read_text(encoding='utf8') tr = Transformer( content=original_code, @@ -51,6 +55,9 @@ def __call__(self, args) -> int: if args.double_quotes: tr = tr._replace(quote='"') modified_code = tr.transform() - if original_code != modified_code: + if original_code == modified_code: + self.print(' {blue}no changes{end}'.format(**colors)) + else: path.write_text(modified_code) + self.print(' {green}decorated{end}'.format(**colors)) return 0 diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index 4547ff20..1992b3a2 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -86,5 +86,5 @@ def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path stream.seek(0) captured = stream.read() - assert captured.strip() == str(file_path) + assert str(file_path) in captured assert file_path.read_text() == dedent(expected) From a6989f408ae609fe5f1ba23b12aed7a0894c5555 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 17 Nov 2021 14:05:02 +0100 Subject: [PATCH 13/20] simplify transformer a bit --- deal/linter/_func.py | 6 ++++++ deal/linter/_transformer.py | 14 ++++++-------- tests/test_linter/test_transformer.py | 12 ++++++++++++ 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/deal/linter/_func.py b/deal/linter/_func.py index ebf4a1ab..27ea03cd 100644 --- a/deal/linter/_func.py +++ b/deal/linter/_func.py @@ -88,6 +88,12 @@ def from_astroid(cls, tree: astroid.Module) -> List['Func']: )) return funcs + def has_contract(self, *categories: Category) -> bool: + for contract in self.contracts: + if contract.category in categories: + return True + return False + def __repr__(self) -> str: cats = ', '.join(contract.category.value for contract in self.contracts) return f'{type(self).__name__}({cats})' diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index a31ddeef..f59918ed 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -62,13 +62,13 @@ class Transformer(NamedTuple): quote: str = "'" def transform(self) -> str: + self.mutations.clear() tree = astroid.parse(self.content, path=self.path) for func in Func.from_astroid(tree): self._collect_mutations(func) return self._apply_mutations(self.content) def _collect_mutations(self, func: Func) -> None: - self.mutations.clear() self.mutations.extend(self._mutations_excs(func)) self.mutations.extend(self._mutations_markers(func)) @@ -94,9 +94,8 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: return if TransformationType.SAFE not in self.types: return - for contract in func.contracts: - if contract.category in {Category.PURE, Category.SAFE}: - return + if func.has_contract(Category.PURE, Category.SAFE): + return yield Insert( line=func.line, indent=func.col, @@ -154,9 +153,8 @@ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: # if no new markers found, add deal.has() if not markers: - for contract in func.contracts: - if contract.category in {Category.PURE, Category.HAS}: - return + if func.has_contract(Category.PURE, Category.HAS): + return yield Insert( line=func.line, indent=func.col, @@ -165,7 +163,7 @@ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: ) return - # if new exceptions detected, remove old contracts and add a new deal.raises + # if new markers detected, remove old contracts and add a new deal.raises for contract in func.contracts: if contract.category not in cats: continue diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 30303a03..a8378908 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -171,6 +171,17 @@ def f(): def f(): 1/0 """, + # # support methods + # """ + # class A: + # def f(self): + # 1/0 + # --- + # class A: + # @deal.raises(ZeroDivisionError) + # def f(self): + # 1/0 + # """, ]) def test_transformer_raises(content: str, tmp_path: Path) -> None: given, expected = content.split('---') @@ -182,6 +193,7 @@ def test_transformer_raises(content: str, tmp_path: Path) -> None: types={TransformationType.RAISES, TransformationType.SAFE}, ) actual = tr.transform() + print(tr.mutations) assert actual == expected From fb376928707bdc1bd2e14f162ece87041518f852 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 17 Nov 2021 14:05:15 +0100 Subject: [PATCH 14/20] sort imports --- deal/_cli/_decorate.py | 4 ++-- deal/linter/__init__.py | 2 +- deal/linter/_transformer.py | 5 +++-- tests/test_cli/test_decorate.py | 2 ++ tests/test_linter/test_transformer.py | 4 +++- 5 files changed, 11 insertions(+), 6 deletions(-) diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py index e625c572..42164f79 100644 --- a/deal/_cli/_decorate.py +++ b/deal/_cli/_decorate.py @@ -1,10 +1,10 @@ from argparse import ArgumentParser from pathlib import Path -from ..linter import Transformer, TransformationType +from .._colors import get_colors +from ..linter import TransformationType, Transformer from ._base import Command from ._common import get_paths -from .._colors import get_colors class DecorateCommand(Command): diff --git a/deal/linter/__init__.py b/deal/linter/__init__.py index 2ec1b228..019b5e8a 100644 --- a/deal/linter/__init__.py +++ b/deal/linter/__init__.py @@ -1,6 +1,6 @@ from ._checker import Checker from ._stub import StubsManager, generate_stub -from ._transformer import Transformer, TransformationType +from ._transformer import TransformationType, Transformer __all__ = [ diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index f59918ed..42209d7a 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -3,10 +3,11 @@ from typing import Iterator, List, NamedTuple, Set, Tuple, Union import astroid + from ._contract import Category -from ._func import Func -from ._rules import CheckRaises, CheckMarkers from ._extractors import get_value +from ._func import Func +from ._rules import CheckMarkers, CheckRaises Priority = int diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index 1992b3a2..f71ae2a7 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -1,7 +1,9 @@ from io import StringIO from pathlib import Path from textwrap import dedent + import pytest + from deal._cli import main diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index a8378908..59413828 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -1,7 +1,9 @@ from pathlib import Path from textwrap import dedent + import pytest -from deal.linter import Transformer, TransformationType + +from deal.linter import TransformationType, Transformer @pytest.mark.parametrize('content', [ From 08da584112673be34cbdfe662e509f50807765b1 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 17 Nov 2021 14:23:11 +0100 Subject: [PATCH 15/20] test more --- tests/test_linter/test_contract.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/tests/test_linter/test_contract.py b/tests/test_linter/test_contract.py index a6edf362..617b857e 100644 --- a/tests/test_linter/test_contract.py +++ b/tests/test_linter/test_contract.py @@ -5,6 +5,7 @@ import pytest from deal.linter._contract import Category, Contract +from deal.linter._extractors.contracts import SUPPORTED_CONTRACTS, SUPPORTED_MARKERS from deal.linter._func import Func @@ -18,6 +19,16 @@ def f(x): """ +@pytest.mark.parametrize('marker', SUPPORTED_MARKERS) +def test_all_markers_are_contracts(marker): + assert marker in SUPPORTED_CONTRACTS + + +def test_supported_contracts_match_categories(marker): + cats = {c.value for c in Category} + assert cats == SUPPORTED_CONTRACTS + + def test_exceptions(): funcs1 = Func.from_ast(ast.parse(TEXT)) assert len(funcs1) == 1 @@ -30,7 +41,11 @@ def test_exceptions(): def test_repr(): - c = Contract(category=Category.RAISES, args=[], func_args=None) + c = Contract( + category=Category.RAISES, + args=[], + func_args=None, # type: ignore[arg-type] + ) assert repr(c) == 'Contract(raises)' From cb5ab9a26a6d5801b2b9d0958b2f4eaf3dc3f411 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 17 Nov 2021 15:56:25 +0100 Subject: [PATCH 16/20] add import statement if needed --- deal/linter/_transformer.py | 61 ++++++++++++-- tests/test_linter/test_contract.py | 11 --- .../test_extractors/test_contracts.py | 9 ++- tests/test_linter/test_transformer.py | 81 +++++++++++++++++++ 4 files changed, 142 insertions(+), 20 deletions(-) diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 42209d7a..cf4db6ea 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -17,9 +17,25 @@ class TransformationType(Enum): RAISES = 'raises' HAS = 'has' SAFE = 'safe' + IMPORT = 'import' -class Insert(NamedTuple): +class InsertText(NamedTuple): + line: int + text: str + + def apply(self, lines: List[str]) -> None: + lines.insert(self.line - 1, f'{self}\n') + + @property + def key(self) -> Tuple[int, Priority]: + return (self.line, 1) + + def __str__(self) -> str: + return self.text + + +class InsertContract(NamedTuple): line: int contract: Category args: List[str] @@ -52,10 +68,12 @@ def key(self) -> Tuple[int, Priority]: return (self.line, 2) -Mutation = Union[Insert, Remove] +Mutation = Union[InsertText, InsertContract, Remove] class Transformer(NamedTuple): + """Transformer adds deal decorators into the given script. + """ content: str path: Path types: Set[TransformationType] @@ -67,6 +85,7 @@ def transform(self) -> str: tree = astroid.parse(self.content, path=self.path) for func in Func.from_astroid(tree): self._collect_mutations(func) + self.mutations.extend(self._mutations_import(tree)) return self._apply_mutations(self.content) def _collect_mutations(self, func: Func) -> None: @@ -74,6 +93,8 @@ def _collect_mutations(self, func: Func) -> None: self.mutations.extend(self._mutations_markers(func)) def _mutations_excs(self, func: Func) -> Iterator[Mutation]: + """Add @deal.raises or @deal.safe if needed. + """ cats = {Category.RAISES, Category.SAFE, Category.PURE} # collect declared exceptions @@ -97,7 +118,7 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: return if func.has_contract(Category.PURE, Category.SAFE): return - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.SAFE, @@ -113,7 +134,7 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: continue yield Remove(contract.line) if contract.category == Category.PURE: - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.HAS, @@ -121,7 +142,7 @@ def _mutations_excs(self, func: Func) -> Iterator[Mutation]: ) contract_args = [self._exc_as_str(exc) for exc in declared] contract_args.extend(sorted(excs)) - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.RAISES, @@ -135,6 +156,8 @@ def _exc_as_str(exc) -> str: return exc.__name__ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: + """Add @deal.has if needed. + """ if TransformationType.HAS not in self.types: return cats = {Category.HAS, Category.PURE} @@ -156,7 +179,7 @@ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: if not markers: if func.has_contract(Category.PURE, Category.HAS): return - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.HAS, @@ -170,7 +193,7 @@ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: continue yield Remove(contract.line) if contract.category == Category.PURE: - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.SAFE, @@ -178,13 +201,35 @@ def _mutations_markers(self, func: Func) -> Iterator[Mutation]: ) contract_args = [self._exc_as_str(marker) for marker in declared] contract_args.extend(sorted(markers)) - yield Insert( + yield InsertContract( line=func.line, indent=func.col, contract=Category.HAS, args=[f'{self.quote}{arg}{self.quote}' for arg in contract_args], ) + def _mutations_import(self, tree: astroid.Module) -> Iterator[Mutation]: + """Add `import deal` if needed. + """ + if TransformationType.IMPORT not in self.types: + return + if not self.mutations: + return + # check if already imported + for stmt in tree.body: + if not isinstance(stmt, astroid.Import): + continue + for name, _ in stmt.names: + if name == 'deal': + return + line = 1 + skip = (astroid.ImportFrom, astroid.Import, astroid.Const) + for stmt in tree.body: # pragma: no cover + if not isinstance(stmt, skip): + break + line = stmt.lineno + 1 + yield InsertText(line=line, text='import deal') + def _apply_mutations(self, content: str) -> str: if not self.mutations: return content diff --git a/tests/test_linter/test_contract.py b/tests/test_linter/test_contract.py index 617b857e..35172753 100644 --- a/tests/test_linter/test_contract.py +++ b/tests/test_linter/test_contract.py @@ -5,7 +5,6 @@ import pytest from deal.linter._contract import Category, Contract -from deal.linter._extractors.contracts import SUPPORTED_CONTRACTS, SUPPORTED_MARKERS from deal.linter._func import Func @@ -19,16 +18,6 @@ def f(x): """ -@pytest.mark.parametrize('marker', SUPPORTED_MARKERS) -def test_all_markers_are_contracts(marker): - assert marker in SUPPORTED_CONTRACTS - - -def test_supported_contracts_match_categories(marker): - cats = {c.value for c in Category} - assert cats == SUPPORTED_CONTRACTS - - def test_exceptions(): funcs1 = Func.from_ast(ast.parse(TEXT)) assert len(funcs1) == 1 diff --git a/tests/test_linter/test_extractors/test_contracts.py b/tests/test_linter/test_extractors/test_contracts.py index fe77a8c1..1cbce1b7 100644 --- a/tests/test_linter/test_extractors/test_contracts.py +++ b/tests/test_linter/test_extractors/test_contracts.py @@ -4,11 +4,18 @@ import astroid import pytest +from deal.linter._contract import Category from deal.linter._extractors import get_contracts +from deal.linter._extractors.contracts import SUPPORTED_CONTRACTS, SUPPORTED_MARKERS + + +def test_supported_contracts_match_categories(): + cats = {f'deal.{c.value}' for c in Category} + assert cats == SUPPORTED_CONTRACTS | SUPPORTED_MARKERS def get_cats(target) -> tuple: - return tuple(cat for cat, _ in get_contracts(target)) + return tuple(contract.name for contract in get_contracts(target)) @pytest.mark.parametrize('text, expected', [ diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 59413828..496612e0 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -345,3 +345,84 @@ def test_transformer_has(content: str, tmp_path: Path) -> None: ) actual = tr.transform() assert actual == expected + + +@pytest.mark.parametrize('content', [ + # add import if needed + """ + def f(): + return 1 + --- + import deal + + @deal.has() + def f(): + return 1 + """, + # skip imports + """ + import re + + def f(): + return 1 + --- + import re + import deal + + @deal.has() + def f(): + return 1 + """, + # skip import-from, do not skip consts + """ + import re + from textwrap import dedent + + HI = 1 + + def f(): + return 1 + --- + import re + from textwrap import dedent + import deal + + HI = 1 + + @deal.has() + def f(): + return 1 + """, + # do nothing if there are no mutations + """ + @deal.has() + def f(): + return 1 + --- + @deal.has() + def f(): + return 1 + """, + # do not duplicate existing imports + """ + import deal + def f(): + return 1 + --- + import deal + @deal.has() + def f(): + return 1 + """, +]) +def test_transformer_import(content: str, tmp_path: Path) -> None: + given, expected = content.split('---') + given = dedent(given) + expected = dedent(expected) + tr = Transformer( + content=given, + path=tmp_path / "example.py", + types={TransformationType.HAS, TransformationType.IMPORT}, + ) + actual = tr.transform() + assert actual.strip() == expected.strip() From 1bafce311dc4ef62b449f463f33b364cdb756c35 Mon Sep 17 00:00:00 2001 From: gram Date: Wed, 17 Nov 2021 16:11:14 +0100 Subject: [PATCH 17/20] do not break multiline imports --- deal/_cli/_decorate.py | 2 +- deal/linter/_transformer.py | 7 ++++--- tests/test_cli/test_decorate.py | 3 +-- tests/test_linter/test_transformer.py | 25 +++++++++++++++++++++---- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/deal/_cli/_decorate.py b/deal/_cli/_decorate.py index 42164f79..faa9b1e9 100644 --- a/deal/_cli/_decorate.py +++ b/deal/_cli/_decorate.py @@ -29,7 +29,7 @@ def init_parser(parser: ArgumentParser) -> None: '--types', nargs='*', choices=[tt.value for tt in TransformationType], - default=['has', 'raises', 'safe'], + default=['has', 'raises', 'safe', 'import'], help='types of decorators to apply', ) parser.add_argument( diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index cf4db6ea..76bd1294 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -225,9 +225,10 @@ def _mutations_import(self, tree: astroid.Module) -> Iterator[Mutation]: line = 1 skip = (astroid.ImportFrom, astroid.Import, astroid.Const) for stmt in tree.body: # pragma: no cover - if not isinstance(stmt, skip): - break - line = stmt.lineno + 1 + if isinstance(stmt, skip): + continue + line = getattr(stmt, 'lineno', line) + break yield InsertText(line=line, text='import deal') def _apply_mutations(self, content: str) -> str: diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index f71ae2a7..de30a40f 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -11,7 +11,6 @@ ( [], """ - import deal @deal.post(lambda x: x > 0) def f(x): print(1/0) @@ -89,4 +88,4 @@ def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path stream.seek(0) captured = stream.read() assert str(file_path) in captured - assert file_path.read_text() == dedent(expected) + assert file_path.read_text().lstrip('\n') == dedent(expected).lstrip('\n') diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 496612e0..108ae617 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -354,7 +354,6 @@ def f(): return 1 --- import deal - @deal.has() def f(): return 1 @@ -367,8 +366,8 @@ def f(): return 1 --- import re - import deal + import deal @deal.has() def f(): return 1 @@ -385,8 +384,8 @@ def f(): --- import re from textwrap import dedent - import deal + import deal HI = 1 @deal.has() @@ -414,6 +413,24 @@ def f(): def f(): return 1 """, + # support multiline imports + """ + from textwrap import ( + dedent, + ) + + def f(): + return 1 + --- + from textwrap import ( + dedent, + ) + + import deal + @deal.has() + def f(): + return 1 + """, ]) def test_transformer_import(content: str, tmp_path: Path) -> None: given, expected = content.split('---') @@ -425,4 +442,4 @@ def test_transformer_import(content: str, tmp_path: Path) -> None: types={TransformationType.HAS, TransformationType.IMPORT}, ) actual = tr.transform() - assert actual.strip() == expected.strip() + assert actual.lstrip('\n') == expected.lstrip('\n') From 5a564cb7e20f98ee9728ea3797f301ecfe979d30 Mon Sep 17 00:00:00 2001 From: gram Date: Thu, 18 Nov 2021 15:08:05 +0100 Subject: [PATCH 18/20] skip module imports --- deal/linter/_transformer.py | 15 ++++++++----- tests/test_cli/test_decorate.py | 1 + tests/test_linter/test_transformer.py | 32 ++++++++++++++++++++++----- 3 files changed, 37 insertions(+), 11 deletions(-) diff --git a/deal/linter/_transformer.py b/deal/linter/_transformer.py index 76bd1294..ffd87042 100644 --- a/deal/linter/_transformer.py +++ b/deal/linter/_transformer.py @@ -222,13 +222,16 @@ def _mutations_import(self, tree: astroid.Module) -> Iterator[Mutation]: for name, _ in stmt.names: if name == 'deal': return + + # We insert the import after `__future__` imports and module imports. + # We don't skip `from` imports, though, because they can be multiline. line = 1 - skip = (astroid.ImportFrom, astroid.Import, astroid.Const) - for stmt in tree.body: # pragma: no cover - if isinstance(stmt, skip): - continue - line = getattr(stmt, 'lineno', line) - break + for stmt in tree.body: + if isinstance(stmt, astroid.Import): + line = stmt.lineno + 1 + if isinstance(stmt, astroid.ImportFrom): + if stmt.modname == '__future__': + line = stmt.lineno + 1 yield InsertText(line=line, text='import deal') def _apply_mutations(self, content: str) -> str: diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index de30a40f..9295dfec 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -18,6 +18,7 @@ def f(x): """, """ import deal + @deal.has('stdout') @deal.raises(ZeroDivisionError) @deal.post(lambda x: x > 0) diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index 108ae617..aabf2acc 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -366,8 +366,8 @@ def f(): return 1 --- import re - import deal + @deal.has() def f(): return 1 @@ -383,9 +383,9 @@ def f(): return 1 --- import re + import deal from textwrap import dedent - import deal HI = 1 @deal.has() @@ -418,14 +418,36 @@ def f(): from textwrap import ( dedent, ) - def f(): return 1 --- + import deal from textwrap import ( dedent, ) - + @deal.has() + def f(): + return 1 + """, + # skip __future__ imports + """ + from __future__ import annotations + def f(): + return 1 + --- + from __future__ import annotations + import deal + @deal.has() + def f(): + return 1 + """, + # skip from imports before module imports + """ + from __future__ import annotations + def f(): + return 1 + --- + from __future__ import annotations import deal @deal.has() def f(): @@ -434,7 +456,7 @@ def f(): ]) def test_transformer_import(content: str, tmp_path: Path) -> None: given, expected = content.split('---') - given = dedent(given) + given = dedent(given).lstrip('\n') expected = dedent(expected) tr = Transformer( content=given, From 3fa57b1be67b88a4aadf2594571ab75cdeb2d22f Mon Sep 17 00:00:00 2001 From: gram Date: Thu, 18 Nov 2021 15:16:11 +0100 Subject: [PATCH 19/20] fix type errors for get_contracts usage --- deal/_cli/_prove.py | 7 +++++-- deal/linter/_extractors/contracts.py | 4 ---- deal/linter/_extractors/exceptions.py | 6 +++--- deal/linter/_extractors/markers.py | 6 +++--- deal/linter/_extractors/pre.py | 7 +++---- 5 files changed, 14 insertions(+), 16 deletions(-) diff --git a/deal/_cli/_prove.py b/deal/_cli/_prove.py index 7dbebe53..f82bade8 100644 --- a/deal/_cli/_prove.py +++ b/deal/_cli/_prove.py @@ -19,8 +19,11 @@ class DealTheorem(Theorem): @staticmethod def get_contracts(func: astroid.FunctionDef) -> Iterator[Contract]: - for name, args in get_contracts(func): - yield Contract(name=name, args=args) + for contract in get_contracts(func): + yield Contract( + name=contract.name, + args=contract.args, # type: ignore[arg-type] + ) def run_solver( diff --git a/deal/linter/_extractors/contracts.py b/deal/linter/_extractors/contracts.py index cca5efb7..f0ebef10 100644 --- a/deal/linter/_extractors/contracts.py +++ b/deal/linter/_extractors/contracts.py @@ -25,10 +25,6 @@ class ContractInfo(NamedTuple): args: List[Union[ast.expr, astroid.Expr]] line: int - def __iter__(self) -> Iterator: - yield self.name - yield self.args - def get_contracts(func) -> Iterator[ContractInfo]: if isinstance(func, ast.FunctionDef): diff --git a/deal/linter/_extractors/exceptions.py b/deal/linter/_extractors/exceptions.py index 52feb242..a05524b5 100644 --- a/deal/linter/_extractors/exceptions.py +++ b/deal/linter/_extractors/exceptions.py @@ -104,10 +104,10 @@ def _exceptions_from_func(expr) -> Iterator[Token]: yield Token(value=error.value, line=expr.lineno, col=expr.col_offset) # get explicitly specified exceptions from `@deal.raises` - for category, args in get_contracts(value): - if category != 'raises': + for contract in get_contracts(value): + if contract.name != 'raises': continue - for arg in args: + for arg in contract.args: name = get_name(arg) if name is None: continue diff --git a/deal/linter/_extractors/markers.py b/deal/linter/_extractors/markers.py index 005fb33b..931886b6 100644 --- a/deal/linter/_extractors/markers.py +++ b/deal/linter/_extractors/markers.py @@ -273,10 +273,10 @@ def _markers_from_func(expr: astroid.NodeNG, inferred: tuple) -> Iterator[Token] ) # get explicitly specified markers from `@deal.has` - for category, args in get_contracts(value): - if category != 'has': + for contract in get_contracts(value): + if contract.name != 'has': continue - for arg in args: + for arg in contract.args: value = get_value(arg) if type(value) is not str: continue diff --git a/deal/linter/_extractors/pre.py b/deal/linter/_extractors/pre.py index 7caf3dbd..24c10469 100644 --- a/deal/linter/_extractors/pre.py +++ b/deal/linter/_extractors/pre.py @@ -34,12 +34,11 @@ def handle_call(expr: astroid.Call, context: Dict[str, ast.stmt] = None) -> Iter continue code = f'def f({func.args.as_string()}):0' func_args = ast.parse(code).body[0].args # type: ignore - for category, contract_args in get_contracts(func): - if category != 'pre': + for cinfo in get_contracts(func): + if cinfo.name != 'pre': continue - contract = Contract( - args=contract_args, + args=cinfo.args, category=Category.PRE, func_args=func_args, context=context, From c30eda62326931c5c40fee5114e48b3a64be0e4b Mon Sep 17 00:00:00 2001 From: gram Date: Thu, 18 Nov 2021 15:17:55 +0100 Subject: [PATCH 20/20] fix flake8 --- deal/linter/_extractors/contracts.py | 2 +- tests/test_cli/test_decorate.py | 8 ++++---- tests/test_linter/test_transformer.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/deal/linter/_extractors/contracts.py b/deal/linter/_extractors/contracts.py index f0ebef10..a2fa3ab0 100644 --- a/deal/linter/_extractors/contracts.py +++ b/deal/linter/_extractors/contracts.py @@ -47,7 +47,7 @@ def _get_contracts(decorators: list) -> Iterator[ContractInfo]: yield ContractInfo( name=name.split('.')[-1], args=[], - line=contract.lineno + line=contract.lineno, ) if name == 'deal.inherit': yield from _resolve_inherit(contract) diff --git a/tests/test_cli/test_decorate.py b/tests/test_cli/test_decorate.py index 9295dfec..28327861 100644 --- a/tests/test_cli/test_decorate.py +++ b/tests/test_cli/test_decorate.py @@ -25,7 +25,7 @@ def f(x): def f(x): print(1/0) return -1 - """ + """, ), ( ['--types', 'raises', 'safe'], @@ -43,7 +43,7 @@ def f(x): def f(x): print(1/0) return -1 - """ + """, ), ( ['--types', 'has', '--double-quotes'], @@ -61,7 +61,7 @@ def f(x): def f(x): print(1/0) return -1 - """ + """, ), ( [], @@ -76,7 +76,7 @@ def f(x): @deal.pure def f(x): return x - """ + """, ), ]) def test_decorate_command(flags: list, given: str, expected: str, tmp_path: Path): diff --git a/tests/test_linter/test_transformer.py b/tests/test_linter/test_transformer.py index aabf2acc..a68dfbc7 100644 --- a/tests/test_linter/test_transformer.py +++ b/tests/test_linter/test_transformer.py @@ -191,7 +191,7 @@ def test_transformer_raises(content: str, tmp_path: Path) -> None: expected = dedent(expected) tr = Transformer( content=given, - path=tmp_path / "example.py", + path=tmp_path / 'example.py', types={TransformationType.RAISES, TransformationType.SAFE}, ) actual = tr.transform() @@ -340,7 +340,7 @@ def test_transformer_has(content: str, tmp_path: Path) -> None: expected = dedent(expected) tr = Transformer( content=given, - path=tmp_path / "example.py", + path=tmp_path / 'example.py', types={TransformationType.HAS}, ) actual = tr.transform() @@ -460,7 +460,7 @@ def test_transformer_import(content: str, tmp_path: Path) -> None: expected = dedent(expected) tr = Transformer( content=given, - path=tmp_path / "example.py", + path=tmp_path / 'example.py', types={TransformationType.HAS, TransformationType.IMPORT}, ) actual = tr.transform()