diff --git a/ast_tools/passes/base.py b/ast_tools/passes/base.py index 4462f13..ff18c10 100644 --- a/ast_tools/passes/base.py +++ b/ast_tools/passes/base.py @@ -6,7 +6,7 @@ __ALL__ = ['Pass', 'PASS_ARGS_T'] -PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable] +PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable, tp.MutableMapping] class Pass(metaclass=ABCMeta): """ @@ -21,6 +21,7 @@ def __call__(self, args: PASS_ARGS_T) -> PASS_ARGS_T: def rewrite(self, env: SymbolTable, tree: ast.AST, + metadata: tp.MutableMapping, ) -> PASS_ARGS_T: """ diff --git a/ast_tools/passes/debug.py b/ast_tools/passes/debug.py index 8d677e8..bb34996 100644 --- a/ast_tools/passes/debug.py +++ b/ast_tools/passes/debug.py @@ -17,6 +17,8 @@ def __init__(self, dump_env: bool = False, file: tp.Optional[str] = None, append: tp.Optional[bool] = None, + dump_source_filename: bool = False, + dump_source_lines: bool = False ) -> PASS_ARGS_T: self.dump_ast = dump_ast self.dump_src = dump_src @@ -25,10 +27,13 @@ def __init__(self, warnings.warn('Option append has no effect when file is None', stacklevel=2) self.file = file self.append = append + self.dump_source_filename = dump_source_filename + self.dump_source_lines = dump_source_lines def rewrite(self, tree: ast.AST, env: SymbolTable, + metadata: tp.MutableMapping, ) -> PASS_ARGS_T: def _do_dumps(dumps, dump_writer): @@ -44,6 +49,19 @@ def _do_dumps(dumps, dump_writer): dumps.append(('SRC', astor.to_source(tree))) if self.dump_env: dumps.append(('ENV', repr(env))) + if self.dump_source_filename: + if "source_filename" not in metadata: + raise Exception("Cannot dump source filename without " + "@begin_rewrite(debug=True)") + dumps.append(('SOURCE_FILENAME', metadata["source_filename"])) + if self.dump_source_lines: + if "source_lines" not in metadata: + raise Exception("Cannot dump source lines without " + "@begin_rewrite(debug=True)") + lines, start_line_number = metadata["source_lines"] + dump_str = "".join(f"{start_line_number + i}:{line}" for i, line in + enumerate(lines)) + dumps.append(('SOURCE_LINES', dump_str)) if self.file is not None: if self.append: @@ -56,4 +74,4 @@ def _do_dumps(dumps, dump_writer): def _print(*args, **kwargs): print(*args, end='', **kwargs) _do_dumps(dumps, _print) - return tree, env + return tree, env, metadata diff --git a/ast_tools/passes/ssa.py b/ast_tools/passes/ssa.py index b7ec360..6841547 100644 --- a/ast_tools/passes/ssa.py +++ b/ast_tools/passes/ssa.py @@ -342,7 +342,7 @@ class ssa(Pass): def __init__(self, return_prefix: str = '__return_value'): self.return_prefix = return_prefix - def rewrite(self, tree: ast.AST, env: SymbolTable): + def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping): if not isinstance(tree, ast.FunctionDef): raise TypeError('ssa should only be applied to functions') r_name = gen_free_prefix(tree, env, self.return_prefix) @@ -353,4 +353,4 @@ def rewrite(self, tree: ast.AST, env: SymbolTable): value=_build_return(visitor.returns) ) ) - return tree, env + return tree, env, metadata diff --git a/ast_tools/passes/util.py b/ast_tools/passes/util.py index e6c1775..855439d 100644 --- a/ast_tools/passes/util.py +++ b/ast_tools/passes/util.py @@ -1,3 +1,4 @@ +import inspect import ast import typing as tp @@ -13,13 +14,18 @@ class begin_rewrite: """ begins a chain of passes """ - def __init__(self): + def __init__(self, debug=False): env = get_symbol_table([self.__init__]) self.env = env + self.debug = debug def __call__(self, fn) -> PASS_ARGS_T: tree = get_ast(fn) - return tree, self.env + metadata = {} + if self.debug: + metadata["source_filename"] = inspect.getsourcefile(fn) + metadata["source_lines"] = inspect.getsourcelines(fn) + return tree, self.env, metadata def _issubclass(t, s) -> bool: try: @@ -37,7 +43,8 @@ def __init__(self, path: tp.Optional[str] = None): def rewrite(self, tree: ast.AST, - env: SymbolTable) -> tp.Union[tp.Callable, type]: + env: SymbolTable, + metadata: tp.MutableMapping) -> tp.Union[tp.Callable, type]: decorators = [] first_group = True in_group = False diff --git a/tests/test_passes.py b/tests/test_passes.py index c7b5b83..14d5b9f 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -1,7 +1,8 @@ +import os import functools import inspect -from ast_tools.passes import begin_rewrite, end_rewrite +from ast_tools.passes import begin_rewrite, end_rewrite, debug @@ -25,3 +26,47 @@ def foo(): def foo(): pass ''' + + +def test_debug(capsys): + + @end_rewrite + @debug(dump_source_filename=True, dump_source_lines=True) + @begin_rewrite(debug=True) + def foo(): + print("bar") + assert capsys.readouterr().out == f"""\ +BEGIN SOURCE_FILENAME +{os.path.abspath(__file__)} +END SOURCE_FILENAME + +BEGIN SOURCE_LINES +33: @end_rewrite +34: @debug(dump_source_filename=True, dump_source_lines=True) +35: @begin_rewrite(debug=True) +36: def foo(): +37: print("bar") +END SOURCE_LINES + +""" + + +def test_debug_error(): + + try: + @end_rewrite + @debug(dump_source_filename=True) + @begin_rewrite() + def foo(): + print("bar") + except Exception as e: + assert str(e) == "Cannot dump source filename without @begin_rewrite(debug=True)" + + try: + @end_rewrite + @debug(dump_source_lines=True) + @begin_rewrite() + def foo(): + print("bar") + except Exception as e: + assert str(e) == "Cannot dump source lines without @begin_rewrite(debug=True)"