# Tree

如何访问及修改 ast。

## NodeVisitor

In [1]:
from ast import *

In [2]:
source = '''PI = 3.1415
def area(radius):
    return PI * radius ** 2
    
print(area(10))
'''

ast = parse(source)

In [3]:
class MyVisitor(NodeVisitor):
    # 定义 visit_Node 方法，此 Visitor 在访问到该类型 Node 时会调用此方法
    def visit_Name(self, node):
        print('Name:', node.id)
        self.generic_visit(node)
    def visit_FunctionDef(self, node):
        print('Def function: ', node.name)
        # 如果包含子节点，调用 generic_visit 来访问子节点
        self.generic_visit(node)
        
MyVisitor().visit(ast)

exec(compile(ast, '', 'exec'))

Name: PI
Def function:  area
Name: PI
Name: radius
Name: print
Name: area
314.15000000000003


## NodeTransformer

NodeTransformer 是 NodeVisitor 的子类，用于转换节点。

In [4]:
class PrintTransformer(NodeTransformer):
    # 我们给 print 函数加一个默认的前缀参数：
    def visit_Call(self, node):
        default_prefix = [Str(s='You see me? -->')]

        if type(node.func) == Name and node.func.id == 'print':
            return Call(
                func=node.func,
                args=default_prefix + node.args,
                keywords=node.keywords,
            )
        return node
    
PrintTransformer().visit(ast)
# 转换后的 node 缺少 lineno 等必须属性
fix_missing_locations(ast)

exec(compile(ast, '', 'exec'))

You see me? --> 314.15000000000003


我们来实现 Rust 中的隐式返回（不用 return 关键字返回函数体的最后一个表达式）

In [5]:
class ReturnTransformer(PrintTransformer):
    def visit_FunctionDef(self, node):
        self.generic_visit(node)

        if type(node.body[-1]) == Expr:
            return FunctionDef(
                name=node.name,
                args=node.args,
                decorator_list=node.decorator_list,
                returns=node.returns,
                body=node.body[:-1] + [Return(value=node.body[-1].value)],
            )
        return node
    
def_source = '''def add(x, y):
    print('x is', x)
    print('y is', y)
    x + y

z = add(1, 2)
print('z is', z)
'''

def_ast = parse(def_source)
ReturnTransformer().visit(def_ast)
fix_missing_locations(def_ast)
exec(compile(def_ast, '', 'exec'))

You see me? --> x is 1
You see me? --> y is 2
You see me? --> z is 3
