# work on the tree

In [1]:
%reload_ext autoreload
%autoreload 2

import sys
from pathlib import Path

my_happy_flow_path = str(Path('../../src').resolve())
my_lib_path = str(Path('my_lib').resolve())

if my_lib_path not in sys.path:
    sys.path.append(my_lib_path)


if my_happy_flow_path not in sys.path:
    sys.path.append(my_happy_flow_path)
    


## ast.NodeVisitor

ast.NodeVisitor is the primary tool for ‘scanning’ the tree. 

In [2]:
import ast
import inspect

print(inspect.getsource(ast.NodeVisitor))

class NodeVisitor(object):
    """
    A node visitor base class that walks the abstract syntax tree and calls a
    visitor function for every node found.  This function may return a value
    which is forwarded by the `visit` method.

    This class is meant to be subclassed, with the subclass adding visitor
    methods.

    Per default the visitor functions for the nodes are ``'visit_'`` +
    class name of the node.  So a `TryFinally` node visit function would
    be `visit_TryFinally`.  This behavior can be changed by overriding
    the `visit` method.  If no visitor function exists for a node
    (return value `None`) the `generic_visit` visitor is used instead.

    Don't use the `NodeVisitor` if you want to apply changes to nodes during
    traversing.  For this a special visitor exists (`NodeTransformer`) that
    allows modifications.
    """

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(

To use it, subclass it and override methods visit_Foo, corresponding to the node classes. (see [Meet the Nodes](https://greentreesnakes.readthedocs.io/en/latest/nodes.html)).

For example, this visitor will print the names of any functions defined in the given code, including methods and functions defined within other functions:

In [3]:
import ast
class FuncLister(ast.NodeVisitor):
    def visit_FunctionDef(self, node):
        print('func_name: ', node.name)
        self.generic_visit(node)

In [4]:
source_code = """
def a():
    print('i am a')
    
def b():
    print('call a')
    a()
    def c():
        print('i am c function')
    
b()
""".strip()
FuncLister().visit(ast.parse(source_code))

func_name:  a
func_name:  b
func_name:  c


If you want child nodes to be visited, remember to call self.generic_visit(node) in the methods you override.

Alternatively, you can run through a list of all the nodes in the tree using ast.walk(). There are no guarantees about the order in which nodes will appear. The following example again prints the names of any functions defined within the given code:

In [5]:
import ast

source_code = """
def a():
    print('i am a')
    
def b():
    print('call a')
    a()
    def c():
        print('i am c function')
    
b()
""".strip()

tree = ast.parse(source_code)
for node in ast.walk(tree):
    if isinstance(node, ast.FunctionDef):
        print('func_name: ', node.name)

func_name:  a
func_name:  b
func_name:  c


You can also get the direct children of a node, using ast.iter_child_nodes(). Remember that many nodes have children in several sections: for example, an If has a node in the test field, and list of nodes in body and orelse. ast.iter_child_nodes() will go through all of these.

Finally, you can navigate directly, using the attributes of the nodes. For example, if you want to get the last node within a function’s body, use node.body\[-1\]. Of course, all the normal Python tools for iterating and indexing work. In particular, isinstance() is very useful for checking what nodes are.


## Inspecting nodes

The ast module has a couple of functions for inspecting nodes:

$\large{🍞}$ ast.iter_fields() iterates over the fields defined for a node.

$\large{🍞}$ ast.get_docstring() gets the docstring of a FunctionDef, ClassDef or Module node.

$\large{🍞}$  ast.dump() returns a string showing the node and any children. 


## Modifying the tree

The key tool is ast.NodeTransformer. Like ast.NodeVisitor, you subclass this and override visit_Foo methods. The method should return the original node, a replacement node, or None to remove that node from the tree.

The ast module docs have this example, which rewrites name lookups, so foo becomes data\['foo'\]:


In [6]:
import ast

class RewriteName(ast.NodeTransformer):
    def visit_Name(slef, node):
        return ast.copy_location(ast.Subscript(
            value=ast.Name(id='data', ctx=ast.Load()),
            slice=ast.Index(value=ast.Str(s=node.id)),
            ctx=node.ctx
        ), node)
    

In [10]:
from __future__ import unicode_literals

import ast
import ast_utils
from ast_utils import print_utils
import ujson

tree = ast.parse("foo")
print(ast_utils.dump(tree))

print_utils.print_separator()

print(ast_utils.dump_json(tree))

new_tree = ast.fix_missing_locations(RewriteName().visit(tree))

print_utils.print_separator()

print(ast_utils.unparse(new_tree))

print_utils.print_separator()

print(ast_utils.dump_json(new_tree))

Module(body=[Expr(value=Name(
  id='foo',
  ctx=Load()))])


{
 "_PyType": "Module",
 "body": [
  {
   "_PyType": "Expr",
   "value": {
    "_PyType": "Name",
    "id": "foo",
    "ctx": {
     "_PyType": "Load"
    }
   }
  }
 ]
}



data['foo']



{
 "_PyType": "Module",
 "body": [
  {
   "_PyType": "Expr",
   "value": {
    "_PyType": "Subscript",
    "value": {
     "_PyType": "Name",
     "id": "data",
     "ctx": {
      "_PyType": "Load"
     }
    },
    "slice": {
     "_PyType": "Index",
     "value": {
      "_PyType": "Str",
      "s": "foo"
     }
    },
    "ctx": {
     "_PyType": "Load"
    }
   }
  }
 ]
}


Be careful when removing nodes. You can quite easily remove a node from a required field, such as the test field of an If node. Python won’t complain about the invalid AST until you try to compile() it, when a TypeError is raised.

