# Basics

In [1]:
import ast
import textwrap

Build AST from the code stored in the string

In [2]:
tree = ast.parse("print('hello world')")
print(tree)

<_ast.Module object at 0x106ffb588>


Compile and execute AST

In [3]:
exec(compile(tree, filename="<ast>", mode="exec"))

hello world


# Working with AST

Print the names of any functions defined in the given code

In [None]:
class FuncLister(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(node.name)
        self.generic_visit(node)


tree = ast.parse(textwrap.dedent("""
    def foo():
        pass
    
    def bar():
        pass
"""))
FuncLister().visit(tree)

Walk over nodes (order is not guaranteed)

In [None]:
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print(node.name)

# Modifying the Tree
## Replace `foo` with `data['foo']`

In [None]:
tree = ast.parse(textwrap.dedent("""
    data = {'foo': 'bar'}
    print(foo)
"""))


class RewriteName(ast.NodeTransformer):
    def visit_Name(self, node):
        if node.id == 'foo':
            return ast.copy_location(ast.Subscript(
                value=ast.Name(id='data', ctx=ast.Load()),
                slice=ast.Index(value=ast.Str(s=node.id)),
                ctx=node.ctx
            ), node)
        return node


tree = RewriteName().visit(tree)
ast.fix_missing_locations(tree)
exec(compile(tree, filename="<ast>", mode='exec'))

## Wrapping Integers

In [None]:
from fractions import Fraction


class IntegerWrapper(ast.NodeTransformer):
    """Wraps all integers in a call to Integer()"""
    
    def visit_Num(self, node):
        if isinstance(node.n, int):
            return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
                            args=[node], keywords=[])
        return node


class Integer:
    def __init__(self, value):
        self.value = value
    
    def __truediv__(self, other):
        if isinstance(other, Integer):
            return Fraction(numerator=self.value, denominator=other.value)


code = "print((1/10)+(2/10))"
print(code)
print()

print("Without AST transformation:")
exec(code)
print()

print("With AST transformation:")
tree = ast.parse(code)
tree = IntegerWrapper().visit(tree)
# Add lineno & col_offset to the nodes we created
ast.fix_missing_locations(tree)
exec(compile(tree, "<ast>", "exec"))

## Test Runner

In [None]:
"""This will run asserts.py, but keep going if an assertion fails.

It also transforms assertions of the form a==b into a function call, which can
display more info if the 
"""
import ast

filename = "asserts.py"
with open(filename, encoding='utf-8') as f:
    code = f.read()


class AssertCmpTransformer(ast.NodeTransformer):
    """Transform 'assert a==b' into 'assert_equal(a, b)'
    """

    def visit_Assert(self, node):
        # If assertion contains comparison
        #    and it has only one comparison
        #    and the comparison is equality check
        if isinstance(node.test, ast.Compare) and \
                len(node.test.ops) == 1 and \
                isinstance(node.test.ops[0], ast.Eq):
            call = ast.Call(func=ast.Name(id='assert_equal', ctx=ast.Load()),
                            args=[node.test.left, node.test.comparators[0]],
                            keywords=[])
            # Wrap the call in an Expr node, because the return value isn't used
            new_node = ast.Expr(value=call)
            ast.copy_location(new_node, node)
            ast.fix_missing_locations(new_node)
            return new_node
        
        # Return the original node if we don't want to change it.
        return node


def assert_equal(a, b):
    if a != b:
        raise AssertionError("%r != %r" % (a, b))


tree = ast.parse(code)
lines = [None] + code.splitlines()  # None at [0] so we can index lines from 1
test_namespace = {'assert_equal': assert_equal}

tree = AssertCmpTransformer().visit(tree)

for node in tree.body:
    wrapper = ast.Module(body=[node])
    try:
        co = compile(wrapper, filename, 'exec')
        exec(co, test_namespace)
    except AssertionError as e:
        print("Assertion failed on line", node.lineno, ":")
        print(lines[node.lineno])
        # If the error has a message, show it.
        if e.args:
            print(e)
        print()