Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ast_tools/passes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

__ALL__ = ['Pass', 'PASS_ARGS_T']

PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable]
PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable, tp.MutableMapping]

class Pass(metaclass=ABCMeta):
"""
Expand All @@ -21,6 +21,7 @@ def __call__(self, args: PASS_ARGS_T) -> PASS_ARGS_T:
def rewrite(self,
env: SymbolTable,
tree: ast.AST,
metadata: tp.MutableMapping,
) -> PASS_ARGS_T:

"""
Expand Down
20 changes: 19 additions & 1 deletion ast_tools/passes/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self,
dump_env: bool = False,
file: tp.Optional[str] = None,
append: tp.Optional[bool] = None,
dump_source_filename: bool = False,
dump_source_lines: bool = False
) -> PASS_ARGS_T:
self.dump_ast = dump_ast
self.dump_src = dump_src
Expand All @@ -25,10 +27,13 @@ def __init__(self,
warnings.warn('Option append has no effect when file is None', stacklevel=2)
self.file = file
self.append = append
self.dump_source_filename = dump_source_filename
self.dump_source_lines = dump_source_lines

def rewrite(self,
tree: ast.AST,
env: SymbolTable,
metadata: tp.MutableMapping,
) -> PASS_ARGS_T:

def _do_dumps(dumps, dump_writer):
Expand All @@ -44,6 +49,19 @@ def _do_dumps(dumps, dump_writer):
dumps.append(('SRC', astor.to_source(tree)))
if self.dump_env:
dumps.append(('ENV', repr(env)))
if self.dump_source_filename:
if "source_filename" not in metadata:
raise Exception("Cannot dump source filename without "
"@begin_rewrite(debug=True)")
dumps.append(('SOURCE_FILENAME', metadata["source_filename"]))
if self.dump_source_lines:
if "source_lines" not in metadata:
raise Exception("Cannot dump source lines without "
"@begin_rewrite(debug=True)")
lines, start_line_number = metadata["source_lines"]
dump_str = "".join(f"{start_line_number + i}:{line}" for i, line in
enumerate(lines))
dumps.append(('SOURCE_LINES', dump_str))

if self.file is not None:
if self.append:
Expand All @@ -56,4 +74,4 @@ def _do_dumps(dumps, dump_writer):
def _print(*args, **kwargs): print(*args, end='', **kwargs)
_do_dumps(dumps, _print)

return tree, env
return tree, env, metadata
4 changes: 2 additions & 2 deletions ast_tools/passes/ssa.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ class ssa(Pass):
def __init__(self, return_prefix: str = '__return_value'):
self.return_prefix = return_prefix

def rewrite(self, tree: ast.AST, env: SymbolTable):
def rewrite(self, tree: ast.AST, env: SymbolTable, metadata: tp.MutableMapping):
if not isinstance(tree, ast.FunctionDef):
raise TypeError('ssa should only be applied to functions')
r_name = gen_free_prefix(tree, env, self.return_prefix)
Expand All @@ -353,4 +353,4 @@ def rewrite(self, tree: ast.AST, env: SymbolTable):
value=_build_return(visitor.returns)
)
)
return tree, env
return tree, env, metadata
13 changes: 10 additions & 3 deletions ast_tools/passes/util.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import ast
import typing as tp

Expand All @@ -13,13 +14,18 @@ class begin_rewrite:
"""
begins a chain of passes
"""
def __init__(self):
def __init__(self, debug=False):
env = get_symbol_table([self.__init__])
self.env = env
self.debug = debug

def __call__(self, fn) -> PASS_ARGS_T:
tree = get_ast(fn)
return tree, self.env
metadata = {}
if self.debug:
metadata["source_filename"] = inspect.getsourcefile(fn)
metadata["source_lines"] = inspect.getsourcelines(fn)
return tree, self.env, metadata

def _issubclass(t, s) -> bool:
try:
Expand All @@ -37,7 +43,8 @@ def __init__(self, path: tp.Optional[str] = None):

def rewrite(self,
tree: ast.AST,
env: SymbolTable) -> tp.Union[tp.Callable, type]:
env: SymbolTable,
metadata: tp.MutableMapping) -> tp.Union[tp.Callable, type]:
decorators = []
first_group = True
in_group = False
Expand Down
47 changes: 46 additions & 1 deletion tests/test_passes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import os
import functools
import inspect

from ast_tools.passes import begin_rewrite, end_rewrite
from ast_tools.passes import begin_rewrite, end_rewrite, debug



Expand All @@ -25,3 +26,47 @@ def foo():
def foo():
pass
'''


def test_debug(capsys):

@end_rewrite
@debug(dump_source_filename=True, dump_source_lines=True)
@begin_rewrite(debug=True)
def foo():
print("bar")
assert capsys.readouterr().out == f"""\
BEGIN SOURCE_FILENAME
{os.path.abspath(__file__)}
END SOURCE_FILENAME

BEGIN SOURCE_LINES
33: @end_rewrite
34: @debug(dump_source_filename=True, dump_source_lines=True)
35: @begin_rewrite(debug=True)
36: def foo():
37: print("bar")
END SOURCE_LINES

"""


def test_debug_error():

try:
@end_rewrite
@debug(dump_source_filename=True)
@begin_rewrite()
def foo():
print("bar")
except Exception as e:
assert str(e) == "Cannot dump source filename without @begin_rewrite(debug=True)"

try:
@end_rewrite
@debug(dump_source_lines=True)
@begin_rewrite()
def foo():
print("bar")
except Exception as e:
assert str(e) == "Cannot dump source lines without @begin_rewrite(debug=True)"