diff --git a/ast_tools/__init__.py b/ast_tools/__init__.py index 8a42b56..dbb430a 100644 --- a/ast_tools/__init__.py +++ b/ast_tools/__init__.py @@ -1,20 +1,7 @@ """ ast_tools top level package """ -import inspect -import textwrap -import ast - -__all__ = ["NameCollector", "collect_names", "get_ast"] - -from .collect_names import NameCollector, collect_names +from .common import * +from . import passes from . import stack - - -def get_ast(obj): - """ - Given an object, get the corresponding AST - """ - indented_program_txt = inspect.getsource(obj) - program_txt = textwrap.dedent(indented_program_txt) - return ast.parse(program_txt) +from . import visitors diff --git a/ast_tools/common.py b/ast_tools/common.py new file mode 100644 index 0000000..5e14a08 --- /dev/null +++ b/ast_tools/common.py @@ -0,0 +1,100 @@ +import ast +import functools +import inspect +import logging +import os +import textwrap +import types +import typing as tp +import weakref + +import astor + +from ast_tools import stack +from ast_tools.stack import SymbolTable +from ast_tools.visitors import UsedNames + +__ALL__ = ['exec_in_file', 'exec_def_in_file', 'get_ast', 'gen_free_name'] + +DefStmt = tp.Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef] + +def exec_def_in_file( + tree: DefStmt, + st: SymbolTable, + path: tp.Optional[str] = None, + file_name: tp.Optional[str] = None) -> None: + """ + execs a definition in a file and returns the definiton + """ + if file_name is None: + file_name = tree.name + '.py' + + exec_in_file(tree, st, path, file_name) + return st.locals[tree.name] + + +def exec_in_file( + tree: ast.AST, + st: SymbolTable, + path: tp.Optional[str] = None, + file_name: tp.Optional[str] = None) -> None: + + """ + execs a definition in a file + """ + if path is None: + path = '.ast_tools' + if file_name is None: + file_name = f'ast_tools_exec_{hash(tree)}.py' + + source = astor.to_source(tree) + file_name = os.path.join(path, file_name) + os.makedirs(path, exist_ok=True) + with open(file_name, 'w') as fp: + fp.write(source) + + try: + code = compile(source, filename=file_name, mode='exec') + except Exception as e: + logging.exception("Error compiling source") + raise e from None + + try: + exec(code, st.globals, st.locals) + except Exception as e: + logging.exception("Error executing code") + raise e from None + + +_AST_CACHE = weakref.WeakKeyDictionary() +def get_ast(obj) -> ast.AST: + """ + Given an object, get the corresponding AST + """ + try: + _AST_CACHE[obj] + except KeyError: + pass + + src = textwrap.dedent(inspect.getsource(obj)) + + if isinstance(obj, types.ModuleType): + tree = ast.parse(src) + else: + tree = ast.parse(src).body[0] + + return _AST_CACHE.setdefault(obj, tree) + + +def gen_free_name(tree: ast.AST, env: SymbolTable, prefix: str = '__auto_name_'): + visitor = UsedNames() + visitor.visit(tree) + used_names = visitor.names | env.locals.keys() | env.globals.keys() + f_str = prefix+'{}' + c = 0 + name = f_str.format(c) + while name in used_names: + c += 1 + name = f_str.format(c) + + return name diff --git a/ast_tools/passes/__init__.py b/ast_tools/passes/__init__.py new file mode 100644 index 0000000..6bf8842 --- /dev/null +++ b/ast_tools/passes/__init__.py @@ -0,0 +1,92 @@ +from abc import ABCMeta, abstractmethod +import ast +import astor +import functools +import typing as tp + +from ast_tools.stack import get_symbol_table, SymbolTable +from ast_tools.common import get_ast, exec_def_in_file + +__ALL__ = ['Pass', 'begin_rewrite', 'end_rewite'] + +_PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable] +class Pass(metaclass=ABCMeta): + """ + Abstract base class for passes + Mostly a convience to unpack arguments + """ + + def __call__(self, args: _PASS_ARGS_T) -> _PASS_ARGS_T: + return self.rewrite(*args) + + @abstractmethod + def rewrite(self, + env: SymbolTable, + tree: ast.AST, + ) -> _PASS_ARGS_T: + + """ + Type annotation here should be followed except on terminal passes e.g. + end_rewite + """ + pass + +class begin_rewrite: + """ + begins a chain of passes + """ + def __init__(self): + env = get_symbol_table([self.__init__]) + self.env = env + + def __call__(self, fn) -> _PASS_ARGS_T: + tree = get_ast(fn) + return tree, self.env + +def _issubclass(t, s) -> bool: + try: + return issubclass(t, s) + except TypeError: + return False + + +class end_rewrite(Pass): + """ + ends a chain of passes + """ + def __init__(self, path: tp.Optional[str] = None): + self.path = path + + def rewrite(self, + tree: ast.AST, + env: SymbolTable) -> tp.Union[tp.Callable, type]: + decorators = [] + first_group = True + in_group = False + # filter passes from the decorator list + for node in reversed(tree.decorator_list): + if not first_group: + decorators.append(node) + continue + + if isinstance(node, ast.Call): + expr = ast.Expression(node.func) + else: + assert isinstance(node, ast.Name) + expr = ast.Expression(node) + + code = compile(expr, filename="", mode="eval") + deco = eval(code, env.globals, env.locals) + if in_group: + if _issubclass(deco, end_rewrite): + assert in_group + in_group = False + first_group = False + elif _issubclass(deco, begin_rewrite): + assert not in_group + in_group = True + else: + decorators.append(node) + + tree.decorator_list = reversed(decorators) + return exec_def_in_file(tree, env, self.path) diff --git a/ast_tools/passes/debug.py b/ast_tools/passes/debug.py new file mode 100644 index 0000000..169628a --- /dev/null +++ b/ast_tools/passes/debug.py @@ -0,0 +1,57 @@ +import ast +import typing as tp +import warnings + +import astor + +from . import Pass +from . import _PASS_ARGS_T +from ast_tools.stack import SymbolTable + +class debug(Pass): + def __init__(self, + dump_ast: bool = True, + dump_src: bool = False, + dump_env: bool = False, + file: tp.Optional[str] = None, + append: tp.Optional[bool] = None, + ) -> _PASS_ARGS_T: + self.dump_ast = dump_ast + self.dump_src = dump_src + self.dump_env = dump_env + if append is not None and file is None: + warnings.warn('Option append has no effect when file is None', stacklevel=2) + self.file = file + self.append = append + + def rewrite(self, + tree: ast.AST, + env: SymbolTable, + ) -> _PASS_ARGS_T: + + def _do_dumps(dumps, dump_writer): + for dump in dumps: + dump_writer(f'BEGIN {dump[0]}\n') + dump_writer(dump[1].strip()) + dump_writer(f'\nEND {dump[0]}\n\n') + + dumps = [] + if self.dump_ast: + dumps.append(('AST', astor.dump_tree(tree))) + if self.dump_src: + dumps.append(('SRC', astor.to_source(tree))) + if self.dump_env: + dumps.append(('ENV', repr(env))) + + if self.file is not None: + if self.append: + mode = 'wa' + else: + mode = 'w' + with open(self.dump_file, mode) as fp: + _do_dumps(dumps, fp.write) + else: + def _print(*args, **kwargs): print(*args, end='', **kwargs) + _do_dumps(dumps, _print) + + return tree, env diff --git a/ast_tools/stack.py b/ast_tools/stack.py index 7011176..2287c25 100644 --- a/ast_tools/stack.py +++ b/ast_tools/stack.py @@ -18,7 +18,7 @@ _SKIP_FRAME_DEBUG_FAIL = False class SymbolTable(tp.NamedTuple): - locals: tp.Mapping[str, tp.Any] + locals: tp.MutableMapping[str, tp.Any] globals: tp.Dict[str, tp.Any] def get_symbol_table( @@ -48,7 +48,7 @@ def get_symbol_table( logging.debug(f'{frame.function} @ {frame.filename}:{frame.lineno} might be leaking names') locals = locals.new_child(stack[i].frame.f_locals) globals = globals.new_child(stack[i].frame.f_globals) - return SymbolTable(locals=locals, globals=globals) + return SymbolTable(locals=locals, globals=dict(globals)) def inspect_symbol_table( fn: tp.Callable, # tp.Callable[[SymbolTable, ...], tp.Any], diff --git a/ast_tools/visitors/__init__.py b/ast_tools/visitors/__init__.py new file mode 100644 index 0000000..d8fc3ca --- /dev/null +++ b/ast_tools/visitors/__init__.py @@ -0,0 +1,3 @@ +from .used_names import * +from .collect_names import * + diff --git a/ast_tools/collect_names.py b/ast_tools/visitors/collect_names.py similarity index 100% rename from ast_tools/collect_names.py rename to ast_tools/visitors/collect_names.py diff --git a/ast_tools/visitors/used_names.py b/ast_tools/visitors/used_names.py new file mode 100644 index 0000000..4c512b9 --- /dev/null +++ b/ast_tools/visitors/used_names.py @@ -0,0 +1,17 @@ +import ast + +class UsedNames(ast.NodeVisitor): + def __init__(self): + self.names = set() + + def visit_Name(self, node: ast.Name): + self.names.add(node.id) + + def visit_FunctionDef(self, node: ast.FunctionDef): + self.names.add(node.name) + + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef): + self.names.add(node.name) + + def visit_ClassDef(self, node: ast.ClassDef): + self.names.add(node.name) diff --git a/setup.py b/setup.py index 05a1e97..12574e3 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ description='Toolbox for working with the Python AST', scripts=[], packages=[], - install_requires=[], + install_requires=['astor'], long_description=LONG_DESCRIPTION, long_description_content_type="text/markdown" ) diff --git a/tests/test_collect_names.py b/tests/test_collect_names.py deleted file mode 100644 index 53a27e2..0000000 --- a/tests/test_collect_names.py +++ /dev/null @@ -1,19 +0,0 @@ -""" -Test collecting instances of `ast.Name` -""" -import ast -from ast_tools import get_ast, collect_names - - -def test_collect_names_basic(): - """ - Test collecting names from a simple function including the `ctx` feature - """ - def foo(bar, baz): # pylint: disable=blacklisted-name - buzz = bar + baz - return buzz - - foo_ast = get_ast(foo) - assert collect_names(foo_ast) == {"bar", "baz", "buzz"} - assert collect_names(foo_ast, ctx=ast.Load) == {"bar", "baz", "buzz"} - assert collect_names(foo_ast, ctx=ast.Store) == {"buzz"} diff --git a/tests/test_common.py b/tests/test_common.py new file mode 100644 index 0000000..a6fc5e9 --- /dev/null +++ b/tests/test_common.py @@ -0,0 +1,33 @@ +import ast +import astor + +from ast_tools.common import get_ast, gen_free_name, SymbolTable + + +def test_get_ast(): + def f(): pass + f_str = 'def f(): pass' + ast_str_0 = astor.dump_tree(get_ast(f)) + ast_str_1 = astor.dump_tree(ast.parse(f_str).body[0]) + assert ast_str_0 == ast_str_1 + + +def test_gen_free(): + src = ''' +class P0: + P5 = 1 + def __init__(self): self.y = 0 +def P1(): + return P0.P5 +P2 = P1() +''' + tree = ast.parse(src) + env = SymbolTable({}, {}) + free_name = gen_free_name(tree, env, prefix='P') + assert free_name == 'P3' + env = SymbolTable({'P4': 'foo'}, {}) + free_name = gen_free_name(tree, env, prefix='P') + assert free_name == 'P3' + env = SymbolTable({'P4': 'foo'}, {'P3' : 'bar'}) + free_name = gen_free_name(tree, env, prefix='P') + assert free_name == 'P5' diff --git a/tests/test_passes.py b/tests/test_passes.py new file mode 100644 index 0000000..c7b5b83 --- /dev/null +++ b/tests/test_passes.py @@ -0,0 +1,27 @@ +import functools +import inspect + +from ast_tools.passes import begin_rewrite, end_rewrite + + + +def test_begin_end(): + def wrapper(fn): + @functools.wraps(fn) + def wrapped(*args, **kwargs): + return fn(*args, **kwargs) + return wrapped + + @wrapper + @end_rewrite() + @begin_rewrite() + @wrapper + def foo(): + pass + + assert inspect.getsource(foo) == '''\ +@wrapper +@wrapper +def foo(): + pass +''' diff --git a/tests/test_visitors.py b/tests/test_visitors.py new file mode 100644 index 0000000..9f2d33f --- /dev/null +++ b/tests/test_visitors.py @@ -0,0 +1,39 @@ +""" +Test collecting instances of `ast.Name` +""" +import ast +from ast_tools.visitors import collect_names +from ast_tools.visitors import UsedNames + +def test_collect_names_basic(): + """ + Test collecting names from a simple function including the `ctx` feature + """ + s = ''' +def foo(bar, baz): # pylint: disable=blacklisted-name + buzz = bar + baz + return buzz +''' + + foo_ast = ast.parse(s) + assert collect_names(foo_ast) == {"bar", "baz", "buzz"} + assert collect_names(foo_ast, ctx=ast.Load) == {"bar", "baz", "buzz"} + assert collect_names(foo_ast, ctx=ast.Store) == {"buzz"} + + +def test_used_names(): + tree = ast.parse(''' +x = 1 +def foo(): + def g(): + pass +class A: + def __init__(self): pass + + class B: pass + +async def h(): pass +''') + visitor = UsedNames() + visitor.visit(tree) + assert visitor.names == {'x', 'foo', 'A', 'h'}