From e8c06c787e2352a8d9bd7c347116d6c12cac91e4 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 30 Aug 2019 10:59:20 -0700 Subject: [PATCH 1/7] Add immutable ast --- ast_tools/__init__.py | 1 + ast_tools/_immutable_ast.py | 239 +++++++++++ ast_tools/immutable_ast.py | 785 ++++++++++++++++++++++++++++++++++++ 3 files changed, 1025 insertions(+) create mode 100644 ast_tools/_immutable_ast.py create mode 100644 ast_tools/immutable_ast.py diff --git a/ast_tools/__init__.py b/ast_tools/__init__.py index dbb430a..10033e3 100644 --- a/ast_tools/__init__.py +++ b/ast_tools/__init__.py @@ -2,6 +2,7 @@ ast_tools top level package """ from .common import * +from . import immutable_ast from . import passes from . import stack from . import visitors diff --git a/ast_tools/_immutable_ast.py b/ast_tools/_immutable_ast.py new file mode 100644 index 0000000..de941f5 --- /dev/null +++ b/ast_tools/_immutable_ast.py @@ -0,0 +1,239 @@ +import functools as ft +import ast + +__ALL__ = ['ImmutableMeta', 'immutable', 'mutable'] + +class ImmutableMeta(type): + _immutable_to_mutable = dict() + _mutable_to_immutable = dict() + def __new__(mcs, name, bases, namespace, mutable, **kwargs): + def __setattr__(self, attr, value): + if attr in self._fields and hasattr(self, attr): + raise AttributeError('Cannot modify ImmutableAST fields') + elif isinstance(value, (list, ast.AST)): + value = immutable(value) + + self.__dict__[attr] = value + + def __delattr__(self, attr): + if attr in self._fields: + raise AttributeError('Cannot modify ImmutableAST fields') + del self.__dict__[attr] + + def __hash__(self): + try: + return self._hash_ + except AttributeError: + pass + + h = hash(type(self)) + for _, n in ast.iter_fields(self): + if isinstance(type(n), ImmutableMeta): + h += hash(n) + elif isinstance(n, tp.Sequence): + for c in n: + h += hash(c) + else: + h += hash(n) + self._hash_ = h + return h + + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + elif type(self) == type(other): + for f in self._fields: + if getattr(self, f) != getattr(other, f): + return False + return True + else: + return False + + + def __ne__(self, other): + return not (self == other) + + namespace['__setattr__'] = __setattr__ + namespace['__delattr__'] = __delattr__ + namespace['__hash__'] = __hash__ + namespace['__eq__'] = __eq__ + namespace['__ne__'] = __ne__ + + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + + ImmutableMeta._immutable_to_mutable[cls] = mutable + ImmutableMeta._mutable_to_immutable[mutable] = cls + + return cls + + def __instancecheck__(cls, instance): + return super().__instancecheck__(instance)\ + or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls]) + + def __subclasscheck__(cls, type_): + return super().__subclasscheck__(type_)\ + or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls]) + + +def _cast_tree(seq_t, n_seq_t, type_look_up, tree): + args = seq_t, n_seq_t, type_look_up + + if isinstance(tree, seq_t): + return n_seq_t(_cast_tree(*args, c) for c in tree) + + try: + T = type_look_up[type(tree)] + except KeyError: + return tree + + kwargs = {} + for field, c in ast.iter_fields(tree): + kwargs[field] = _cast_tree(*args, c) + + return T(**kwargs) + + +def immutable(tree: ast.AST): + '''Converts a mutable ast to an immutable one''' + return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree) + +def mutable(tree: 'AST'): + '''Converts an immutable ast to a mutable one''' + return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree) + + +# could actually generate the classes and put them in globals +# but that would make text editors suck (no autocomplete etc) +# so Instead generate the actual file +def _generate_immutable_ast(): + import ast + import inspect + import sys + import datetime + + immutable_ast_template = '''\ +{head_comment} + +import ast +import sys +import warnings + +from ._immutable_ast import * + +{version_check} + +__ALL__ = {ALL} + +{classes} +''' + + class_template = '''\ +class {name}({bases}mutable=ast.{name}{meta}): +{tab}_fields={fields} +{tab}def __init__(self{sig}):''' + + builder_template = '{tab}{tab}self.{arg} = {arg}' + + tab = ' '*4 + + pass_string = f'{tab}{tab}pass' + + nl = '\n' + head_comment = f'''\ +# file generated by {__file__} on {datetime.datetime.now()} +# for python {sys.version.split(nl)[0].strip()}''' + + version_check = f'''\ +if sys.version_info[:2] != {sys.version_info[:2]}: +{tab}warnings.warn(f"{{__file__}} generated for {sys.version_info[:2]}" +{tab} f"does not match system version {{sys.version_info[:2]}}")''' + + def _issubclass(t, types): + try: + return issubclass(t, types) + except TypeError: + pass + return False + + _classes = [] + for _name in dir(ast): + _obj = getattr(ast, _name) + if _issubclass(_obj, ast.AST): + _classes.append(_obj) + + _class_tree = inspect.getclasstree(_classes) + assert _class_tree == inspect.getclasstree(_classes, unique=True) + _cls_to_args = {} + + def _build_cls_from_tree(tree): + for item in tree: + if isinstance(item, list): + r = _build_cls_from_tree(item) + if r is not None: + return r + elif item[0] not in _cls_to_args: + cls = item[0] + bases = tuple(_cls_to_args[base][0] for base in item[1] if base is not object) + _cls_to_args[cls] = r = cls.__name__, bases, cls._fields + return r + + + _class_strings = [] + _all = [] + _args = _build_cls_from_tree(_class_tree[1]) + while _args is not None: + name=_args[0] + bases=', '.join(_args[1]) + if bases != '': + bases += ', ' + meta='' + else: + meta=', metaclass=ImmutableMeta' + + + fields = _args[2] + + if fields: + sig = ', ' + else: + sig = '' + sig += ', '.join(fields) + + class_ = [class_template.format( + tab=tab, + name=name, + bases=bases, + meta=meta, + fields=fields, + sig=sig, + )] + + if fields: + for arg in fields: + class_.append(builder_template.format( + tab=tab, + arg=arg, + )) + else: + class_.append(pass_string) + + class_.append('\n') + _class_strings.append('\n'.join(class_)) + _all.append(name) + _args = _build_cls_from_tree(_class_tree[1]) + + + return immutable_ast_template.format( + head_comment=head_comment, + version_check=version_check, + ALL=_all, + classes = '\n'.join(_class_strings) + ) + + +if __name__ == '__main__': + s = _generate_immutable_ast() + with open('immutable_ast.py', 'w') as f: + f.write(s) + diff --git a/ast_tools/immutable_ast.py b/ast_tools/immutable_ast.py new file mode 100644 index 0000000..0753abb --- /dev/null +++ b/ast_tools/immutable_ast.py @@ -0,0 +1,785 @@ +# file generated by ast_tools/_immutable_ast.py on 2019-08-30 12:22:11.404048 +# for python 3.7.3 (default, Apr 3 2019, 05:39:12) + +import ast +import sys +import warnings + +from ._immutable_ast import * + +if sys.version_info[:2] != (3, 7): + warnings.warn(f"{__file__} generated for (3, 7)" + f"does not match system version {sys.version_info[:2]}") + +__ALL__ = ['AST', 'alias', 'arg', 'arguments', 'boolop', 'And', 'Or', 'cmpop', 'Eq', 'Gt', 'GtE', 'In', 'Is', 'IsNot', 'Lt', 'LtE', 'NotEq', 'NotIn', 'comprehension', 'excepthandler', 'ExceptHandler', 'expr', 'Attribute', 'Await', 'BinOp', 'BoolOp', 'Bytes', 'Call', 'Compare', 'Constant', 'Dict', 'DictComp', 'Ellipsis', 'FormattedValue', 'GeneratorExp', 'IfExp', 'JoinedStr', 'Lambda', 'List', 'ListComp', 'Name', 'NameConstant', 'Num', 'Set', 'SetComp', 'Starred', 'Str', 'Subscript', 'Tuple', 'UnaryOp', 'Yield', 'YieldFrom', 'expr_context', 'AugLoad', 'AugStore', 'Del', 'Load', 'Param', 'Store', 'keyword', 'mod', 'Expression', 'Interactive', 'Module', 'Suite', 'operator', 'Add', 'BitAnd', 'BitOr', 'BitXor', 'Div', 'FloorDiv', 'LShift', 'MatMult', 'Mod', 'Mult', 'Pow', 'RShift', 'Sub', 'slice', 'ExtSlice', 'Index', 'Slice', 'stmt', 'AnnAssign', 'Assert', 'Assign', 'AsyncFor', 'AsyncFunctionDef', 'AsyncWith', 'AugAssign', 'Break', 'ClassDef', 'Continue', 'Delete', 'Expr', 'For', 'FunctionDef', 'Global', 'If', 'Import', 'ImportFrom', 'Nonlocal', 'Pass', 'Raise', 'Return', 'Try', 'While', 'With', 'unaryop', 'Invert', 'Not', 'UAdd', 'USub', 'withitem'] + +class AST(mutable=ast.AST, metaclass=ImmutableMeta): + _fields=() + def __init__(self): + pass + + +class alias(AST, mutable=ast.alias): + _fields=('name', 'asname') + def __init__(self, name, asname): + self.name = name + self.asname = asname + + +class arg(AST, mutable=ast.arg): + _fields=('arg', 'annotation') + def __init__(self, arg, annotation): + self.arg = arg + self.annotation = annotation + + +class arguments(AST, mutable=ast.arguments): + _fields=('args', 'vararg', 'kwonlyargs', 'kw_defaults', 'kwarg', 'defaults') + def __init__(self, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults): + self.args = args + self.vararg = vararg + self.kwonlyargs = kwonlyargs + self.kw_defaults = kw_defaults + self.kwarg = kwarg + self.defaults = defaults + + +class boolop(AST, mutable=ast.boolop): + _fields=() + def __init__(self): + pass + + +class And(boolop, mutable=ast.And): + _fields=() + def __init__(self): + pass + + +class Or(boolop, mutable=ast.Or): + _fields=() + def __init__(self): + pass + + +class cmpop(AST, mutable=ast.cmpop): + _fields=() + def __init__(self): + pass + + +class Eq(cmpop, mutable=ast.Eq): + _fields=() + def __init__(self): + pass + + +class Gt(cmpop, mutable=ast.Gt): + _fields=() + def __init__(self): + pass + + +class GtE(cmpop, mutable=ast.GtE): + _fields=() + def __init__(self): + pass + + +class In(cmpop, mutable=ast.In): + _fields=() + def __init__(self): + pass + + +class Is(cmpop, mutable=ast.Is): + _fields=() + def __init__(self): + pass + + +class IsNot(cmpop, mutable=ast.IsNot): + _fields=() + def __init__(self): + pass + + +class Lt(cmpop, mutable=ast.Lt): + _fields=() + def __init__(self): + pass + + +class LtE(cmpop, mutable=ast.LtE): + _fields=() + def __init__(self): + pass + + +class NotEq(cmpop, mutable=ast.NotEq): + _fields=() + def __init__(self): + pass + + +class NotIn(cmpop, mutable=ast.NotIn): + _fields=() + def __init__(self): + pass + + +class comprehension(AST, mutable=ast.comprehension): + _fields=('target', 'iter', 'ifs', 'is_async') + def __init__(self, target, iter, ifs, is_async): + self.target = target + self.iter = iter + self.ifs = ifs + self.is_async = is_async + + +class excepthandler(AST, mutable=ast.excepthandler): + _fields=() + def __init__(self): + pass + + +class ExceptHandler(excepthandler, mutable=ast.ExceptHandler): + _fields=('type', 'name', 'body') + def __init__(self, type, name, body): + self.type = type + self.name = name + self.body = body + + +class expr(AST, mutable=ast.expr): + _fields=() + def __init__(self): + pass + + +class Attribute(expr, mutable=ast.Attribute): + _fields=('value', 'attr', 'ctx') + def __init__(self, value, attr, ctx): + self.value = value + self.attr = attr + self.ctx = ctx + + +class Await(expr, mutable=ast.Await): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class BinOp(expr, mutable=ast.BinOp): + _fields=('left', 'op', 'right') + def __init__(self, left, op, right): + self.left = left + self.op = op + self.right = right + + +class BoolOp(expr, mutable=ast.BoolOp): + _fields=('op', 'values') + def __init__(self, op, values): + self.op = op + self.values = values + + +class Bytes(expr, mutable=ast.Bytes): + _fields=('s',) + def __init__(self, s): + self.s = s + + +class Call(expr, mutable=ast.Call): + _fields=('func', 'args', 'keywords') + def __init__(self, func, args, keywords): + self.func = func + self.args = args + self.keywords = keywords + + +class Compare(expr, mutable=ast.Compare): + _fields=('left', 'ops', 'comparators') + def __init__(self, left, ops, comparators): + self.left = left + self.ops = ops + self.comparators = comparators + + +class Constant(expr, mutable=ast.Constant): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class Dict(expr, mutable=ast.Dict): + _fields=('keys', 'values') + def __init__(self, keys, values): + self.keys = keys + self.values = values + + +class DictComp(expr, mutable=ast.DictComp): + _fields=('key', 'value', 'generators') + def __init__(self, key, value, generators): + self.key = key + self.value = value + self.generators = generators + + +class Ellipsis(expr, mutable=ast.Ellipsis): + _fields=() + def __init__(self): + pass + + +class FormattedValue(expr, mutable=ast.FormattedValue): + _fields=('value', 'conversion', 'format_spec') + def __init__(self, value, conversion, format_spec): + self.value = value + self.conversion = conversion + self.format_spec = format_spec + + +class GeneratorExp(expr, mutable=ast.GeneratorExp): + _fields=('elt', 'generators') + def __init__(self, elt, generators): + self.elt = elt + self.generators = generators + + +class IfExp(expr, mutable=ast.IfExp): + _fields=('test', 'body', 'orelse') + def __init__(self, test, body, orelse): + self.test = test + self.body = body + self.orelse = orelse + + +class JoinedStr(expr, mutable=ast.JoinedStr): + _fields=('values',) + def __init__(self, values): + self.values = values + + +class Lambda(expr, mutable=ast.Lambda): + _fields=('args', 'body') + def __init__(self, args, body): + self.args = args + self.body = body + + +class List(expr, mutable=ast.List): + _fields=('elts', 'ctx') + def __init__(self, elts, ctx): + self.elts = elts + self.ctx = ctx + + +class ListComp(expr, mutable=ast.ListComp): + _fields=('elt', 'generators') + def __init__(self, elt, generators): + self.elt = elt + self.generators = generators + + +class Name(expr, mutable=ast.Name): + _fields=('id', 'ctx') + def __init__(self, id, ctx): + self.id = id + self.ctx = ctx + + +class NameConstant(expr, mutable=ast.NameConstant): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class Num(expr, mutable=ast.Num): + _fields=('n',) + def __init__(self, n): + self.n = n + + +class Set(expr, mutable=ast.Set): + _fields=('elts',) + def __init__(self, elts): + self.elts = elts + + +class SetComp(expr, mutable=ast.SetComp): + _fields=('elt', 'generators') + def __init__(self, elt, generators): + self.elt = elt + self.generators = generators + + +class Starred(expr, mutable=ast.Starred): + _fields=('value', 'ctx') + def __init__(self, value, ctx): + self.value = value + self.ctx = ctx + + +class Str(expr, mutable=ast.Str): + _fields=('s',) + def __init__(self, s): + self.s = s + + +class Subscript(expr, mutable=ast.Subscript): + _fields=('value', 'slice', 'ctx') + def __init__(self, value, slice, ctx): + self.value = value + self.slice = slice + self.ctx = ctx + + +class Tuple(expr, mutable=ast.Tuple): + _fields=('elts', 'ctx') + def __init__(self, elts, ctx): + self.elts = elts + self.ctx = ctx + + +class UnaryOp(expr, mutable=ast.UnaryOp): + _fields=('op', 'operand') + def __init__(self, op, operand): + self.op = op + self.operand = operand + + +class Yield(expr, mutable=ast.Yield): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class YieldFrom(expr, mutable=ast.YieldFrom): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class expr_context(AST, mutable=ast.expr_context): + _fields=() + def __init__(self): + pass + + +class AugLoad(expr_context, mutable=ast.AugLoad): + _fields=() + def __init__(self): + pass + + +class AugStore(expr_context, mutable=ast.AugStore): + _fields=() + def __init__(self): + pass + + +class Del(expr_context, mutable=ast.Del): + _fields=() + def __init__(self): + pass + + +class Load(expr_context, mutable=ast.Load): + _fields=() + def __init__(self): + pass + + +class Param(expr_context, mutable=ast.Param): + _fields=() + def __init__(self): + pass + + +class Store(expr_context, mutable=ast.Store): + _fields=() + def __init__(self): + pass + + +class keyword(AST, mutable=ast.keyword): + _fields=('arg', 'value') + def __init__(self, arg, value): + self.arg = arg + self.value = value + + +class mod(AST, mutable=ast.mod): + _fields=() + def __init__(self): + pass + + +class Expression(mod, mutable=ast.Expression): + _fields=('body',) + def __init__(self, body): + self.body = body + + +class Interactive(mod, mutable=ast.Interactive): + _fields=('body',) + def __init__(self, body): + self.body = body + + +class Module(mod, mutable=ast.Module): + _fields=('body',) + def __init__(self, body): + self.body = body + + +class Suite(mod, mutable=ast.Suite): + _fields=('body',) + def __init__(self, body): + self.body = body + + +class operator(AST, mutable=ast.operator): + _fields=() + def __init__(self): + pass + + +class Add(operator, mutable=ast.Add): + _fields=() + def __init__(self): + pass + + +class BitAnd(operator, mutable=ast.BitAnd): + _fields=() + def __init__(self): + pass + + +class BitOr(operator, mutable=ast.BitOr): + _fields=() + def __init__(self): + pass + + +class BitXor(operator, mutable=ast.BitXor): + _fields=() + def __init__(self): + pass + + +class Div(operator, mutable=ast.Div): + _fields=() + def __init__(self): + pass + + +class FloorDiv(operator, mutable=ast.FloorDiv): + _fields=() + def __init__(self): + pass + + +class LShift(operator, mutable=ast.LShift): + _fields=() + def __init__(self): + pass + + +class MatMult(operator, mutable=ast.MatMult): + _fields=() + def __init__(self): + pass + + +class Mod(operator, mutable=ast.Mod): + _fields=() + def __init__(self): + pass + + +class Mult(operator, mutable=ast.Mult): + _fields=() + def __init__(self): + pass + + +class Pow(operator, mutable=ast.Pow): + _fields=() + def __init__(self): + pass + + +class RShift(operator, mutable=ast.RShift): + _fields=() + def __init__(self): + pass + + +class Sub(operator, mutable=ast.Sub): + _fields=() + def __init__(self): + pass + + +class slice(AST, mutable=ast.slice): + _fields=() + def __init__(self): + pass + + +class ExtSlice(slice, mutable=ast.ExtSlice): + _fields=('dims',) + def __init__(self, dims): + self.dims = dims + + +class Index(slice, mutable=ast.Index): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class Slice(slice, mutable=ast.Slice): + _fields=('lower', 'upper', 'step') + def __init__(self, lower, upper, step): + self.lower = lower + self.upper = upper + self.step = step + + +class stmt(AST, mutable=ast.stmt): + _fields=() + def __init__(self): + pass + + +class AnnAssign(stmt, mutable=ast.AnnAssign): + _fields=('target', 'annotation', 'value', 'simple') + def __init__(self, target, annotation, value, simple): + self.target = target + self.annotation = annotation + self.value = value + self.simple = simple + + +class Assert(stmt, mutable=ast.Assert): + _fields=('test', 'msg') + def __init__(self, test, msg): + self.test = test + self.msg = msg + + +class Assign(stmt, mutable=ast.Assign): + _fields=('targets', 'value') + def __init__(self, targets, value): + self.targets = targets + self.value = value + + +class AsyncFor(stmt, mutable=ast.AsyncFor): + _fields=('target', 'iter', 'body', 'orelse') + def __init__(self, target, iter, body, orelse): + self.target = target + self.iter = iter + self.body = body + self.orelse = orelse + + +class AsyncFunctionDef(stmt, mutable=ast.AsyncFunctionDef): + _fields=('name', 'args', 'body', 'decorator_list', 'returns') + def __init__(self, name, args, body, decorator_list, returns): + self.name = name + self.args = args + self.body = body + self.decorator_list = decorator_list + self.returns = returns + + +class AsyncWith(stmt, mutable=ast.AsyncWith): + _fields=('items', 'body') + def __init__(self, items, body): + self.items = items + self.body = body + + +class AugAssign(stmt, mutable=ast.AugAssign): + _fields=('target', 'op', 'value') + def __init__(self, target, op, value): + self.target = target + self.op = op + self.value = value + + +class Break(stmt, mutable=ast.Break): + _fields=() + def __init__(self): + pass + + +class ClassDef(stmt, mutable=ast.ClassDef): + _fields=('name', 'bases', 'keywords', 'body', 'decorator_list') + def __init__(self, name, bases, keywords, body, decorator_list): + self.name = name + self.bases = bases + self.keywords = keywords + self.body = body + self.decorator_list = decorator_list + + +class Continue(stmt, mutable=ast.Continue): + _fields=() + def __init__(self): + pass + + +class Delete(stmt, mutable=ast.Delete): + _fields=('targets',) + def __init__(self, targets): + self.targets = targets + + +class Expr(stmt, mutable=ast.Expr): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class For(stmt, mutable=ast.For): + _fields=('target', 'iter', 'body', 'orelse') + def __init__(self, target, iter, body, orelse): + self.target = target + self.iter = iter + self.body = body + self.orelse = orelse + + +class FunctionDef(stmt, mutable=ast.FunctionDef): + _fields=('name', 'args', 'body', 'decorator_list', 'returns') + def __init__(self, name, args, body, decorator_list, returns): + self.name = name + self.args = args + self.body = body + self.decorator_list = decorator_list + self.returns = returns + + +class Global(stmt, mutable=ast.Global): + _fields=('names',) + def __init__(self, names): + self.names = names + + +class If(stmt, mutable=ast.If): + _fields=('test', 'body', 'orelse') + def __init__(self, test, body, orelse): + self.test = test + self.body = body + self.orelse = orelse + + +class Import(stmt, mutable=ast.Import): + _fields=('names',) + def __init__(self, names): + self.names = names + + +class ImportFrom(stmt, mutable=ast.ImportFrom): + _fields=('module', 'names', 'level') + def __init__(self, module, names, level): + self.module = module + self.names = names + self.level = level + + +class Nonlocal(stmt, mutable=ast.Nonlocal): + _fields=('names',) + def __init__(self, names): + self.names = names + + +class Pass(stmt, mutable=ast.Pass): + _fields=() + def __init__(self): + pass + + +class Raise(stmt, mutable=ast.Raise): + _fields=('exc', 'cause') + def __init__(self, exc, cause): + self.exc = exc + self.cause = cause + + +class Return(stmt, mutable=ast.Return): + _fields=('value',) + def __init__(self, value): + self.value = value + + +class Try(stmt, mutable=ast.Try): + _fields=('body', 'handlers', 'orelse', 'finalbody') + def __init__(self, body, handlers, orelse, finalbody): + self.body = body + self.handlers = handlers + self.orelse = orelse + self.finalbody = finalbody + + +class While(stmt, mutable=ast.While): + _fields=('test', 'body', 'orelse') + def __init__(self, test, body, orelse): + self.test = test + self.body = body + self.orelse = orelse + + +class With(stmt, mutable=ast.With): + _fields=('items', 'body') + def __init__(self, items, body): + self.items = items + self.body = body + + +class unaryop(AST, mutable=ast.unaryop): + _fields=() + def __init__(self): + pass + + +class Invert(unaryop, mutable=ast.Invert): + _fields=() + def __init__(self): + pass + + +class Not(unaryop, mutable=ast.Not): + _fields=() + def __init__(self): + pass + + +class UAdd(unaryop, mutable=ast.UAdd): + _fields=() + def __init__(self): + pass + + +class USub(unaryop, mutable=ast.USub): + _fields=() + def __init__(self): + pass + + +class withitem(AST, mutable=ast.withitem): + _fields=('context_expr', 'optional_vars') + def __init__(self, context_expr, optional_vars): + self.context_expr = context_expr + self.optional_vars = optional_vars + + From 898d36959e8c96105db85345160316f8f990323b Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 30 Aug 2019 12:16:23 -0700 Subject: [PATCH 2/7] add tests --- tests/test_immutable_ast.py | 53 +++++++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 tests/test_immutable_ast.py diff --git a/tests/test_immutable_ast.py b/tests/test_immutable_ast.py new file mode 100644 index 0000000..abe4b8a --- /dev/null +++ b/tests/test_immutable_ast.py @@ -0,0 +1,53 @@ +import pytest +import random +import ast + +from ast_tools import immutable_ast +from ast_tools.immutable_ast import ImmutableMeta + +with open(immutable_ast.__file__, 'r') as f: + text = f.read() + +tree = ast.parse(text) + +def test_mutable_to_immutable(): + def _test(tree, itree): + if isinstance(tree, ast.AST): + assert isinstance(itree, immutable_ast.AST) + assert isinstance(tree, type(itree)) + assert tree._fields == itree._fields + assert ImmutableMeta._mutable_to_immutable[type(tree)] is type(itree) + for field, value in ast.iter_fields(tree): + _test(value, getattr(itree, field)) + elif isinstance(tree, list): + assert isinstance(itree, tuple) + assert len(tree) == len(itree) + for c, ic in zip(tree, itree): + _test(c, ic) + else: + assert tree == itree + + + itree = immutable_ast.immutable(tree) + _test(tree, itree) + +def test_immutable_to_mutable(): + itree = immutable_ast.immutable(tree) + mtree = immutable_ast.mutable(itree) + + assert itree == immutable_ast.immutable(mtree) + +def test_mutate(): + node = immutable_ast.Name(id='foo', ctx=immutable_ast.Load()) + with pytest.raises(AttributeError): + node.id = 'bar' + + +def test_construct_from_mutable(): + node = immutable_ast.Module([ + ast.Name(id='foo', ctx=ast.Store()) + ]) + + assert isinstance(node.body, tuple) + assert type(node.body[0]) is immutable_ast.Name + assert type(node.body[0].ctx) is immutable_ast.Store From 2efbf75e922f011546d276c98e1a11b28ce31606 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 30 Aug 2019 12:43:30 -0700 Subject: [PATCH 3/7] Add tests --- tests/test_immutable_ast.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/tests/test_immutable_ast.py b/tests/test_immutable_ast.py index abe4b8a..dac605f 100644 --- a/tests/test_immutable_ast.py +++ b/tests/test_immutable_ast.py @@ -1,16 +1,25 @@ import pytest -import random import ast +import inspect from ast_tools import immutable_ast from ast_tools.immutable_ast import ImmutableMeta +from ast_tools import _immutable_ast -with open(immutable_ast.__file__, 'r') as f: - text = f.read() -tree = ast.parse(text) +trees = [] -def test_mutable_to_immutable(): +# inspect is about the largest module I know +# hopefully it has a diverse ast +for mod in (immutable_ast, _immutable_ast, inspect, ast, pytest): + with open(mod.__file__, 'r') as f: + text = f.read() + tree = ast.parse(text) + trees.append(tree) + + +@pytest.mark.parametrize("tree", trees) +def test_mutable_to_immutable(tree): def _test(tree, itree): if isinstance(tree, ast.AST): assert isinstance(itree, immutable_ast.AST) @@ -31,7 +40,8 @@ def _test(tree, itree): itree = immutable_ast.immutable(tree) _test(tree, itree) -def test_immutable_to_mutable(): +@pytest.mark.parametrize("tree", trees) +def test_immutable_to_mutable(tree): itree = immutable_ast.immutable(tree) mtree = immutable_ast.mutable(itree) From a57f4259d8d15df0a141cf0fcd26a90190538e2d Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 30 Aug 2019 13:01:59 -0700 Subject: [PATCH 4/7] Fix a bug; more tests --- ast_tools/_immutable_ast.py | 3 ++- tests/test_immutable_ast.py | 34 +++++++++++++++++++++++++++++++--- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/ast_tools/_immutable_ast.py b/ast_tools/_immutable_ast.py index de941f5..33d55a1 100644 --- a/ast_tools/_immutable_ast.py +++ b/ast_tools/_immutable_ast.py @@ -1,5 +1,6 @@ -import functools as ft import ast +import functools as ft +import typing as tp __ALL__ = ['ImmutableMeta', 'immutable', 'mutable'] diff --git a/tests/test_immutable_ast.py b/tests/test_immutable_ast.py index dac605f..c71d9c9 100644 --- a/tests/test_immutable_ast.py +++ b/tests/test_immutable_ast.py @@ -42,15 +42,43 @@ def _test(tree, itree): @pytest.mark.parametrize("tree", trees) def test_immutable_to_mutable(tree): + def _test(tree, mtree): + assert type(tree) is type(mtree) + if isinstance(tree, ast.AST): + for field, value in ast.iter_fields(tree): + _test(value, getattr(mtree, field)) + elif isinstance(tree, list): + assert len(tree) == len(mtree) + for c, mc in zip(tree, mtree): + _test(c, mc) + else: + assert tree == mtree + itree = immutable_ast.immutable(tree) mtree = immutable_ast.mutable(itree) + _test(tree, mtree) - assert itree == immutable_ast.immutable(mtree) + +@pytest.mark.parametrize("tree", trees) +def test_eq(tree): + itree = immutable_ast.immutable(tree) + jtree = immutable_ast.immutable(tree) + assert itree == jtree + assert hash(itree) == hash(jtree) def test_mutate(): node = immutable_ast.Name(id='foo', ctx=immutable_ast.Load()) - with pytest.raises(AttributeError): - node.id = 'bar' + # can add metadata to a node + node.random = 0 + del node.random + + # but cant change its fields + for field in node._fields: + with pytest.raises(AttributeError): + setattr(node, field, 'bar') + + with pytest.raises(AttributeError): + delattr(node, field) def test_construct_from_mutable(): From c85fa2d46b96554c528ffddff33c9f5cd2455172 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Fri, 30 Aug 2019 13:18:25 -0700 Subject: [PATCH 5/7] Add no cover pragma --- ast_tools/_immutable_ast.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ast_tools/_immutable_ast.py b/ast_tools/_immutable_ast.py index 33d55a1..cb7e411 100644 --- a/ast_tools/_immutable_ast.py +++ b/ast_tools/_immutable_ast.py @@ -107,7 +107,7 @@ def mutable(tree: 'AST'): # could actually generate the classes and put them in globals # but that would make text editors suck (no autocomplete etc) # so Instead generate the actual file -def _generate_immutable_ast(): +def _generate_immutable_ast(): # pragma: no cover import ast import inspect import sys @@ -233,7 +233,7 @@ def _build_cls_from_tree(tree): ) -if __name__ == '__main__': +if __name__ == '__main__': # pragma: no cover s = _generate_immutable_ast() with open('immutable_ast.py', 'w') as f: f.write(s) From c2c193c7a3a57983c2cc318d72ac57fa4b526ba6 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Mon, 2 Sep 2019 13:23:46 -0700 Subject: [PATCH 6/7] Move generation of immutable_ast to setup.py --- ast_tools/_immutable_ast.py | 240 ---------- ast_tools/immutable_ast.py | 785 -------------------------------- setup.py | 31 +- tests/test_immutable_ast.py | 3 +- util/generate_ast/__init__.py | 4 + util/generate_ast/_base.px | 45 ++ util/generate_ast/_functions.px | 113 +++++ util/generate_ast/_meta.px | 19 + util/generate_ast/generate.py | 132 ++++++ 9 files changed, 337 insertions(+), 1035 deletions(-) delete mode 100644 ast_tools/_immutable_ast.py delete mode 100644 ast_tools/immutable_ast.py create mode 100644 util/generate_ast/__init__.py create mode 100644 util/generate_ast/_base.px create mode 100644 util/generate_ast/_functions.px create mode 100644 util/generate_ast/_meta.px create mode 100644 util/generate_ast/generate.py diff --git a/ast_tools/_immutable_ast.py b/ast_tools/_immutable_ast.py deleted file mode 100644 index cb7e411..0000000 --- a/ast_tools/_immutable_ast.py +++ /dev/null @@ -1,240 +0,0 @@ -import ast -import functools as ft -import typing as tp - -__ALL__ = ['ImmutableMeta', 'immutable', 'mutable'] - -class ImmutableMeta(type): - _immutable_to_mutable = dict() - _mutable_to_immutable = dict() - def __new__(mcs, name, bases, namespace, mutable, **kwargs): - def __setattr__(self, attr, value): - if attr in self._fields and hasattr(self, attr): - raise AttributeError('Cannot modify ImmutableAST fields') - elif isinstance(value, (list, ast.AST)): - value = immutable(value) - - self.__dict__[attr] = value - - def __delattr__(self, attr): - if attr in self._fields: - raise AttributeError('Cannot modify ImmutableAST fields') - del self.__dict__[attr] - - def __hash__(self): - try: - return self._hash_ - except AttributeError: - pass - - h = hash(type(self)) - for _, n in ast.iter_fields(self): - if isinstance(type(n), ImmutableMeta): - h += hash(n) - elif isinstance(n, tp.Sequence): - for c in n: - h += hash(c) - else: - h += hash(n) - self._hash_ = h - return h - - - def __eq__(self, other): - if not isinstance(other, type(self)): - return NotImplemented - elif type(self) == type(other): - for f in self._fields: - if getattr(self, f) != getattr(other, f): - return False - return True - else: - return False - - - def __ne__(self, other): - return not (self == other) - - namespace['__setattr__'] = __setattr__ - namespace['__delattr__'] = __delattr__ - namespace['__hash__'] = __hash__ - namespace['__eq__'] = __eq__ - namespace['__ne__'] = __ne__ - - cls = super().__new__(mcs, name, bases, namespace, **kwargs) - - ImmutableMeta._immutable_to_mutable[cls] = mutable - ImmutableMeta._mutable_to_immutable[mutable] = cls - - return cls - - def __instancecheck__(cls, instance): - return super().__instancecheck__(instance)\ - or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls]) - - def __subclasscheck__(cls, type_): - return super().__subclasscheck__(type_)\ - or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls]) - - -def _cast_tree(seq_t, n_seq_t, type_look_up, tree): - args = seq_t, n_seq_t, type_look_up - - if isinstance(tree, seq_t): - return n_seq_t(_cast_tree(*args, c) for c in tree) - - try: - T = type_look_up[type(tree)] - except KeyError: - return tree - - kwargs = {} - for field, c in ast.iter_fields(tree): - kwargs[field] = _cast_tree(*args, c) - - return T(**kwargs) - - -def immutable(tree: ast.AST): - '''Converts a mutable ast to an immutable one''' - return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree) - -def mutable(tree: 'AST'): - '''Converts an immutable ast to a mutable one''' - return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree) - - -# could actually generate the classes and put them in globals -# but that would make text editors suck (no autocomplete etc) -# so Instead generate the actual file -def _generate_immutable_ast(): # pragma: no cover - import ast - import inspect - import sys - import datetime - - immutable_ast_template = '''\ -{head_comment} - -import ast -import sys -import warnings - -from ._immutable_ast import * - -{version_check} - -__ALL__ = {ALL} - -{classes} -''' - - class_template = '''\ -class {name}({bases}mutable=ast.{name}{meta}): -{tab}_fields={fields} -{tab}def __init__(self{sig}):''' - - builder_template = '{tab}{tab}self.{arg} = {arg}' - - tab = ' '*4 - - pass_string = f'{tab}{tab}pass' - - nl = '\n' - head_comment = f'''\ -# file generated by {__file__} on {datetime.datetime.now()} -# for python {sys.version.split(nl)[0].strip()}''' - - version_check = f'''\ -if sys.version_info[:2] != {sys.version_info[:2]}: -{tab}warnings.warn(f"{{__file__}} generated for {sys.version_info[:2]}" -{tab} f"does not match system version {{sys.version_info[:2]}}")''' - - def _issubclass(t, types): - try: - return issubclass(t, types) - except TypeError: - pass - return False - - _classes = [] - for _name in dir(ast): - _obj = getattr(ast, _name) - if _issubclass(_obj, ast.AST): - _classes.append(_obj) - - _class_tree = inspect.getclasstree(_classes) - assert _class_tree == inspect.getclasstree(_classes, unique=True) - _cls_to_args = {} - - def _build_cls_from_tree(tree): - for item in tree: - if isinstance(item, list): - r = _build_cls_from_tree(item) - if r is not None: - return r - elif item[0] not in _cls_to_args: - cls = item[0] - bases = tuple(_cls_to_args[base][0] for base in item[1] if base is not object) - _cls_to_args[cls] = r = cls.__name__, bases, cls._fields - return r - - - _class_strings = [] - _all = [] - _args = _build_cls_from_tree(_class_tree[1]) - while _args is not None: - name=_args[0] - bases=', '.join(_args[1]) - if bases != '': - bases += ', ' - meta='' - else: - meta=', metaclass=ImmutableMeta' - - - fields = _args[2] - - if fields: - sig = ', ' - else: - sig = '' - sig += ', '.join(fields) - - class_ = [class_template.format( - tab=tab, - name=name, - bases=bases, - meta=meta, - fields=fields, - sig=sig, - )] - - if fields: - for arg in fields: - class_.append(builder_template.format( - tab=tab, - arg=arg, - )) - else: - class_.append(pass_string) - - class_.append('\n') - _class_strings.append('\n'.join(class_)) - _all.append(name) - _args = _build_cls_from_tree(_class_tree[1]) - - - return immutable_ast_template.format( - head_comment=head_comment, - version_check=version_check, - ALL=_all, - classes = '\n'.join(_class_strings) - ) - - -if __name__ == '__main__': # pragma: no cover - s = _generate_immutable_ast() - with open('immutable_ast.py', 'w') as f: - f.write(s) - diff --git a/ast_tools/immutable_ast.py b/ast_tools/immutable_ast.py deleted file mode 100644 index 0753abb..0000000 --- a/ast_tools/immutable_ast.py +++ /dev/null @@ -1,785 +0,0 @@ -# file generated by ast_tools/_immutable_ast.py on 2019-08-30 12:22:11.404048 -# for python 3.7.3 (default, Apr 3 2019, 05:39:12) - -import ast -import sys -import warnings - -from ._immutable_ast import * - -if sys.version_info[:2] != (3, 7): - warnings.warn(f"{__file__} generated for (3, 7)" - f"does not match system version {sys.version_info[:2]}") - -__ALL__ = ['AST', 'alias', 'arg', 'arguments', 'boolop', 'And', 'Or', 'cmpop', 'Eq', 'Gt', 'GtE', 'In', 'Is', 'IsNot', 'Lt', 'LtE', 'NotEq', 'NotIn', 'comprehension', 'excepthandler', 'ExceptHandler', 'expr', 'Attribute', 'Await', 'BinOp', 'BoolOp', 'Bytes', 'Call', 'Compare', 'Constant', 'Dict', 'DictComp', 'Ellipsis', 'FormattedValue', 'GeneratorExp', 'IfExp', 'JoinedStr', 'Lambda', 'List', 'ListComp', 'Name', 'NameConstant', 'Num', 'Set', 'SetComp', 'Starred', 'Str', 'Subscript', 'Tuple', 'UnaryOp', 'Yield', 'YieldFrom', 'expr_context', 'AugLoad', 'AugStore', 'Del', 'Load', 'Param', 'Store', 'keyword', 'mod', 'Expression', 'Interactive', 'Module', 'Suite', 'operator', 'Add', 'BitAnd', 'BitOr', 'BitXor', 'Div', 'FloorDiv', 'LShift', 'MatMult', 'Mod', 'Mult', 'Pow', 'RShift', 'Sub', 'slice', 'ExtSlice', 'Index', 'Slice', 'stmt', 'AnnAssign', 'Assert', 'Assign', 'AsyncFor', 'AsyncFunctionDef', 'AsyncWith', 'AugAssign', 'Break', 'ClassDef', 'Continue', 'Delete', 'Expr', 'For', 'FunctionDef', 'Global', 'If', 'Import', 'ImportFrom', 'Nonlocal', 'Pass', 'Raise', 'Return', 'Try', 'While', 'With', 'unaryop', 'Invert', 'Not', 'UAdd', 'USub', 'withitem'] - -class AST(mutable=ast.AST, metaclass=ImmutableMeta): - _fields=() - def __init__(self): - pass - - -class alias(AST, mutable=ast.alias): - _fields=('name', 'asname') - def __init__(self, name, asname): - self.name = name - self.asname = asname - - -class arg(AST, mutable=ast.arg): - _fields=('arg', 'annotation') - def __init__(self, arg, annotation): - self.arg = arg - self.annotation = annotation - - -class arguments(AST, mutable=ast.arguments): - _fields=('args', 'vararg', 'kwonlyargs', 'kw_defaults', 'kwarg', 'defaults') - def __init__(self, args, vararg, kwonlyargs, kw_defaults, kwarg, defaults): - self.args = args - self.vararg = vararg - self.kwonlyargs = kwonlyargs - self.kw_defaults = kw_defaults - self.kwarg = kwarg - self.defaults = defaults - - -class boolop(AST, mutable=ast.boolop): - _fields=() - def __init__(self): - pass - - -class And(boolop, mutable=ast.And): - _fields=() - def __init__(self): - pass - - -class Or(boolop, mutable=ast.Or): - _fields=() - def __init__(self): - pass - - -class cmpop(AST, mutable=ast.cmpop): - _fields=() - def __init__(self): - pass - - -class Eq(cmpop, mutable=ast.Eq): - _fields=() - def __init__(self): - pass - - -class Gt(cmpop, mutable=ast.Gt): - _fields=() - def __init__(self): - pass - - -class GtE(cmpop, mutable=ast.GtE): - _fields=() - def __init__(self): - pass - - -class In(cmpop, mutable=ast.In): - _fields=() - def __init__(self): - pass - - -class Is(cmpop, mutable=ast.Is): - _fields=() - def __init__(self): - pass - - -class IsNot(cmpop, mutable=ast.IsNot): - _fields=() - def __init__(self): - pass - - -class Lt(cmpop, mutable=ast.Lt): - _fields=() - def __init__(self): - pass - - -class LtE(cmpop, mutable=ast.LtE): - _fields=() - def __init__(self): - pass - - -class NotEq(cmpop, mutable=ast.NotEq): - _fields=() - def __init__(self): - pass - - -class NotIn(cmpop, mutable=ast.NotIn): - _fields=() - def __init__(self): - pass - - -class comprehension(AST, mutable=ast.comprehension): - _fields=('target', 'iter', 'ifs', 'is_async') - def __init__(self, target, iter, ifs, is_async): - self.target = target - self.iter = iter - self.ifs = ifs - self.is_async = is_async - - -class excepthandler(AST, mutable=ast.excepthandler): - _fields=() - def __init__(self): - pass - - -class ExceptHandler(excepthandler, mutable=ast.ExceptHandler): - _fields=('type', 'name', 'body') - def __init__(self, type, name, body): - self.type = type - self.name = name - self.body = body - - -class expr(AST, mutable=ast.expr): - _fields=() - def __init__(self): - pass - - -class Attribute(expr, mutable=ast.Attribute): - _fields=('value', 'attr', 'ctx') - def __init__(self, value, attr, ctx): - self.value = value - self.attr = attr - self.ctx = ctx - - -class Await(expr, mutable=ast.Await): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class BinOp(expr, mutable=ast.BinOp): - _fields=('left', 'op', 'right') - def __init__(self, left, op, right): - self.left = left - self.op = op - self.right = right - - -class BoolOp(expr, mutable=ast.BoolOp): - _fields=('op', 'values') - def __init__(self, op, values): - self.op = op - self.values = values - - -class Bytes(expr, mutable=ast.Bytes): - _fields=('s',) - def __init__(self, s): - self.s = s - - -class Call(expr, mutable=ast.Call): - _fields=('func', 'args', 'keywords') - def __init__(self, func, args, keywords): - self.func = func - self.args = args - self.keywords = keywords - - -class Compare(expr, mutable=ast.Compare): - _fields=('left', 'ops', 'comparators') - def __init__(self, left, ops, comparators): - self.left = left - self.ops = ops - self.comparators = comparators - - -class Constant(expr, mutable=ast.Constant): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class Dict(expr, mutable=ast.Dict): - _fields=('keys', 'values') - def __init__(self, keys, values): - self.keys = keys - self.values = values - - -class DictComp(expr, mutable=ast.DictComp): - _fields=('key', 'value', 'generators') - def __init__(self, key, value, generators): - self.key = key - self.value = value - self.generators = generators - - -class Ellipsis(expr, mutable=ast.Ellipsis): - _fields=() - def __init__(self): - pass - - -class FormattedValue(expr, mutable=ast.FormattedValue): - _fields=('value', 'conversion', 'format_spec') - def __init__(self, value, conversion, format_spec): - self.value = value - self.conversion = conversion - self.format_spec = format_spec - - -class GeneratorExp(expr, mutable=ast.GeneratorExp): - _fields=('elt', 'generators') - def __init__(self, elt, generators): - self.elt = elt - self.generators = generators - - -class IfExp(expr, mutable=ast.IfExp): - _fields=('test', 'body', 'orelse') - def __init__(self, test, body, orelse): - self.test = test - self.body = body - self.orelse = orelse - - -class JoinedStr(expr, mutable=ast.JoinedStr): - _fields=('values',) - def __init__(self, values): - self.values = values - - -class Lambda(expr, mutable=ast.Lambda): - _fields=('args', 'body') - def __init__(self, args, body): - self.args = args - self.body = body - - -class List(expr, mutable=ast.List): - _fields=('elts', 'ctx') - def __init__(self, elts, ctx): - self.elts = elts - self.ctx = ctx - - -class ListComp(expr, mutable=ast.ListComp): - _fields=('elt', 'generators') - def __init__(self, elt, generators): - self.elt = elt - self.generators = generators - - -class Name(expr, mutable=ast.Name): - _fields=('id', 'ctx') - def __init__(self, id, ctx): - self.id = id - self.ctx = ctx - - -class NameConstant(expr, mutable=ast.NameConstant): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class Num(expr, mutable=ast.Num): - _fields=('n',) - def __init__(self, n): - self.n = n - - -class Set(expr, mutable=ast.Set): - _fields=('elts',) - def __init__(self, elts): - self.elts = elts - - -class SetComp(expr, mutable=ast.SetComp): - _fields=('elt', 'generators') - def __init__(self, elt, generators): - self.elt = elt - self.generators = generators - - -class Starred(expr, mutable=ast.Starred): - _fields=('value', 'ctx') - def __init__(self, value, ctx): - self.value = value - self.ctx = ctx - - -class Str(expr, mutable=ast.Str): - _fields=('s',) - def __init__(self, s): - self.s = s - - -class Subscript(expr, mutable=ast.Subscript): - _fields=('value', 'slice', 'ctx') - def __init__(self, value, slice, ctx): - self.value = value - self.slice = slice - self.ctx = ctx - - -class Tuple(expr, mutable=ast.Tuple): - _fields=('elts', 'ctx') - def __init__(self, elts, ctx): - self.elts = elts - self.ctx = ctx - - -class UnaryOp(expr, mutable=ast.UnaryOp): - _fields=('op', 'operand') - def __init__(self, op, operand): - self.op = op - self.operand = operand - - -class Yield(expr, mutable=ast.Yield): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class YieldFrom(expr, mutable=ast.YieldFrom): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class expr_context(AST, mutable=ast.expr_context): - _fields=() - def __init__(self): - pass - - -class AugLoad(expr_context, mutable=ast.AugLoad): - _fields=() - def __init__(self): - pass - - -class AugStore(expr_context, mutable=ast.AugStore): - _fields=() - def __init__(self): - pass - - -class Del(expr_context, mutable=ast.Del): - _fields=() - def __init__(self): - pass - - -class Load(expr_context, mutable=ast.Load): - _fields=() - def __init__(self): - pass - - -class Param(expr_context, mutable=ast.Param): - _fields=() - def __init__(self): - pass - - -class Store(expr_context, mutable=ast.Store): - _fields=() - def __init__(self): - pass - - -class keyword(AST, mutable=ast.keyword): - _fields=('arg', 'value') - def __init__(self, arg, value): - self.arg = arg - self.value = value - - -class mod(AST, mutable=ast.mod): - _fields=() - def __init__(self): - pass - - -class Expression(mod, mutable=ast.Expression): - _fields=('body',) - def __init__(self, body): - self.body = body - - -class Interactive(mod, mutable=ast.Interactive): - _fields=('body',) - def __init__(self, body): - self.body = body - - -class Module(mod, mutable=ast.Module): - _fields=('body',) - def __init__(self, body): - self.body = body - - -class Suite(mod, mutable=ast.Suite): - _fields=('body',) - def __init__(self, body): - self.body = body - - -class operator(AST, mutable=ast.operator): - _fields=() - def __init__(self): - pass - - -class Add(operator, mutable=ast.Add): - _fields=() - def __init__(self): - pass - - -class BitAnd(operator, mutable=ast.BitAnd): - _fields=() - def __init__(self): - pass - - -class BitOr(operator, mutable=ast.BitOr): - _fields=() - def __init__(self): - pass - - -class BitXor(operator, mutable=ast.BitXor): - _fields=() - def __init__(self): - pass - - -class Div(operator, mutable=ast.Div): - _fields=() - def __init__(self): - pass - - -class FloorDiv(operator, mutable=ast.FloorDiv): - _fields=() - def __init__(self): - pass - - -class LShift(operator, mutable=ast.LShift): - _fields=() - def __init__(self): - pass - - -class MatMult(operator, mutable=ast.MatMult): - _fields=() - def __init__(self): - pass - - -class Mod(operator, mutable=ast.Mod): - _fields=() - def __init__(self): - pass - - -class Mult(operator, mutable=ast.Mult): - _fields=() - def __init__(self): - pass - - -class Pow(operator, mutable=ast.Pow): - _fields=() - def __init__(self): - pass - - -class RShift(operator, mutable=ast.RShift): - _fields=() - def __init__(self): - pass - - -class Sub(operator, mutable=ast.Sub): - _fields=() - def __init__(self): - pass - - -class slice(AST, mutable=ast.slice): - _fields=() - def __init__(self): - pass - - -class ExtSlice(slice, mutable=ast.ExtSlice): - _fields=('dims',) - def __init__(self, dims): - self.dims = dims - - -class Index(slice, mutable=ast.Index): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class Slice(slice, mutable=ast.Slice): - _fields=('lower', 'upper', 'step') - def __init__(self, lower, upper, step): - self.lower = lower - self.upper = upper - self.step = step - - -class stmt(AST, mutable=ast.stmt): - _fields=() - def __init__(self): - pass - - -class AnnAssign(stmt, mutable=ast.AnnAssign): - _fields=('target', 'annotation', 'value', 'simple') - def __init__(self, target, annotation, value, simple): - self.target = target - self.annotation = annotation - self.value = value - self.simple = simple - - -class Assert(stmt, mutable=ast.Assert): - _fields=('test', 'msg') - def __init__(self, test, msg): - self.test = test - self.msg = msg - - -class Assign(stmt, mutable=ast.Assign): - _fields=('targets', 'value') - def __init__(self, targets, value): - self.targets = targets - self.value = value - - -class AsyncFor(stmt, mutable=ast.AsyncFor): - _fields=('target', 'iter', 'body', 'orelse') - def __init__(self, target, iter, body, orelse): - self.target = target - self.iter = iter - self.body = body - self.orelse = orelse - - -class AsyncFunctionDef(stmt, mutable=ast.AsyncFunctionDef): - _fields=('name', 'args', 'body', 'decorator_list', 'returns') - def __init__(self, name, args, body, decorator_list, returns): - self.name = name - self.args = args - self.body = body - self.decorator_list = decorator_list - self.returns = returns - - -class AsyncWith(stmt, mutable=ast.AsyncWith): - _fields=('items', 'body') - def __init__(self, items, body): - self.items = items - self.body = body - - -class AugAssign(stmt, mutable=ast.AugAssign): - _fields=('target', 'op', 'value') - def __init__(self, target, op, value): - self.target = target - self.op = op - self.value = value - - -class Break(stmt, mutable=ast.Break): - _fields=() - def __init__(self): - pass - - -class ClassDef(stmt, mutable=ast.ClassDef): - _fields=('name', 'bases', 'keywords', 'body', 'decorator_list') - def __init__(self, name, bases, keywords, body, decorator_list): - self.name = name - self.bases = bases - self.keywords = keywords - self.body = body - self.decorator_list = decorator_list - - -class Continue(stmt, mutable=ast.Continue): - _fields=() - def __init__(self): - pass - - -class Delete(stmt, mutable=ast.Delete): - _fields=('targets',) - def __init__(self, targets): - self.targets = targets - - -class Expr(stmt, mutable=ast.Expr): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class For(stmt, mutable=ast.For): - _fields=('target', 'iter', 'body', 'orelse') - def __init__(self, target, iter, body, orelse): - self.target = target - self.iter = iter - self.body = body - self.orelse = orelse - - -class FunctionDef(stmt, mutable=ast.FunctionDef): - _fields=('name', 'args', 'body', 'decorator_list', 'returns') - def __init__(self, name, args, body, decorator_list, returns): - self.name = name - self.args = args - self.body = body - self.decorator_list = decorator_list - self.returns = returns - - -class Global(stmt, mutable=ast.Global): - _fields=('names',) - def __init__(self, names): - self.names = names - - -class If(stmt, mutable=ast.If): - _fields=('test', 'body', 'orelse') - def __init__(self, test, body, orelse): - self.test = test - self.body = body - self.orelse = orelse - - -class Import(stmt, mutable=ast.Import): - _fields=('names',) - def __init__(self, names): - self.names = names - - -class ImportFrom(stmt, mutable=ast.ImportFrom): - _fields=('module', 'names', 'level') - def __init__(self, module, names, level): - self.module = module - self.names = names - self.level = level - - -class Nonlocal(stmt, mutable=ast.Nonlocal): - _fields=('names',) - def __init__(self, names): - self.names = names - - -class Pass(stmt, mutable=ast.Pass): - _fields=() - def __init__(self): - pass - - -class Raise(stmt, mutable=ast.Raise): - _fields=('exc', 'cause') - def __init__(self, exc, cause): - self.exc = exc - self.cause = cause - - -class Return(stmt, mutable=ast.Return): - _fields=('value',) - def __init__(self, value): - self.value = value - - -class Try(stmt, mutable=ast.Try): - _fields=('body', 'handlers', 'orelse', 'finalbody') - def __init__(self, body, handlers, orelse, finalbody): - self.body = body - self.handlers = handlers - self.orelse = orelse - self.finalbody = finalbody - - -class While(stmt, mutable=ast.While): - _fields=('test', 'body', 'orelse') - def __init__(self, test, body, orelse): - self.test = test - self.body = body - self.orelse = orelse - - -class With(stmt, mutable=ast.With): - _fields=('items', 'body') - def __init__(self, items, body): - self.items = items - self.body = body - - -class unaryop(AST, mutable=ast.unaryop): - _fields=() - def __init__(self): - pass - - -class Invert(unaryop, mutable=ast.Invert): - _fields=() - def __init__(self): - pass - - -class Not(unaryop, mutable=ast.Not): - _fields=() - def __init__(self): - pass - - -class UAdd(unaryop, mutable=ast.UAdd): - _fields=() - def __init__(self): - pass - - -class USub(unaryop, mutable=ast.USub): - _fields=() - def __init__(self): - pass - - -class withitem(AST, mutable=ast.withitem): - _fields=('context_expr', 'optional_vars') - def __init__(self, context_expr, optional_vars): - self.context_expr = context_expr - self.optional_vars = optional_vars - - diff --git a/setup.py b/setup.py index b4e20d6..6b6442e 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,28 @@ -""" +''' setup script for package -""" +''' from setuptools import setup +from setuptools.command.build_py import build_py +from os import path +import util.generate_ast.generate as generate -with open("README.md", "r") as fh: + +PACKAGE_NAME = 'ast_tools' + +with open('README.md', "r") as fh: LONG_DESCRIPTION = fh.read() +class BuildImmutableAst(build_py): + def run(self): + super().run() + if not self.dry_run: + src = generate.generate_immutable_ast() + with open(path.join(PACKAGE_NAME, 'immutable_ast.py'), 'w') as f: + f.write(src) + setup( + cmdclass={'build_py' : BuildImmutableAst}, name='ast_tools', url='https://github.com/leonardt/ast_tools', author='Leonard Truong', @@ -16,12 +31,12 @@ description='Toolbox for working with the Python AST', scripts=[], packages=[ - "ast_tools", - "ast_tools.visitors", - "ast_tools.transformers", - "ast_tools.passes" + f"{PACKAGE_NAME}", + f"{PACKAGE_NAME}.visitors", + f"{PACKAGE_NAME}.transformers", + f"{PACKAGE_NAME}.passes" ], install_requires=['astor'], long_description=LONG_DESCRIPTION, - long_description_content_type="text/markdown" + long_description_content_type='text/markdown' ) diff --git a/tests/test_immutable_ast.py b/tests/test_immutable_ast.py index c71d9c9..0d97fbe 100644 --- a/tests/test_immutable_ast.py +++ b/tests/test_immutable_ast.py @@ -4,14 +4,13 @@ import inspect from ast_tools import immutable_ast from ast_tools.immutable_ast import ImmutableMeta -from ast_tools import _immutable_ast trees = [] # inspect is about the largest module I know # hopefully it has a diverse ast -for mod in (immutable_ast, _immutable_ast, inspect, ast, pytest): +for mod in (immutable_ast, inspect, ast, pytest): with open(mod.__file__, 'r') as f: text = f.read() tree = ast.parse(text) diff --git a/util/generate_ast/__init__.py b/util/generate_ast/__init__.py new file mode 100644 index 0000000..8333482 --- /dev/null +++ b/util/generate_ast/__init__.py @@ -0,0 +1,4 @@ +if __name__ == '__main__': + import generate + print(generate.generate_immutable_ast()) + diff --git a/util/generate_ast/_base.px b/util/generate_ast/_base.px new file mode 100644 index 0000000..dcd8247 --- /dev/null +++ b/util/generate_ast/_base.px @@ -0,0 +1,45 @@ +class AST(mutable=ast.AST, metaclass=ImmutableMeta): + def __setattr__(self, attr, value): + if attr in self._fields and hasattr(self, attr): + raise AttributeError('Cannot modify ImmutableAST fields') + elif isinstance(value, (list, ast.AST)): + value = immutable(value) + + self.__dict__[attr] = value + + def __delattr__(self, attr): + if attr in self._fields: + raise AttributeError('Cannot modify ImmutableAST fields') + del self.__dict__[attr] + + def __hash__(self): + try: + return self._hash_ + except AttributeError: + pass + + h = hash(type(self)) + for _, n in iter_fields(self): + if isinstance(n, AST): + h += hash(n) + elif isinstance(n, tp.Sequence): + for c in n: + h += hash(c) + else: + h += hash(n) + self._hash_ = h + return h + + def __eq__(self, other): + if not isinstance(other, type(self)): + return NotImplemented + elif type(self) == type(other): + for f in self._fields: + if getattr(self, f) != getattr(other, f): + return False + return True + else: + return False + + def __ne__(self, other): + return not (self == other) diff --git a/util/generate_ast/_functions.px b/util/generate_ast/_functions.px new file mode 100644 index 0000000..1c52de1 --- /dev/null +++ b/util/generate_ast/_functions.px @@ -0,0 +1,113 @@ +__ALL__ += ['immutable', 'mutable', 'parse', 'dump', + 'iter_fields', 'iter_child_nodes', 'walk', + 'NodeVisitor', 'NodeTransformer'] + + +def _cast_tree(seq_t, n_seq_t, type_look_up, tree): + args = seq_t, n_seq_t, type_look_up + + if isinstance(tree, seq_t): + return n_seq_t(_cast_tree(*args, c) for c in tree) + + try: + T = type_look_up[type(tree)] + except KeyError: + return tree + + kwargs = {} + for field, c in iter_fields(tree): + kwargs[field] = _cast_tree(*args, c) + + return T(**kwargs) + + +def immutable(tree: ast.AST) -> 'AST': + '''Converts a mutable ast to an immutable one''' + return _cast_tree(list, tuple, ImmutableMeta._mutable_to_immutable, tree) + +def mutable(tree: 'AST') -> ast.AST: + '''Converts an immutable ast to a mutable one''' + return _cast_tree(tuple, list, ImmutableMeta._immutable_to_mutable, tree) + +def parse(source, filename='', mode='exec') -> 'AST': + tree = ast.parse(source, filename, mode) + return immutable(tree) + +def dump(node, annotate_fields=True, include_attributes=False) -> str: + tree = mutable(node) + return ast.dump(tree) + + +# duck typing ftw +iter_fields = ast.iter_fields + +# The following is more or less copied verbatim from +# CPython/Lib/ast.py. Changes are: +# s/list/tuple/ +# +# The CPython license is very permissive so I am pretty sure this is cool. +# If it is not Guido please forgive me. +def iter_child_nodes(node): + for name, field in iter_fields(node): + if isinstance(field, AST): + yield field + elif isinstance(field, tuple): + for item in field: + if isinstance(item, AST): + yield item + +# Same note as above +def walk(node): + from collections import deque + todo = deque([node]) + while todo: + node = todo.popleft() + todo.extend(iter_child_nodes(node)) + yield node + + +# Same note as above +class NodeVisitor: + def visit(self, node): + method = 'visit_' + node.__class__.__name__ + visitor = getattr(self, method, self.generic_visit) + return visitor(node) + + def generic_visit(self, node): + for field, value in iter_fields(node): + if isinstance(value, tuple): + for item in value: + if isinstance(item, AST): + self.visit(item) + elif isinstance(value, AST): + self.visit(value) + + +# Same note as above +class NodeTransformer(NodeVisitor): + ''' + Mostly equivalent to ast.NodeTransformer, except returns new nodes + instead of mutating them in place + ''' + + def generic_visit(self, node): + kwargs = {} + for field, old_value in iter_fields(node): + if instance(old_value, tuple): + new_value = [] + for value in old_value: + if isinstance(value, AST): + value = self.visit(value) + if value is None: + continue + elif not isinstance(item, AST): + new_value.extend(value) + continue + new_value.append(value) + new_value = tuple(new_value) + elif isinstance(type(old_value), ImmutableMeta): + new_value = self.visit(old_value) + else: + new_value = old_value + kwargs[field] = new_value + return type(node)(**kwargs) diff --git a/util/generate_ast/_meta.px b/util/generate_ast/_meta.px new file mode 100644 index 0000000..5009fd4 --- /dev/null +++ b/util/generate_ast/_meta.px @@ -0,0 +1,19 @@ +__ALL__ += ['ImmutableMeta'] + +class ImmutableMeta(type): + _immutable_to_mutable = dict() + _mutable_to_immutable = dict() + def __new__(mcs, name, bases, namespace, mutable, **kwargs): + cls = super().__new__(mcs, name, bases, namespace, **kwargs) + ImmutableMeta._immutable_to_mutable[cls] = mutable + ImmutableMeta._mutable_to_immutable[mutable] = cls + + return cls + + def __instancecheck__(cls, instance): + return super().__instancecheck__(instance)\ + or isinstance(instance, ImmutableMeta._immutable_to_mutable[cls]) + + def __subclasscheck__(cls, type_): + return super().__subclasscheck__(type_)\ + or issubclass(type_, ImmutableMeta._immutable_to_mutable[cls]) diff --git a/util/generate_ast/generate.py b/util/generate_ast/generate.py new file mode 100644 index 0000000..4244aad --- /dev/null +++ b/util/generate_ast/generate.py @@ -0,0 +1,132 @@ +import ast +import datetime +import inspect +from os import path +import sys + +_BASE_PATH = path.dirname(__file__) +def _make_path(f): + return path.abspath(path.join(_BASE_PATH, f)) + +META_FILE = _make_path('_meta.px') +FUNCTIONS_FILE = _make_path('_functions.px') +AST_BASE_FILE = _make_path('_base.px') +TAB = ' '*4 + +def generate_class(name, bases, fields): + bases=', '.join(bases) + (', ' if bases else '') + sig = (', ' if fields else '') + ', '.join(fields) + body = [f'self.{arg} = {arg}' for arg in fields] + if not body: + body.append('pass') + + body = f'\n{TAB}{TAB}'.join(body) + + class_ = f'''\ +class {name}({bases}mutable=ast.{name}): +{TAB}_fields={fields} +{TAB}def __init__(self{sig}): +{TAB}{TAB}{body} +''' + + return class_ + +def generate_classes(class_tree, ALL): + cls_to_args = {ast.AST : ('AST', (), ())} + + def pop_args_from_tree(tree): + for item in tree: + if isinstance(item, list): + r = pop_args_from_tree(item) + if r is not None: + return r + elif item[0] not in cls_to_args: + cls = item[0] + bases = tuple(cls_to_args[base][0] for base in item[1] if base is not object) + cls_to_args[cls] = r = cls.__name__, bases, cls._fields + return r + + classes_ = [] + + args = pop_args_from_tree(class_tree[1]) + while args is not None: + class_ = generate_class(*args) + classes_.append(class_) + ALL.append(args[0]) + args = pop_args_from_tree(class_tree[1]) + + return '\n'.join(classes_) + + +def generate_immutable_ast(): + def _issubclass(t, types): + try: + return issubclass(t, types) + except TypeError: + pass + return False + + classes = [] + for name in dir(ast): + obj = getattr(ast, name) + if _issubclass(obj, ast.AST): + classes.append(obj) + + class_tree = inspect.getclasstree(classes) + # assert the class tree is a tree and not a dag + assert class_tree == inspect.getclasstree(classes, unique=True) + # assert the class tree has a root + assert len(class_tree) == 2, class_tree[0] + # assert the root is object + assert class_tree[0][0] is object, class_tree[0][0] + # assert the root has only 1 child + assert len(class_tree[1]) == 2, class_tree[1] + # assert that the child is ast.AST + assert class_tree[1][0][0] is ast.AST, class_tree[1][0] + + nl = '\n' + head_comment = f'''\ +# file generated by {__file__} on {datetime.datetime.now()} +# for python {sys.version.split(nl)[0].strip()}''' + + version_check = f'''\ +if sys.version_info[:2] != {sys.version_info[:2]}: +{TAB}warnings.warn(f"{{__file__}} generated for {sys.version_info[:2]}" +{TAB} f"does not match system version {{sys.version_info[:2]}}")''' + + + + ALL = ['AST'] + + with open(FUNCTIONS_FILE, 'r') as f: + functions = f.read() + + with open(META_FILE, 'r') as f: + meta = f.read() + + with open(AST_BASE_FILE, 'r') as f: + ast_base = f.read() + + classes = generate_classes(class_tree, ALL) + + immutable_ast = f'''\ +{head_comment} + +import ast +import sys +import typing as tp +import warnings + +{version_check} + +__ALL__ = {ALL} + +{functions} + +{meta} + +{ast_base} + +{classes} +''' + return immutable_ast From 2387bbd8174d54093fc2878e472e96478c112331 Mon Sep 17 00:00:00 2001 From: Caleb Donovick Date: Mon, 2 Sep 2019 17:08:47 -0700 Subject: [PATCH 7/7] Make setup work maybe --- setup.py | 47 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 36 insertions(+), 11 deletions(-) diff --git a/setup.py b/setup.py index 6b6442e..31d11f8 100644 --- a/setup.py +++ b/setup.py @@ -4,37 +4,62 @@ from setuptools import setup from setuptools.command.build_py import build_py +from setuptools.command.develop import develop from os import path import util.generate_ast.generate as generate - PACKAGE_NAME = 'ast_tools' with open('README.md', "r") as fh: LONG_DESCRIPTION = fh.read() -class BuildImmutableAst(build_py): - def run(self): - super().run() +class Install(build_py): + def run(self, *args, **kwargs): + self.generated_outputs = [] + if not self.dry_run: + src = generate.generate_immutable_ast() + output_dir = path.join(self.build_lib, PACKAGE_NAME) + self.mkpath(output_dir) + output_file = path.join(output_dir, 'immutable_ast.py') + self.announce(f'generating {output_file}', 2) + with open(output_file, 'w') as f: + f.write(src) + self.generated_outputs.append(output_file) + super().run(*args, **kwargs) + + def get_outputs(self, *args, **kwargs): + outputs = super().get_outputs(*args, **kwargs) + outputs.extend(self.generated_outputs) + return outputs + + +class Develop(develop): + def run(self, *args, **kwargs): if not self.dry_run: src = generate.generate_immutable_ast() - with open(path.join(PACKAGE_NAME, 'immutable_ast.py'), 'w') as f: + output_file = path.join(PACKAGE_NAME, 'immutable_ast.py') + self.announce(f'generating {output_file}', 2) + with open(output_file, 'w') as f: f.write(src) + super().run(*args, **kwargs) setup( - cmdclass={'build_py' : BuildImmutableAst}, + cmdclass={ + 'build_py': Install, + 'develop': Develop, + }, name='ast_tools', url='https://github.com/leonardt/ast_tools', author='Leonard Truong', author_email='lenny@cs.stanford.edu', - version='0.0.5', + version='0.0.6', description='Toolbox for working with the Python AST', scripts=[], packages=[ - f"{PACKAGE_NAME}", - f"{PACKAGE_NAME}.visitors", - f"{PACKAGE_NAME}.transformers", - f"{PACKAGE_NAME}.passes" + f'{PACKAGE_NAME}', + f'{PACKAGE_NAME}.visitors', + f'{PACKAGE_NAME}.transformers', + f'{PACKAGE_NAME}.passes' ], install_requires=['astor'], long_description=LONG_DESCRIPTION,