diff --git a/README.md b/README.md index 450fb1ce..4e9d85a7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ * [Can be enabled or disabled on production.][runtime] * [Colorless](colorless): annotate only what you want. Hence, easy integration into an existing project. * Colorful: syntax highlighting for every piece of code in every command. +* [Memory leaks detection.][leaks] Deal makes sure that a pure function doesn't leave unexpected objects in the memory. * DRY: test discovery, error messages generation. * Partial execution: linter executes contracts to statically check possible values. @@ -31,6 +32,7 @@ [module_load]: https://deal.readthedocs.io/details/module_load.html [runtime]: https://deal.readthedocs.io/basic/runtime.html [colorless]: http://journal.stuffwithstuff.com/2015/02/01/what-color-is-your-function/ +[leaks]: https://deal.readthedocs.io/basic/tests.html#memory-leaks ## Deal in 30 seconds diff --git a/deal/_cli/_main.py b/deal/_cli/_main.py index f9201a12..acc69612 100644 --- a/deal/_cli/_main.py +++ b/deal/_cli/_main.py @@ -7,11 +7,13 @@ from ._lint import lint_command from ._stub import stub_command from ._test import test_command +from ._memtest import memtest_command CommandsType = Mapping[str, Callable[[Sequence[str]], int]] COMMANDS: CommandsType = MappingProxyType(dict( lint=lint_command, + memtest=memtest_command, stub=stub_command, test=test_command, )) diff --git a/deal/_cli/_memtest.py b/deal/_cli/_memtest.py new file mode 100644 index 00000000..5be5d1a2 --- /dev/null +++ b/deal/_cli/_memtest.py @@ -0,0 +1,118 @@ +# built-in +from contextlib import suppress +import sys +from argparse import ArgumentParser +from importlib import import_module +from pathlib import Path +from typing import Dict, Iterator, Sequence, TextIO + +# app +from .._testing import cases, TestCase +from .._mem_test import MemoryTracker +from .._state import state +from ..linter._extractors.pre import format_call_args +from ._common import get_paths +from .._colors import COLORS +from ._test import sys_path, get_func_names + + +def run_tests(path: Path, root: Path, count: int, stream: TextIO = sys.stdout) -> int: + names = list(get_func_names(path=path)) + if not names: + return 0 + print('{magenta}running {path}{end}'.format(path=path, **COLORS), file=stream) + module_name = '.'.join(path.relative_to(root).with_suffix('').parts) + with sys_path(path=root): + module = import_module(module_name) + failed = 0 + for func_name in names: + func = getattr(module, func_name) + ok = run_cases( + cases=cases(func=func, count=count, check_types=False), + func_name=func_name, + stream=stream, + colors=COLORS, + ) + if not ok: + failed += 1 + return failed + + +def run_cases( + cases: Iterator[TestCase], + func_name: str, + stream: TextIO, + colors: Dict[str, str], +) -> bool: + print(' {blue}running {name}{end}'.format(name=func_name, **colors), file=stream) + for case in cases: + tracker = MemoryTracker() + debug = state.debug + state.disable() + try: + with tracker, suppress(Exception): + case() + finally: + state.debug = debug + if not tracker.diff: + continue + + # show the diff and stop testing the func + line = ' {yellow}{name}({args}){end}'.format( + name=func_name, + args=format_call_args(args=case.args, kwargs=case.kwargs), + **colors, + ) + print(line, file=stream) + longest_name_len = max(len(name) for name in tracker.diff) + for name, count in tracker.diff.items(): + line = ' {red}{name}{end} x{count}'.format( + name=name.ljust(longest_name_len), + count=count, + **colors, + ) + print(line, file=stream) + return False + return True + + +def memtest_command( + argv: Sequence[str], root: Path = None, stream: TextIO = sys.stdout, +) -> int: + """Generate and run tests against pure functions and report memory leaks. + + ```bash + python3 -m deal memtest project/ + ``` + + Function must be decorated by one of the following to be run: + + + `@deal.pure` + + `@deal.has()` (without arguments) + + Options: + + + `--count`: how many input values combinations should be checked. + + Exit code is equal to count of leaked functions. + See [memory leaks][leaks] documentation for more details. + + [leaks]: https://deal.readthedocs.io/details/tests.html#memory-leaks + """ + if root is None: # pragma: no cover + root = Path() + parser = ArgumentParser(prog='python3 -m deal test') + parser.add_argument('--count', type=int, default=50) + parser.add_argument('paths', nargs='+') + args = parser.parse_args(argv) + + failed = 0 + for arg in args.paths: + for path in get_paths(Path(arg)): + failed += run_tests( + path=Path(path), + root=root, + count=args.count, + stream=stream, + ) + return failed diff --git a/deal/_exceptions.py b/deal/_exceptions.py index 20ca0c52..7d4db3ec 100644 --- a/deal/_exceptions.py +++ b/deal/_exceptions.py @@ -28,7 +28,7 @@ def exception_hook(etype: Type[BaseException], value: BaseException, tb): if path.startswith(root): with suppress(AttributeError): # read-only attribute in <3.7 prev_tb.tb_next = None - break + break # pragma: no cover prev_tb = patched_tb patched_tb = patched_tb.tb_next else: diff --git a/deal/_mem_test.py b/deal/_mem_test.py new file mode 100644 index 00000000..3e3c4b80 --- /dev/null +++ b/deal/_mem_test.py @@ -0,0 +1,32 @@ +import gc +import typing +from collections import Counter +from._cached_property import cached_property + + +class MemoryTracker: + before: typing.Counter[str] + after: typing.Counter[str] + + def __init__(self) -> None: + self.before = Counter() + self.after = Counter() + + def __enter__(self) -> None: + self.before = self._dump() + + def __exit__(self, *exc) -> None: + self.after = self._dump() + + @cached_property + def diff(self) -> typing.Counter[str]: + return self.after - self.before - Counter({'weakref': 1}) + + @classmethod + def _dump(cls) -> typing.Counter[str]: + counter: typing.Counter[str] = Counter() + gc.collect() + for obj in gc.get_objects(): + name: str = type(obj).__qualname__ + counter[name] += 1 + return counter diff --git a/docs/basic/intro.md b/docs/basic/intro.md index d6f21076..a87bdbe0 100644 --- a/docs/basic/intro.md +++ b/docs/basic/intro.md @@ -41,6 +41,7 @@ It's not "advanced usage", there is nothing advanced or difficult. It's about wr 1. [module_load](../details/module_load) allows you to control what happens at the module load (import) time. 1. [Stubs](../details/stubs) is a way to store some contracts in a JSON file instead of the source code. It can be helpful for third-party libraries. Some stubs already inside Deal. +1. [More about testing](../details/tests) provides information on finding memory leaks and tweaking tests generation. 1. [Validators](../details/validators) is a way to describe complex contracts using [Marshmallow](https://github.com/marshmallow-code/marshmallow) or another validation library. 1. [Recipes](../details/recipes) is the place to learn more about best practices of using contracts. diff --git a/docs/details/cli.md b/docs/details/cli.md index 4529aafc..c0ad7cdf 100644 --- a/docs/details/cli.md +++ b/docs/details/cli.md @@ -17,3 +17,9 @@ ```eval_rst .. autofunction:: deal._cli._test.test_command ``` + +## memtest + +```eval_rst +.. autofunction:: deal._cli._memtest.memtest_command +``` diff --git a/docs/details/tests.md b/docs/details/tests.md new file mode 100644 index 00000000..c967d7d9 --- /dev/null +++ b/docs/details/tests.md @@ -0,0 +1,34 @@ +# More about testing + +This section assumes that you're familiar with [basic testing](../basic/tests.md) and describes how you can get more from deal testing mechanisms. + +## Finding memory leaks + +Sometimes, when a function is completed, it leaves in memory other objects except result. For example: + +```python +cache = {} +User = dict + +def get_user(name: str) -> User: + if name not in cache: + cache[name] = User(name=name) + return cache[name] +``` + +Here, `get_user` creates a `User` object and stores it in a global cache. In this case, this "leak" is a desired behavior and we don't want to fight it. This is why we can't a tool (or something right in the Python interpreter) that catches and reports such behavior, it would have too many false-positives. + +However, things are different with pure functions. A pure function can't store anything on a side because it is a side effect. The result of a pure function is only what it returns. + +The command `memtest` uses this idea to find memory leaks in pure functions. How it works: + +1. It finds all pure functions (as `test` does). +1. For every function: + 1. It makes memory snapshot before running the function. + 1. It runs the function with different autogenerated input arguments (as `test` command does) without running contracts and checking the return value type (to avoid side-effects from deal itself). + 1. It makes memory snapshot after running the function. + 1. Snapshots "before" and "after" are comapared. If there is a difference it will be printed. + +The return code is equal to the amount of functions with memory leaks. + +If the function fails, the command will ignore it and still test the function for leaks. Side-effects shouldn't happen unconditionally, even if the function fails. If you want to find unexpected failures, use `test` command instead. diff --git a/docs/index.md b/docs/index.md index 75aff6d1..3dbf93a3 100644 --- a/docs/index.md +++ b/docs/index.md @@ -33,6 +33,7 @@ details/module_load details/stubs details/validators + details/tests details/recipes details/examples details/cli diff --git a/tests/test_cli/test_mem_test.py b/tests/test_cli/test_mem_test.py new file mode 100644 index 00000000..0949870e --- /dev/null +++ b/tests/test_cli/test_mem_test.py @@ -0,0 +1,80 @@ +import sys +from pathlib import Path +from deal._cli._memtest import memtest_command +from textwrap import dedent +from io import StringIO + + +def test_has_side_effect(tmp_path: Path, capsys): + if 'example' in sys.modules: + del sys.modules['example'] + text = """ + import deal + + a = [] + + @deal.pure + def func(b: int) -> float: + a.append({b, b+b}) + return None + """ + path = (tmp_path / 'example.py') + path.write_text(dedent(text)) + stream = StringIO() + result = memtest_command(['--count', '1', str(path)], root=tmp_path, stream=stream) + assert result == 1 + + stream.seek(0) + captured = stream.read() + assert '/example.py' in captured + assert 'running func' in captured + assert 'func(b=0)' in captured + assert 'set' in captured + assert 'x1' in captured + + +def test_no_side_effects(tmp_path: Path, capsys): + if 'example' in sys.modules: + del sys.modules['example'] + text = """ + import deal + + @deal.pure + def func(b: int) -> float: + return b+b + """ + path = (tmp_path / 'example.py') + path.write_text(dedent(text)) + stream = StringIO() + result = memtest_command(['--count', '1', str(path)], root=tmp_path, stream=stream) + assert result == 0 + + stream.seek(0) + captured = stream.read() + assert '/example.py' in captured + assert 'running func' in captured + assert 'func(b=0)' not in captured + + +def test_no_matching_funcs(tmp_path: Path): + if 'example' in sys.modules: + del sys.modules['example'] + text = """ + import deal + + def not_pure1(a: int, b: int) -> float: + return a / b + + @deal.post(lambda result: result > 0) + def not_pure2(a: int, b: int) -> float: + return a / b + """ + path = (tmp_path / 'example.py') + path.write_text(dedent(text)) + stream = StringIO() + result = memtest_command([str(path)], root=tmp_path, stream=stream) + assert result == 0 + + stream.seek(0) + captured = stream.read() + assert '/example.py' not in captured diff --git a/tests/test_linter/test_contract.py b/tests/test_linter/test_contract.py index dacf271b..8e117246 100644 --- a/tests/test_linter/test_contract.py +++ b/tests/test_linter/test_contract.py @@ -224,8 +224,8 @@ def f(a): assert len(func.contracts) == 1 c = func.contracts[0] - c.run(12) is False - c.run(34) is True + assert c.run(12) is False + assert c.run(34) is True def test_resolve_and_run_dependencies_lambda(): @@ -251,8 +251,8 @@ def f(a): assert len(func.contracts) == 1 c = func.contracts[0] - c.run(12) is False - c.run(34) is True + assert c.run(12) is False + assert c.run(34) is True def test_lazy_import_stdlib(): @@ -270,8 +270,8 @@ def f(a): assert len(func.contracts) == 1 c = func.contracts[0] - c.run('bcd') is False - c.run('abc') is True + assert c.run('bcd') is False + assert c.run('abc') is True def test_unresolvable(): diff --git a/tests/test_mem_test.py b/tests/test_mem_test.py new file mode 100644 index 00000000..d5c0cb95 --- /dev/null +++ b/tests/test_mem_test.py @@ -0,0 +1,36 @@ +from deal._mem_test import MemoryTracker + + +def test_mem_dump_no_diff(): + def f(): + return 123 + + tracker = MemoryTracker() + with tracker: + f() + assert not tracker.diff + + +def test_mem_dump_ignore_locals(): + def f(): + a = 456 + b = a + return b + + tracker = MemoryTracker() + with tracker: + f() + assert not tracker.diff + + +def test_mem_dump_side_effect(): + a = [] + + def f(): + a.append({12}) + return 123 + + tracker = MemoryTracker() + with tracker: + f() + assert dict(tracker.diff) == {'set': 1}