diff --git a/ast_tools/__init__.py b/ast_tools/__init__.py index dbb430a..10033e3 100644 --- a/ast_tools/__init__.py +++ b/ast_tools/__init__.py @@ -2,6 +2,7 @@ ast_tools top level package """ from .common import * +from . import immutable_ast from . import passes from . import stack from . import visitors diff --git a/setup.py b/setup.py index b4e20d6..31d11f8 100644 --- a/setup.py +++ b/setup.py @@ -1,27 +1,67 @@ -""" +''' setup script for package -""" +''' from setuptools import setup +from setuptools.command.build_py import build_py +from setuptools.command.develop import develop +from os import path +import util.generate_ast.generate as generate -with open("README.md", "r") as fh: +PACKAGE_NAME = 'ast_tools' + +with open('README.md', "r") as fh: LONG_DESCRIPTION = fh.read() +class Install(build_py): + def run(self, *args, **kwargs): + self.generated_outputs = [] + if not self.dry_run: + src = generate.generate_immutable_ast() + output_dir = path.join(self.build_lib, PACKAGE_NAME) + self.mkpath(output_dir) + output_file = path.join(output_dir, 'immutable_ast.py') + self.announce(f'generating {output_file}', 2) + with open(output_file, 'w') as f: + f.write(src) + self.generated_outputs.append(output_file) + super().run(*args, **kwargs) + + def get_outputs(self, *args, **kwargs): + outputs = super().get_outputs(*args, **kwargs) + outputs.extend(self.generated_outputs) + return outputs + + +class Develop(develop): + def run(self, *args, **kwargs): + if not self.dry_run: + src = generate.generate_immutable_ast() + output_file = path.join(PACKAGE_NAME, 'immutable_ast.py') + self.announce(f'generating {output_file}', 2) + with open(output_file, 'w') as f: + f.write(src) + super().run(*args, **kwargs) + setup( + cmdclass={ + 'build_py': Install, + 'develop': Develop, + }, name='ast_tools', url='https://github.com/leonardt/ast_tools', author='Leonard Truong', author_email='lenny@cs.stanford.edu', - version='0.0.5', + version='0.0.6', description='Toolbox for working with the Python AST', scripts=[], packages=[ - "ast_tools", - "ast_tools.visitors", - "ast_tools.transformers", - "ast_tools.passes" + f'{PACKAGE_NAME}', + f'{PACKAGE_NAME}.visitors', + f'{PACKAGE_NAME}.transformers', + f'{PACKAGE_NAME}.passes' ], install_requires=['astor'], long_description=LONG_DESCRIPTION, - long_description_content_type="text/markdown" + long_description_content_type='text/markdown' ) diff --git a/tests/test_immutable_ast.py b/tests/test_immutable_ast.py new file mode 100644 index 0000000..0d97fbe --- /dev/null +++ b/tests/test_immutable_ast.py @@ -0,0 +1,90 @@ +import pytest +import ast + +import inspect +from ast_tools import immutable_ast +from ast_tools.immutable_ast import ImmutableMeta + + +trees = [] + +# inspect is about the largest module I know +# hopefully it has a diverse ast +for mod in (immutable_ast, inspect, ast, pytest): + with open(mod.__file__, 'r') as f: + text = f.read() + tree = ast.parse(text) + trees.append(tree) + + +@pytest.mark.parametrize("tree", trees) +def test_mutable_to_immutable(tree): + def _test(tree, itree): + if isinstance(tree, ast.AST): + assert isinstance(itree, immutable_ast.AST) + assert isinstance(tree, type(itree)) + assert tree._fields == itree._fields + assert ImmutableMeta._mutable_to_immutable[type(tree)] is type(itree) + for field, value in ast.iter_fields(tree): + _test(value, getattr(itree, field)) + elif isinstance(tree, list): + assert isinstance(itree, tuple) + assert len(tree) == len(itree) + for c, ic in zip(tree, itree): + _test(c, ic) + else: + assert tree == itree + + + itree = immutable_ast.immutable(tree) + _test(tree, itree) + +@pytest.mark.parametrize("tree", trees) +def test_immutable_to_mutable(tree): + def _test(tree, mtree): + assert type(tree) is type(mtree) + if isinstance(tree, ast.AST): + for field, value in ast.iter_fields(tree): + _test(value, getattr(mtree, field)) + elif isinstance(tree, list): + assert len(tree) == len(mtree) + for c, mc in zip(tree, mtree): + _test(c, mc) + else: + assert tree == mtree + + itree = immutable_ast.immutable(tree) + mtree = immutable_ast.mutable(itree) + _test(tree, mtree) + + +@pytest.mark.parametrize("tree", trees) +def test_eq(tree): + itree = immutable_ast.immutable(tree) + jtree = immutable_ast.immutable(tree) + assert itree == jtree + assert hash(itree) == hash(jtree) + +def test_mutate(): + node = immutable_ast.Name(id='foo', ctx=immutable_ast.Load()) + # can add metadata to a node + node.random = 0 + del node.random + + # but cant change its fields + for field in node._fields: + with pytest.raises(AttributeError): + setattr(node, field, 'bar') + + with pytest.raises(AttributeError): + delattr(node, field) + + +def test_construct_from_mutable(): + node = immutable_ast.Module([ + ast.Name(id='foo', ctx=ast.Store()) + ]) + + assert isinstance(node.body, tuple) + assert type(node.body[0]) is immutable_ast.Name + assert type(node.body[0].ctx) is immutable_ast.Store diff --git a/util/generate_ast/__init__.py b/util/generate_ast/__init__.py new file mode 100644 index 0000000..8333482 --- /dev/null +++ b/util/generate_ast/__init__.py @@ -0,0 +1,4 @@ +if __name__ == '__main__': + import generate + print(generate.generate_immutable_ast()) + diff --git a/util/generate_ast/_base.px b/util/generate_ast/_base.px new file mode 100644 index 0000000..dcd8247 --- /dev/null +++ b/util/generate_ast/_base.px @@ -0,0 +1,45 @@ +class AST(mutable=ast.AST, metaclass=ImmutableMeta): + def __setattr__(self, attr, value): + if attr in self._fields and hasattr(self, attr): + raise AttributeError('Cannot modify ImmutableAST fields') + elif isinstance(value, (list, ast.AST)): + value = immutable(value) + + self.__dict__[attr] = value + + def __delattr__(self, attr): + if attr in self._fields: + raise AttributeError('Cannot modify ImmutableAST fields') + del self.__dict__[attr] + + def __hash__(self): + try: + return self._hash_ + except AttributeError: + pass + + h = hash(type(self)) + for _, n in iter_fields(self): + if isinstance(n, AST): + h += hash(n) + elif isinstance(n, tp.Sequence): + for c in n: + h += hash(c) + else: + h += hash(n) + self._hash_ = h + return h + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + elif type(self) == type(other): + for f in self._fields: + if getattr(self, f) != getattr(other, f): + return False + return True + else: + return False + + def __ne__(self, other): + return not (self == other) diff --git a/util/generate_ast/_functions.px b/util/generate_ast/_functions.px new file mode 100644 index 0000000..1c52de1 --- /dev/null +++ b/util/generate_ast/_functions.px @@ -0,0 +1,113 @@ +__ALL__ += ['immutable', 'mutable', 'parse', 'dump', + 'iter_fields', 'iter_child_nodes', 'walk', + 'NodeVisitor', 'NodeTransformer'] + + +def _cast_tree(seq_t, n_seq_t, type_look_up, tree): + args = seq_t, n_seq_t, type_look_up + + if isinstance(tree, seq_t): + return n_seq_t(_cast_tree(*args, c) for c in tree) + + try: + T = type_look_up[type(tree)] + except KeyError: + return tree + + kwargs = {} + for field, c in iter_fields(tree): + kwargs[field] = _cast_tree(*args, c) + + return T(**kwargs) + + +def immutable(tree: ast.AST) -> 'AST': + '''Converts a mutable ast to an immutable one''' + return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree) + +def mutable(tree: 'AST') -> ast.AST: + '''Converts an immutable ast to a mutable one''' + return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree) + +def parse(source, filename='', mode='exec') -> 'AST': + tree = ast.parse(source, filename, mode) + return immutable(tree) + +def dump(node, annotate_fields=True, include_attributes=False) -> str: + tree = mutable(node) + return ast.dump(tree) + + +# duck typing ftw +iter_fields = ast.iter_fields + +# The following is more or less copied verbatim from +# CPython/Lib/ast.py. Changes are: +# s/list/tuple/ +# +# The CPython license is very permissive so I am pretty sure this is cool. +# If it is not Guido please forgive me. +def iter_child_nodes(node): + for name, field in iter_fields(node): + if isinstance(field, AST): + yield field + elif isinstance(field, tuple): + for item in field: + if isinstance(item, AST): + yield item + +# Same note as above +def walk(node): + from collections import deque + todo = deque([node]) + while todo: + node = todo.popleft() + todo.extend(iter_child_nodes(node)) + yield node + + +# Same note as above +class NodeVisitor: + def visit(self, node): + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node): + for field, value in iter_fields(node): + if isinstance(value, tuple): + for item in value: + if isinstance(item, AST): + self.visit(item) + elif isinstance(value, AST): + self.visit(value) + + +# Same note as above +class NodeTransformer(NodeVisitor): + ''' + Mostly equivalent to ast.NodeTransformer, except returns new nodes + instead of mutating them in place + ''' + + def generic_visit(self, node): + kwargs = {} + for field, old_value in iter_fields(node): + if instance(old_value, tuple): + new_value = [] + for value in old_value: + if isinstance(value, AST): + value = self.visit(value) + if value is None: + continue + elif not isinstance(item, AST): + new_value.extend(value) + continue + new_value.append(value) + new_value = tuple(new_value) + elif isinstance(type(old_value), ImmutableMeta): + new_value = self.visit(old_value) + else: + new_value = old_value + kwargs[field] = new_value + return type(node)(**kwargs) diff --git a/util/generate_ast/_meta.px b/util/generate_ast/_meta.px new file mode 100644 index 0000000..5009fd4 --- /dev/null +++ b/util/generate_ast/_meta.px @@ -0,0 +1,19 @@ +__ALL__ += ['ImmutableMeta'] + +class ImmutableMeta(type): + _immutable_to_mutable = dict() + _mutable_to_immutable = dict() + def __new__(mcs, name, bases, namespace, mutable, **kwargs): + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + ImmutableMeta._immutable_to_mutable[cls] = mutable + ImmutableMeta._mutable_to_immutable[mutable] = cls + + return cls + + def __instancecheck__(cls, instance): + return super().__instancecheck__(instance)\ + or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls]) + + def __subclasscheck__(cls, type_): + return super().__subclasscheck__(type_)\ + or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls]) diff --git a/util/generate_ast/generate.py b/util/generate_ast/generate.py new file mode 100644 index 0000000..4244aad --- /dev/null +++ b/util/generate_ast/generate.py @@ -0,0 +1,132 @@ +import ast +import datetime +import inspect +from os import path +import sys + +_BASE_PATH = path.dirname(__file__) +def _make_path(f): + return path.abspath(path.join(_BASE_PATH, f)) + +META_FILE = _make_path('_meta.px') +FUNCTIONS_FILE = _make_path('_functions.px') +AST_BASE_FILE = _make_path('_base.px') +TAB = ' '*4 + +def generate_class(name, bases, fields): + bases=', '.join(bases) + (', ' if bases else '') + sig = (', ' if fields else '') + ', '.join(fields) + body = [f'self.{arg} = {arg}' for arg in fields] + if not body: + body.append('pass') + + body = f'\n{TAB}{TAB}'.join(body) + + class_ = f'''\ +class {name}({bases}mutable=ast.{name}): +{TAB}_fields={fields} +{TAB}def __init__(self{sig}): +{TAB}{TAB}{body} +''' + + return class_ + +def generate_classes(class_tree, ALL): + cls_to_args = {ast.AST : ('AST', (), ())} + + def pop_args_from_tree(tree): + for item in tree: + if isinstance(item, list): + r = pop_args_from_tree(item) + if r is not None: + return r + elif item[0] not in cls_to_args: + cls = item[0] + bases = tuple(cls_to_args[base][0] for base in item[1] if base is not object) + cls_to_args[cls] = r = cls.__name__, bases, cls._fields + return r + + classes_ = [] + + args = pop_args_from_tree(class_tree[1]) + while args is not None: + class_ = generate_class(*args) + classes_.append(class_) + ALL.append(args[0]) + args = pop_args_from_tree(class_tree[1]) + + return '\n'.join(classes_) + + +def generate_immutable_ast(): + def _issubclass(t, types): + try: + return issubclass(t, types) + except TypeError: + pass + return False + + classes = [] + for name in dir(ast): + obj = getattr(ast, name) + if _issubclass(obj, ast.AST): + classes.append(obj) + + class_tree = inspect.getclasstree(classes) + # assert the class tree is a tree and not a dag + assert class_tree == inspect.getclasstree(classes, unique=True) + # assert the class tree has a root + assert len(class_tree) == 2, class_tree[0] + # assert the root is object + assert class_tree[0][0] is object, class_tree[0][0] + # assert the root has only 1 child + assert len(class_tree[1]) == 2, class_tree[1] + # assert that the child is ast.AST + assert class_tree[1][0][0] is ast.AST, class_tree[1][0] + + nl = '\n' + head_comment = f'''\ +# file generated by {__file__} on {datetime.datetime.now()} +# for python {sys.version.split(nl)[0].strip()}''' + + version_check = f'''\ +if sys.version_info[:2] != {sys.version_info[:2]}: +{TAB}warnings.warn(f"{{__file__}} generated for {sys.version_info[:2]}" +{TAB} f"does not match system version {{sys.version_info[:2]}}")''' + + + + ALL = ['AST'] + + with open(FUNCTIONS_FILE, 'r') as f: + functions = f.read() + + with open(META_FILE, 'r') as f: + meta = f.read() + + with open(AST_BASE_FILE, 'r') as f: + ast_base = f.read() + + classes = generate_classes(class_tree, ALL) + + immutable_ast = f'''\ +{head_comment} + +import ast +import sys +import typing as tp +import warnings + +{version_check} + +__ALL__ = {ALL} + +{functions} + +{meta} + +{ast_base} + +{classes} +''' + return immutable_ast