# Visitors

The following example shows how to navigate an AST in Python.

In [9]:
import ast
your_code = """
def foo():  # type: ignore
    ret = 2 * 2
    return ret
"""
your_ast = ast.parse(your_code)
# print(your_ast)
# print(your_ast.type_ignores)
print(your_ast.body[0].body[1].value.id)
# print(your_ast.body[0].name)

ret


You can refer to the official document (https://docs.python.org/3/library/ast.html) to figure out available attributes for each AST node type. For example,

```
FunctionDef(identifier name, arguments args,
                       stmt* body, expr* decorator_list, expr? returns,
                       string? type_comment, type_param* type_params)
```

However, it is too tedious to explore AST nodes and their attributes one-by-one. Thus, it is necessary to come up with a better navigation method. Fortunately, `ast` package provides visitors to traverse a given AST. You can find an example below:

In [10]:
from ast import NodeVisitor, FunctionDef

class YourVisitor(NodeVisitor):
    def __init__(self) -> None:
        super().__init__()
    
    def visit_FunctionDef(self, node: FunctionDef):
        print("FunctionDef:", node.name)
            

In [22]:
avisitor = YourVisitor()
avisitor.visit(your_ast)

FunctionDef: foo


Let's try a longer code.

In [23]:
long_code = """
import pickle

import six


class MetaNode(type):
    def __new__(mcs, name, bases, dict):
        attrs = list(dict['attrs'])
        dict['attrs'] = list()

        for base in bases:
            if hasattr(base, 'attrs'):
                dict['attrs'].extend(base.attrs)

        dict['attrs'].extend(attrs)

        return type.__new__(mcs, name, bases, dict)


@six.add_metaclass(MetaNode)
class Node(object):
    attrs = ()

    def __init__(self, **kwargs):
        values = kwargs.copy()

        for attr_name in self.attrs:
            value = values.pop(attr_name, None)
            setattr(self, attr_name, value)

        if values:
            raise ValueError('Extraneous arguments')

    def __equals__(self, other):
        if type(other) is not type(self):
            return False

        for attr in self.attrs:
            if getattr(other, attr) != getattr(self, attr):
                return False

        return True

    def __repr__(self):
        attr_values = []
        for attr in sorted(self.attrs):
            attr_values.append('%s=%s' % (attr, getattr(self, attr)))
        return '%s(%s)' % (type(self).__name__, ', '.join(attr_values))

    def __iter__(self):
        return walk_tree(self)

    def filter(self, pattern):
        for path, node in self:
            if ((isinstance(pattern, type) and isinstance(node, pattern)) or
                (node == pattern)):
                yield path, node

    @property
    def children(self):
        return [getattr(self, attr_name) for attr_name in self.attrs]
    
    @property
    def position(self):
        if hasattr(self, "_position"):
            return self._position

def walk_tree(root):
    children = None

    if isinstance(root, Node):
        yield (), root
        children = root.children
    else:
        children = root

    for child in children:
        if isinstance(child, (Node, list, tuple)):
            for path, node in walk_tree(child):
                yield (root,) + path, node

def dump(ast, file):
    pickle.dump(ast, file)

def load(file):
    return pickle.load(file)
"""

In [24]:
long_ast = ast.parse(long_code)
# ast.dump(long_ast)

In [25]:
avisitor.visit(long_ast)

FunctionDef: __new__
FunctionDef: __init__
FunctionDef: __equals__
FunctionDef: __repr__
FunctionDef: __iter__
FunctionDef: filter
FunctionDef: children
FunctionDef: position
FunctionDef: walk_tree
FunctionDef: dump
FunctionDef: load


In [26]:
from ast import Name, Call

class YourVisitor(YourVisitor):
    def visit_Name(self, node: Name):
        print("Name reference:", node.id)
    
    def visit_Call(self, node: Call):
        print("Func Invocation:", node.func, node.args)

In [27]:
bvisitor = YourVisitor()
bvisitor.visit(long_ast)

Name reference: type
FunctionDef: __new__
Name reference: object
Name reference: attrs
FunctionDef: __init__
FunctionDef: __equals__
FunctionDef: __repr__
FunctionDef: __iter__
FunctionDef: filter
FunctionDef: children
FunctionDef: position
Func Invocation: <ast.Attribute object at 0x109726680> [<ast.Name object at 0x109726800>]
FunctionDef: walk_tree
FunctionDef: dump
FunctionDef: load


In [28]:
class YourVisitor(YourVisitor):
    def __init__(self) -> None:
        super().__init__()
        
    def visit_FunctionDef(self, node: FunctionDef):
        print("FunctionDef:", node.name)
        return self.generic_visit(node)
    
    def visit_Name(self, node: Name):
        print("Name reference:", node.id)
        return self.generic_visit(node)
    
    def visit_Call(self, node: Call):
        print("Func Invocation:", node.func, node.args)
        return self.generic_visit(node)

In [29]:
bvisitor = YourVisitor()
bvisitor.visit(long_ast)

Name reference: type
FunctionDef: __new__
Name reference: attrs
Func Invocation: <ast.Name object at 0x10973c2e0> [<ast.Subscript object at 0x10973c2b0>]
Name reference: list
Name reference: dict
Name reference: dict
Func Invocation: <ast.Name object at 0x10973c100> []
Name reference: list
Name reference: base
Name reference: bases
Func Invocation: <ast.Name object at 0x10973cf70> [<ast.Name object at 0x10973cfa0>, <ast.Constant object at 0x10973cfd0>]
Name reference: hasattr
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c3d0> [<ast.Attribute object at 0x10973c370>]
Name reference: dict
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c610> [<ast.Name object at 0x10973c8b0>]
Name reference: dict
Name reference: attrs
Func Invocation: <ast.Attribute object at 0x10973c670> [<ast.Name object at 0x10973c700>, <ast.Name object at 0x10973c7c0>, <ast.Name object at 0x10973c6d0>, <ast.Name object at 0x10973ca90>]
Name reference: type
Name reference:

In [34]:
bvisitor.visit(long_ast.body[2])
# long_ast.body[2]

Name reference: type
FunctionDef: __new__
Name reference: attrs
Func Invocation: <ast.Name object at 0x10973c2e0> [<ast.Subscript object at 0x10973c2b0>]
Name reference: list
Name reference: dict
Name reference: dict
Func Invocation: <ast.Name object at 0x10973c100> []
Name reference: list
Name reference: base
Name reference: bases
Func Invocation: <ast.Name object at 0x10973cf70> [<ast.Name object at 0x10973cfa0>, <ast.Constant object at 0x10973cfd0>]
Name reference: hasattr
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c3d0> [<ast.Attribute object at 0x10973c370>]
Name reference: dict
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c610> [<ast.Name object at 0x10973c8b0>]
Name reference: dict
Name reference: attrs
Func Invocation: <ast.Attribute object at 0x10973c670> [<ast.Name object at 0x10973c700>, <ast.Name object at 0x10973c7c0>, <ast.Name object at 0x10973c6d0>, <ast.Name object at 0x10973ca90>]
Name reference: type
Name reference:

In [35]:
class YourVisitor(YourVisitor):
    def __init__(self) -> None:
        super().__init__()
        self.list_of_functions = []
        
    def visit_FunctionDef(self, node: FunctionDef):
        self.list_of_functions.append(node)
        return self.generic_visit(node)

In [36]:
cvisitor = YourVisitor()
cvisitor.visit(long_ast)

Name reference: type
Name reference: attrs
Func Invocation: <ast.Name object at 0x10973c2e0> [<ast.Subscript object at 0x10973c2b0>]
Name reference: list
Name reference: dict
Name reference: dict
Func Invocation: <ast.Name object at 0x10973c100> []
Name reference: list
Name reference: base
Name reference: bases
Func Invocation: <ast.Name object at 0x10973cf70> [<ast.Name object at 0x10973cfa0>, <ast.Constant object at 0x10973cfd0>]
Name reference: hasattr
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c3d0> [<ast.Attribute object at 0x10973c370>]
Name reference: dict
Name reference: base
Func Invocation: <ast.Attribute object at 0x10973c610> [<ast.Name object at 0x10973c8b0>]
Name reference: dict
Name reference: attrs
Func Invocation: <ast.Attribute object at 0x10973c670> [<ast.Name object at 0x10973c700>, <ast.Name object at 0x10973c7c0>, <ast.Name object at 0x10973c6d0>, <ast.Name object at 0x10973ca90>]
Name reference: type
Name reference: mcs
Name reference: 

In [37]:
print(cvisitor.list_of_functions)

[<ast.FunctionDef object at 0x10973c790>, <ast.FunctionDef object at 0x10973c8e0>, <ast.FunctionDef object at 0x109727e50>, <ast.FunctionDef object at 0x1097273d0>, <ast.FunctionDef object at 0x109726b60>, <ast.FunctionDef object at 0x109726d40>, <ast.FunctionDef object at 0x109726140>, <ast.FunctionDef object at 0x1097265c0>, <ast.FunctionDef object at 0x1097267d0>, <ast.FunctionDef object at 0x1096d4490>, <ast.FunctionDef object at 0x1096d42b0>]


In [41]:
cvisitor.visit(cvisitor.list_of_functions[3])

Name reference: attr_values
Name reference: attr
Func Invocation: <ast.Name object at 0x109727640> [<ast.Attribute object at 0x1097276a0>]
Name reference: sorted
Name reference: self
Func Invocation: <ast.Attribute object at 0x1097272b0> [<ast.BinOp object at 0x109727220>]
Name reference: attr_values
Name reference: attr
Func Invocation: <ast.Name object at 0x109727190> [<ast.Name object at 0x109727310>, <ast.Name object at 0x1097277c0>]
Name reference: getattr
Name reference: self
Name reference: attr
Func Invocation: <ast.Name object at 0x109727100> [<ast.Name object at 0x1097270d0>]
Name reference: type
Name reference: self
Func Invocation: <ast.Attribute object at 0x109726f20> [<ast.Name object at 0x109726b90>]
Name reference: attr_values
