Skip to content

Commit

Permalink
Refactored for spring cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
pmaupin committed May 19, 2015
1 parent bf78092 commit c479d88
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 48 deletions.
37 changes: 29 additions & 8 deletions astor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,32 @@
"""

__version__ = '0.5'

from .codegen import to_source # NOQA
from .misc import iter_node, dump, all_symbols, get_anyop # NOQA
from .misc import get_boolop, get_binop, get_cmpop, get_unaryop # NOQA
from .misc import ExplicitNodeVisitor # NOQA
from .misc import parsefile, CodeToAst, codetoast # NOQA
from .treewalk import TreeWalk # NOQA
__version__ = '0.6'

from .code_gen import to_source # NOQA
from .node_util import iter_node, strip_tree, dump_tree
from .node_util import ExplicitNodeVisitor
from .file_util import CodeToAst, code_to_ast # NOQA
from .op_util import get_op_symbol, get_op_precedence # NOQA
from .op_util import symbol_data
from .tree_walk import TreeWalk # NOQA


#DEPRECATED!!!

# These aliases support old programs. Please do not use in future.

# NOTE: We should think hard about what we want to export,
# and not just dump everything here. Some things
# will never be used by other packages, and other
# things could be accessed from their submodule.


get_boolop = get_binop = get_cmpop = get_unaryop = get_op_symbol # NOQA
get_anyop = get_op_symbol
parsefile = code_to_ast.parse_file
codetoast = code_to_ast
dump = dump_tree
all_symbols = symbol_data
treewalk = tree_walk
codegen = code_gen
28 changes: 17 additions & 11 deletions astor/code_gen.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# -*- coding: utf-8 -*-
"""
This module converts an AST into Python source code.
Part of the astor library for Python AST manipulation.
License: 3-clause BSD
Original code copyright (c) 2008 by Armin Ronacher and
is distributed under the 3-clause BSD license.
Copyright (c) 2008 Armin Ronacher
Copyright (c) 2012-2015 Patrick Maupin
Copyright (c) 2013-2015 Berker Peksag
This module converts an AST into Python source code.
It was derived from a modified version found here:
Before being version-controlled as part of astor,
this code came from here (in 2012):
https://gist.github.com/1250562
Expand All @@ -14,8 +20,8 @@
import ast
import sys

from .misc import (ExplicitNodeVisitor, get_boolop, get_binop, get_cmpop,
get_unaryop)
from .op_util import get_op_symbol
from .node_util import ExplicitNodeVisitor


def to_source(node, indent_with=' ' * 4, add_line_information=False):
Expand Down Expand Up @@ -160,7 +166,7 @@ def visit_Assign(self, node):
self.visit(node.value)

def visit_AugAssign(self, node):
self.statement(node, node.target, get_binop(node.op, ' %s= '),
self.statement(node, node.target, get_op_symbol(node.op, ' %s= '),
node.value)

def visit_ImportFrom(self, node):
Expand Down Expand Up @@ -439,23 +445,23 @@ def visit_Dict(self, node):

@enclose('()')
def visit_BinOp(self, node):
self.write(node.left, get_binop(node.op, ' %s '), node.right)
self.write(node.left, get_op_symbol(node.op, ' %s '), node.right)

@enclose('()')
def visit_BoolOp(self, node):
op = get_boolop(node.op, ' %s ')
op = get_op_symbol(node.op, ' %s ')
for idx, value in enumerate(node.values):
self.write(idx and op or '', value)

@enclose('()')
def visit_Compare(self, node):
self.visit(node.left)
for op, right in zip(node.ops, node.comparators):
self.write(get_cmpop(op, ' %s '), right)
self.write(get_op_symbol(op, ' %s '), right)

@enclose('()')
def visit_UnaryOp(self, node):
self.write(get_unaryop(node.op), ' ', node.operand)
self.write(get_op_symbol(node.op), ' ', node.operand)

def visit_Subscript(self, node):
self.write(node.value, '[', node.slice, ']')
Expand Down
82 changes: 62 additions & 20 deletions astor/file_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,16 @@
License: 3-clause BSD
Copyright 2012 (c) Patrick Maupin
Copyright 2013 (c) Berker Peksag
Copyright 2012-2015 (c) Patrick Maupin
Copyright 2013-2015 (c) Berker Peksag
Functions that interact with the filesystem go here.
"""

import ast
import sys


def parsefile(fname):
with open(fname, 'r') as f:
fstr = f.read()
fstr = fstr.replace('\r\n', '\n').replace('\r', '\n')
if not fstr.endswith('\n'):
fstr += '\n'
return ast.parse(fstr, filename=fname)
import os


class CodeToAst(object):
Expand All @@ -28,28 +22,76 @@ class CodeToAst(object):
the sub-AST for the function. Allow caching to reduce
number of compiles.
Also contains static helper utility functions to
look for python files, to parse python files, and to extract
the file/line information from a code object.
"""
def __init__(self, cache=None):
self.cache = cache or {}

def __call__(self, codeobj):
cache = self.cache
@staticmethod
def find_py_files(srctree, ignore=None):
"""Return all the python files in a source tree
Ignores any path that contains the ignore string
This is not used by other class methods, but is
designed to be used in code that uses this class.
"""

for srcpath, _, fnames in os.walk(srctree):
# Avoid infinite recursion for silly users
if ignore is not None and ignore in srcpath:
continue
for fname in (x for x in fnames if x.endswith('.py')):
yield srcpath, fname

@staticmethod
def parse_file(fname):
"""Parse a python file into an AST.
This is a very thin wrapper around ast.parse
TODO: Handle encodings other than the default (issue #26)
"""
with open(fname, 'r') as f:
fstr = f.read()
fstr = fstr.replace('\r\n', '\n').replace('\r', '\n')
if not fstr.endswith('\n'):
fstr += '\n'
return ast.parse(fstr, filename=fname)


@staticmethod
def get_file_info(codeobj):
"""Returns the file and line number of a code object.
If the code object has a __file__ attribute (e.g. if
it is a module), then the returned line number will
be 0
"""
fname = getattr(codeobj, '__file__', None)
linenum = 0
if fname is None:
func_code = codeobj.__code__
fname = func_code.co_filename
linenum = func_code.co_firstlineno
key = fname, linenum
else:
fname = key = fname.replace('.pyc', '.py')
fname = fname.replace('.pyc', '.py')
return fname, linenum

def __init__(self, cache=None):
self.cache = cache or {}

def __call__(self, codeobj):
cache = self.cache
key = self.get_file_info(codeobj)
result = cache.get(key)
if result is not None:
return result
cache[fname] = mod_ast = parsefile(fname)
fname = key[0]
cache[(fname, 0)] = mod_ast = self.parse_file(fname)
for obj in mod_ast.body:
if not isinstance(obj, ast.FunctionDef):
continue
cache[(fname, obj.lineno)] = obj
return cache[key]

codetoast = CodeToAst()
code_to_ast = CodeToAst()
74 changes: 67 additions & 7 deletions astor/node_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,26 @@
License: 3-clause BSD
Copyright 2012 (c) Patrick Maupin
Copyright 2013 (c) Berker Peksag
Copyright 2012-2015 (c) Patrick Maupin
Copyright 2013-2015 (c) Berker Peksag
Utilities for node (and, by extension, tree) manipulation.
For a whole-tree approach, see the treewalk submodule.
"""

import ast
import sys


class NonExistent(object):
"""This is not the class you are looking for.
"""
pass


def iter_node(node, name='', list=list, getattr=getattr, isinstance=isinstance,
def iter_node(node, name='',
# Runtime optimization
unknown=None, list=list, getattr=getattr, isinstance=isinstance,
enumerate=enumerate, missing=NonExistent):
"""Iterates over an object:
Expand All @@ -30,20 +36,26 @@ def iter_node(node, name='', list=list, getattr=getattr, isinstance=isinstance,
in the list, where the name is passed into
this function (defaults to blank).
- Can update a list with information about
attributes that do not exist in fields.
"""
fields = getattr(node, '_fields', None)
if fields is not None:
for name in fields:
for name in list(fields):
value = getattr(node, name, missing)
if value is not missing:
yield value, name
if unknown is not None:
unknown.update(set(vars(node)) - set(fields))
elif isinstance(node, list):
for value in node:
yield value, name


def dump(node, name=None, initial_indent='', indentation=' ',
maxline=120, maxmerged=80, iter_node=iter_node, special=ast.AST,
def dump_tree(node, name=None, initial_indent='', indentation=' ',
maxline=120, maxmerged=80,
#Runtime optimization
iter_node=iter_node, special=ast.AST,
list=list, isinstance=isinstance, type=type, len=len):
"""Dumps an AST or similar structure:
Expand Down Expand Up @@ -74,3 +86,51 @@ def dump(node, name=None, indent=''):
return dump(node, name, initial_indent)


def strip_tree(node,
#Runtime optimization
iter_node=iter_node, special=ast.AST,
list=list, isinstance=isinstance, type=type, len=len):
"""Strips an AST by removing all attributes not in _fields.
Returns a set of the names of all attributes stripped.
This canonicalizes two trees for comparison purposes.
"""
stripped = set()
def strip(node, indent):
unknown = set()
leaf = True
for subnode, _ in iter_node(node, unknown=unknown):
leaf = False
strip(subnode, indent + ' ')
if leaf:
if isinstance(node, special):
unknown = set(vars(node))
stripped.update(unknown)
for name in unknown:
delattr(node, name)
if hasattr(node, 'ctx'):
delattr(node, 'ctx')
if 'ctx' in node._fields:
mylist = list(node._fields)
mylist.remove('ctx')
node._fields = mylist
strip(node, '')
return stripped


class ExplicitNodeVisitor(ast.NodeVisitor):
"""This expands on the ast module's NodeVisitor class
to remove any implicit visits.
"""

def abort_visit(node): # XXX: self?
msg = 'No defined handler for node of type %s'
raise AttributeError(msg % node.__class__.__name__)

def visit(self, node, abort=abort_visit):
"""Visit a node."""
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, abort)
return visitor(node)
Loading

0 comments on commit c479d88

Please sign in to comment.