# Basics

In [1]:
import ast
import textwrap

Build AST from the code stored in the string

In [2]:
tree = ast.parse("print('hello world')")
print(tree)

<_ast.Module object at 0x106ffb588>


Compile and execute AST

In [3]:
exec(compile(tree, filename="<ast>", mode="exec"))

hello world


# Working with AST

Print the names of any functions defined in the given code

In [4]:
class FuncLister(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print(node.name)
        self.generic_visit(node)


tree = ast.parse(textwrap.dedent("""
    def foo():
        pass
    
    def bar():
        pass
"""))
FuncLister().visit(tree)

foo
bar


Walk over nodes (order is not guaranteed)

In [7]:
def main():
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            print(node.name)


main()

foo
bar


# Modifying the Tree
## Replace `foo` with `data['foo']`

In [29]:
tree = ast.parse(textwrap.dedent("""
    data = {'foo': 'bar'}
    print(foo)
"""))


class RewriteName(ast.NodeTransformer):
    def visit_Name(self, node):
        if node.id == 'foo':
            return ast.copy_location(ast.Subscript(
                value=ast.Name(id='data', ctx=ast.Load()),
                slice=ast.Index(value=ast.Str(s=node.id)),
                ctx=node.ctx
            ), node)
        return node


tree = RewriteName().visit(tree)
ast.fix_missing_locations(tree)
exec(compile(tree, filename="<ast>", mode='exec'))

bar


## Wrapping Integers

In [24]:
from fractions import Fraction


class IntegerWrapper(ast.NodeTransformer):
    """Wraps all integers in a call to Integer()"""
    
    def visit_Num(self, node):
        if isinstance(node.n, int):
            return ast.Call(func=ast.Name(id='Integer', ctx=ast.Load()),
                            args=[node], keywords=[])
        return node


class Integer:
    def __init__(self, value):
        self.value = value
    
    def __truediv__(self, other):
        if isinstance(other, Integer):
            return Fraction(numerator=self.value, denominator=other.value)


code = "print((1/10)+(2/10))"
print(code)
print()

print("Without AST transformation:")
exec(code)
print()

print("With AST transformation:")
tree = ast.parse(code)
tree = IntegerWrapper().visit(tree)
# Add lineno & col_offset to the nodes we created
ast.fix_missing_locations(tree)
exec(compile(tree, "<ast>", "exec"))


print((1/10)+(2/10))

Without AST transformation:
0.30000000000000004

With AST transformation:
3/10
