Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ast_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
ast_tools top level package
"""
from .common import *
from . import immutable_ast
from . import passes
from . import stack
from . import visitors
58 changes: 49 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,67 @@
"""
'''
setup script for package
"""
'''

from setuptools import setup
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from os import path
import util.generate_ast.generate as generate

with open("README.md", "r") as fh:
PACKAGE_NAME = 'ast_tools'

with open('README.md', "r") as fh:
LONG_DESCRIPTION = fh.read()

class Install(build_py):
def run(self, *args, **kwargs):
self.generated_outputs = []
if not self.dry_run:
src = generate.generate_immutable_ast()
output_dir = path.join(self.build_lib, PACKAGE_NAME)
self.mkpath(output_dir)
output_file = path.join(output_dir, 'immutable_ast.py')
self.announce(f'generating {output_file}', 2)
with open(output_file, 'w') as f:
f.write(src)
self.generated_outputs.append(output_file)
super().run(*args, **kwargs)

def get_outputs(self, *args, **kwargs):
outputs = super().get_outputs(*args, **kwargs)
outputs.extend(self.generated_outputs)
return outputs


class Develop(develop):
def run(self, *args, **kwargs):
if not self.dry_run:
src = generate.generate_immutable_ast()
output_file = path.join(PACKAGE_NAME, 'immutable_ast.py')
self.announce(f'generating {output_file}', 2)
with open(output_file, 'w') as f:
f.write(src)
super().run(*args, **kwargs)

setup(
cmdclass={
'build_py': Install,
'develop': Develop,
},
name='ast_tools',
url='https://github.com/leonardt/ast_tools',
author='Leonard Truong',
author_email='lenny@cs.stanford.edu',
version='0.0.5',
version='0.0.6',
description='Toolbox for working with the Python AST',
scripts=[],
packages=[
"ast_tools",
"ast_tools.visitors",
"ast_tools.transformers",
"ast_tools.passes"
f'{PACKAGE_NAME}',
f'{PACKAGE_NAME}.visitors',
f'{PACKAGE_NAME}.transformers',
f'{PACKAGE_NAME}.passes'
],
install_requires=['astor'],
long_description=LONG_DESCRIPTION,
long_description_content_type="text/markdown"
long_description_content_type='text/markdown'
)
90 changes: 90 additions & 0 deletions tests/test_immutable_ast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import ast

import inspect
from ast_tools import immutable_ast
from ast_tools.immutable_ast import ImmutableMeta


trees = []

# inspect is about the largest module I know
# hopefully it has a diverse ast
for mod in (immutable_ast, inspect, ast, pytest):
with open(mod.__file__, 'r') as f:
text = f.read()
tree = ast.parse(text)
trees.append(tree)


@pytest.mark.parametrize("tree", trees)
def test_mutable_to_immutable(tree):
def _test(tree, itree):
if isinstance(tree, ast.AST):
assert isinstance(itree, immutable_ast.AST)
assert isinstance(tree, type(itree))
assert tree._fields == itree._fields
assert ImmutableMeta._mutable_to_immutable[type(tree)] is type(itree)
for field, value in ast.iter_fields(tree):
_test(value, getattr(itree, field))
elif isinstance(tree, list):
assert isinstance(itree, tuple)
assert len(tree) == len(itree)
for c, ic in zip(tree, itree):
_test(c, ic)
else:
assert tree == itree


itree = immutable_ast.immutable(tree)
_test(tree, itree)

@pytest.mark.parametrize("tree", trees)
def test_immutable_to_mutable(tree):
def _test(tree, mtree):
assert type(tree) is type(mtree)
if isinstance(tree, ast.AST):
for field, value in ast.iter_fields(tree):
_test(value, getattr(mtree, field))
elif isinstance(tree, list):
assert len(tree) == len(mtree)
for c, mc in zip(tree, mtree):
_test(c, mc)
else:
assert tree == mtree

itree = immutable_ast.immutable(tree)
mtree = immutable_ast.mutable(itree)
_test(tree, mtree)


@pytest.mark.parametrize("tree", trees)
def test_eq(tree):
itree = immutable_ast.immutable(tree)
jtree = immutable_ast.immutable(tree)
assert itree == jtree
assert hash(itree) == hash(jtree)

def test_mutate():
node = immutable_ast.Name(id='foo', ctx=immutable_ast.Load())
# can add metadata to a node
node.random = 0
del node.random

# but cant change its fields
for field in node._fields:
with pytest.raises(AttributeError):
setattr(node, field, 'bar')

with pytest.raises(AttributeError):
delattr(node, field)


def test_construct_from_mutable():
node = immutable_ast.Module([
ast.Name(id='foo', ctx=ast.Store())
])

assert isinstance(node.body, tuple)
assert type(node.body[0]) is immutable_ast.Name
assert type(node.body[0].ctx) is immutable_ast.Store
4 changes: 4 additions & 0 deletions util/generate_ast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
if __name__ == '__main__':
import generate
print(generate.generate_immutable_ast())

45 changes: 45 additions & 0 deletions util/generate_ast/_base.px
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
class AST(mutable=ast.AST, metaclass=ImmutableMeta):
def __setattr__(self, attr, value):
if attr in self._fields and hasattr(self, attr):
raise AttributeError('Cannot modify ImmutableAST fields')
elif isinstance(value, (list, ast.AST)):
value = immutable(value)

self.__dict__[attr] = value

def __delattr__(self, attr):
if attr in self._fields:
raise AttributeError('Cannot modify ImmutableAST fields')
del self.__dict__[attr]

def __hash__(self):
try:
return self._hash_
except AttributeError:
pass

h = hash(type(self))
for _, n in iter_fields(self):
if isinstance(n, AST):
h += hash(n)
elif isinstance(n, tp.Sequence):
for c in n:
h += hash(c)
else:
h += hash(n)
self._hash_ = h
return h

def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this just be False?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will try the other way (inverse) which may be okay too, this seems reasonable.

elif type(self) == type(other):
for f in self._fields:
if getattr(self, f) != getattr(other, f):
return False
return True
else:
return False

def __ne__(self, other):
return not (self == other)
113 changes: 113 additions & 0 deletions util/generate_ast/_functions.px
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
__ALL__ += ['immutable', 'mutable', 'parse', 'dump',
'iter_fields', 'iter_child_nodes', 'walk',
'NodeVisitor', 'NodeTransformer']


def _cast_tree(seq_t, n_seq_t, type_look_up, tree):
args = seq_t, n_seq_t, type_look_up

if isinstance(tree, seq_t):
return n_seq_t(_cast_tree(*args, c) for c in tree)

try:
T = type_look_up[type(tree)]
except KeyError:
return tree

kwargs = {}
for field, c in iter_fields(tree):
kwargs[field] = _cast_tree(*args, c)

return T(**kwargs)


def immutable(tree: ast.AST) -> 'AST':
'''Converts a mutable ast to an immutable one'''
return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree)

def mutable(tree: 'AST') -> ast.AST:
'''Converts an immutable ast to a mutable one'''
return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree)

def parse(source, filename='<unknown>', mode='exec') -> 'AST':
tree = ast.parse(source, filename, mode)
return immutable(tree)

def dump(node, annotate_fields=True, include_attributes=False) -> str:
tree = mutable(node)
return ast.dump(tree)


# duck typing ftw
iter_fields = ast.iter_fields

# The following is more or less copied verbatim from
# CPython/Lib/ast.py. Changes are:
# s/list/tuple/
#
# The CPython license is very permissive so I am pretty sure this is cool.
# If it is not Guido please forgive me.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😂

def iter_child_nodes(node):
for name, field in iter_fields(node):
if isinstance(field, AST):
yield field
elif isinstance(field, tuple):
for item in field:
if isinstance(item, AST):
yield item

# Same note as above
def walk(node):
from collections import deque
todo = deque([node])
while todo:
node = todo.popleft()
todo.extend(iter_child_nodes(node))
yield node


# Same note as above
class NodeVisitor:
def visit(self, node):
method = 'visit_' + node.__class__.__name__
visitor = getattr(self, method, self.generic_visit)
return visitor(node)

def generic_visit(self, node):
for field, value in iter_fields(node):
if isinstance(value, tuple):
for item in value:
if isinstance(item, AST):
self.visit(item)
elif isinstance(value, AST):
self.visit(value)


# Same note as above
class NodeTransformer(NodeVisitor):
'''
Mostly equivalent to ast.NodeTransformer, except returns new nodes
instead of mutating them in place
'''

def generic_visit(self, node):
kwargs = {}
for field, old_value in iter_fields(node):
if instance(old_value, tuple):
new_value = []
for value in old_value:
if isinstance(value, AST):
value = self.visit(value)
if value is None:
continue
elif not isinstance(item, AST):
new_value.extend(value)
continue
new_value.append(value)
new_value = tuple(new_value)
elif isinstance(type(old_value), ImmutableMeta):
new_value = self.visit(old_value)
else:
new_value = old_value
kwargs[field] = new_value
return type(node)(**kwargs)
19 changes: 19 additions & 0 deletions util/generate_ast/_meta.px
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
__ALL__ += ['ImmutableMeta']

class ImmutableMeta(type):
_immutable_to_mutable = dict()
_mutable_to_immutable = dict()
def __new__(mcs, name, bases, namespace, mutable, **kwargs):
cls = super().__new__(mcs, name, bases, namespace, **kwargs)
ImmutableMeta._immutable_to_mutable[cls] = mutable
ImmutableMeta._mutable_to_immutable[mutable] = cls

return cls

def __instancecheck__(cls, instance):
return super().__instancecheck__(instance)\
or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls])

def __subclasscheck__(cls, type_):
return super().__subclasscheck__(type_)\
or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls])
Loading