Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions ast_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
"""
ast_tools top level package
"""
import inspect
import textwrap
import ast

__all__ = ["NameCollector", "collect_names", "get_ast"]

from .collect_names import NameCollector, collect_names
from .common import *
from . import passes
from . import stack


def get_ast(obj):
"""
Given an object, get the corresponding AST
"""
indented_program_txt = inspect.getsource(obj)
program_txt = textwrap.dedent(indented_program_txt)
return ast.parse(program_txt)
from . import visitors
100 changes: 100 additions & 0 deletions ast_tools/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import ast
import functools
import inspect
import logging
import os
import textwrap
import types
import typing as tp
import weakref

import astor

from ast_tools import stack
from ast_tools.stack import SymbolTable
from ast_tools.visitors import UsedNames

__ALL__ = ['exec_in_file', 'exec_def_in_file', 'get_ast', 'gen_free_name']

DefStmt = tp.Union[ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef]

def exec_def_in_file(
tree: DefStmt,
st: SymbolTable,
path: tp.Optional[str] = None,
file_name: tp.Optional[str] = None) -> None:
"""
execs a definition in a file and returns the definiton
"""
if file_name is None:
file_name = tree.name + '.py'

exec_in_file(tree, st, path, file_name)
return st.locals[tree.name]


def exec_in_file(
tree: ast.AST,
st: SymbolTable,
path: tp.Optional[str] = None,
file_name: tp.Optional[str] = None) -> None:

"""
execs a definition in a file
"""
if path is None:
path = '.ast_tools'
if file_name is None:
file_name = f'ast_tools_exec_{hash(tree)}.py'

source = astor.to_source(tree)
file_name = os.path.join(path, file_name)
os.makedirs(path, exist_ok=True)
with open(file_name, 'w') as fp:
fp.write(source)

try:
code = compile(source, filename=file_name, mode='exec')
except Exception as e:
logging.exception("Error compiling source")
raise e from None

try:
exec(code, st.globals, st.locals)
except Exception as e:
logging.exception("Error executing code")
raise e from None


_AST_CACHE = weakref.WeakKeyDictionary()
def get_ast(obj) -> ast.AST:
"""
Given an object, get the corresponding AST
"""
try:
_AST_CACHE[obj]
except KeyError:
pass

src = textwrap.dedent(inspect.getsource(obj))

if isinstance(obj, types.ModuleType):
tree = ast.parse(src)
else:
tree = ast.parse(src).body[0]

return _AST_CACHE.setdefault(obj, tree)


def gen_free_name(tree: ast.AST, env: SymbolTable, prefix: str = '__auto_name_'):
visitor = UsedNames()
visitor.visit(tree)
used_names = visitor.names | env.locals.keys() | env.globals.keys()
f_str = prefix+'{}'
c = 0
name = f_str.format(c)
while name in used_names:
c += 1
name = f_str.format(c)

return name
92 changes: 92 additions & 0 deletions ast_tools/passes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from abc import ABCMeta, abstractmethod
import ast
import astor
import functools
import typing as tp

from ast_tools.stack import get_symbol_table, SymbolTable
from ast_tools.common import get_ast, exec_def_in_file

__ALL__ = ['Pass', 'begin_rewrite', 'end_rewite']

_PASS_ARGS_T = tp.Tuple[ast.AST, SymbolTable]
class Pass(metaclass=ABCMeta):
"""
Abstract base class for passes
Mostly a convience to unpack arguments
"""

def __call__(self, args: _PASS_ARGS_T) -> _PASS_ARGS_T:
return self.rewrite(*args)

@abstractmethod
def rewrite(self,
env: SymbolTable,
tree: ast.AST,
) -> _PASS_ARGS_T:

"""
Type annotation here should be followed except on terminal passes e.g.
end_rewite
"""
pass

class begin_rewrite:
"""
begins a chain of passes
"""
def __init__(self):
env = get_symbol_table([self.__init__])
self.env = env

def __call__(self, fn) -> _PASS_ARGS_T:
tree = get_ast(fn)
return tree, self.env

def _issubclass(t, s) -> bool:
try:
return issubclass(t, s)
except TypeError:
return False


class end_rewrite(Pass):
"""
ends a chain of passes
"""
def __init__(self, path: tp.Optional[str] = None):
self.path = path

def rewrite(self,
tree: ast.AST,
env: SymbolTable) -> tp.Union[tp.Callable, type]:
decorators = []
first_group = True
in_group = False
# filter passes from the decorator list
for node in reversed(tree.decorator_list):
if not first_group:
decorators.append(node)
continue

if isinstance(node, ast.Call):
expr = ast.Expression(node.func)
else:
assert isinstance(node, ast.Name)
expr = ast.Expression(node)

code = compile(expr, filename="<string>", mode="eval")
deco = eval(code, env.globals, env.locals)
if in_group:
if _issubclass(deco, end_rewrite):
assert in_group
in_group = False
first_group = False
elif _issubclass(deco, begin_rewrite):
assert not in_group
in_group = True
else:
decorators.append(node)

tree.decorator_list = reversed(decorators)
return exec_def_in_file(tree, env, self.path)
57 changes: 57 additions & 0 deletions ast_tools/passes/debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import ast
import typing as tp
import warnings

import astor

from . import Pass
from . import _PASS_ARGS_T
from ast_tools.stack import SymbolTable

class debug(Pass):
def __init__(self,
dump_ast: bool = True,
dump_src: bool = False,
dump_env: bool = False,
file: tp.Optional[str] = None,
append: tp.Optional[bool] = None,
) -> _PASS_ARGS_T:
self.dump_ast = dump_ast
self.dump_src = dump_src
self.dump_env = dump_env
if append is not None and file is None:
warnings.warn('Option append has no effect when file is None', stacklevel=2)
self.file = file
self.append = append

def rewrite(self,
tree: ast.AST,
env: SymbolTable,
) -> _PASS_ARGS_T:

def _do_dumps(dumps, dump_writer):
for dump in dumps:
dump_writer(f'BEGIN {dump[0]}\n')
dump_writer(dump[1].strip())
dump_writer(f'\nEND {dump[0]}\n\n')

dumps = []
if self.dump_ast:
dumps.append(('AST', astor.dump_tree(tree)))
if self.dump_src:
dumps.append(('SRC', astor.to_source(tree)))
if self.dump_env:
dumps.append(('ENV', repr(env)))

if self.file is not None:
if self.append:
mode = 'wa'
else:
mode = 'w'
with open(self.dump_file, mode) as fp:
_do_dumps(dumps, fp.write)
else:
def _print(*args, **kwargs): print(*args, end='', **kwargs)
_do_dumps(dumps, _print)

return tree, env
4 changes: 2 additions & 2 deletions ast_tools/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
_SKIP_FRAME_DEBUG_FAIL = False

class SymbolTable(tp.NamedTuple):
locals: tp.Mapping[str, tp.Any]
locals: tp.MutableMapping[str, tp.Any]
globals: tp.Dict[str, tp.Any]

def get_symbol_table(
Expand Down Expand Up @@ -48,7 +48,7 @@ def get_symbol_table(
logging.debug(f'{frame.function} @ {frame.filename}:{frame.lineno} might be leaking names')
locals = locals.new_child(stack[i].frame.f_locals)
globals = globals.new_child(stack[i].frame.f_globals)
return SymbolTable(locals=locals, globals=globals)
return SymbolTable(locals=locals, globals=dict(globals))

def inspect_symbol_table(
fn: tp.Callable, # tp.Callable[[SymbolTable, ...], tp.Any],
Expand Down
3 changes: 3 additions & 0 deletions ast_tools/visitors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .used_names import *
from .collect_names import *

File renamed without changes.
17 changes: 17 additions & 0 deletions ast_tools/visitors/used_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import ast

class UsedNames(ast.NodeVisitor):
def __init__(self):
self.names = set()

def visit_Name(self, node: ast.Name):
self.names.add(node.id)

def visit_FunctionDef(self, node: ast.FunctionDef):
self.names.add(node.name)

def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
self.names.add(node.name)

def visit_ClassDef(self, node: ast.ClassDef):
self.names.add(node.name)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
description='Toolbox for working with the Python AST',
scripts=[],
packages=[],
install_requires=[],
install_requires=['astor'],
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown"
)
19 changes: 0 additions & 19 deletions tests/test_collect_names.py

This file was deleted.

33 changes: 33 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import ast
import astor

from ast_tools.common import get_ast, gen_free_name, SymbolTable


def test_get_ast():
def f(): pass
f_str = 'def f(): pass'
ast_str_0 = astor.dump_tree(get_ast(f))
ast_str_1 = astor.dump_tree(ast.parse(f_str).body[0])
assert ast_str_0 == ast_str_1


def test_gen_free():
src = '''
class P0:
P5 = 1
def __init__(self): self.y = 0
def P1():
return P0.P5
P2 = P1()
'''
tree = ast.parse(src)
env = SymbolTable({}, {})
free_name = gen_free_name(tree, env, prefix='P')
assert free_name == 'P3'
env = SymbolTable({'P4': 'foo'}, {})
free_name = gen_free_name(tree, env, prefix='P')
assert free_name == 'P3'
env = SymbolTable({'P4': 'foo'}, {'P3' : 'bar'})
free_name = gen_free_name(tree, env, prefix='P')
assert free_name == 'P5'
Loading