# Chapter 37: Node Visitors and Transformers

This notebook covers the visitor pattern for traversing and modifying Abstract Syntax Trees. You will learn how to use `ast.NodeVisitor` to analyze code, `ast.NodeTransformer` to modify trees, and `ast.walk` for simple iteration over all nodes.

## Key Concepts
- **`ast.NodeVisitor`**: Base class for read-only AST traversal using `visit_*` methods
- **`ast.NodeTransformer`**: Base class for modifying AST nodes in place
- **`ast.walk`**: Simple iterator that yields every node in the tree
- **`ast.fix_missing_locations`**: Fills in line numbers after tree modification
- **`generic_visit`**: Continues traversal to child nodes

## Section 1: NodeVisitor Basics

`ast.NodeVisitor` provides a `visit` method that dispatches to `visit_<NodeType>` methods. For example, `visit_Name` is called for every `Name` node. Call `self.generic_visit(node)` to continue visiting child nodes.

In [None]:
import ast


class NameCollector(ast.NodeVisitor):
    """Collects all variable names referenced in the code."""

    def __init__(self) -> None:
        self.names: list[str] = []

    def visit_Name(self, node: ast.Name) -> None:
        """Called for every Name node in the AST."""
        self.names.append(node.id)
        self.generic_visit(node)


# Parse and visit
tree: ast.Module = ast.parse("x = y + z")
collector: NameCollector = NameCollector()
collector.visit(tree)

print(f"Names found: {collector.names}")
print(f"'y' in names: {'y' in collector.names}")
print(f"'z' in names: {'z' in collector.names}")

In [None]:
import ast


class NameCollector(ast.NodeVisitor):
    """Collects all variable names referenced in the code."""

    def __init__(self) -> None:
        self.names: list[str] = []

    def visit_Name(self, node: ast.Name) -> None:
        self.names.append(node.id)
        self.generic_visit(node)


# Collect names from a more complex expression
source: str = """
total = price * quantity
tax = total * rate
final = total + tax
"""
tree: ast.Module = ast.parse(source)
collector: NameCollector = NameCollector()
collector.visit(tree)

print(f"All names: {collector.names}")
unique_names: set[str] = set(collector.names)
print(f"Unique names: {sorted(unique_names)}")

## Section 2: Counting Specific Node Types

A common use of `NodeVisitor` is counting occurrences of specific constructs, such as function definitions, loops, or imports.

In [None]:
import ast


class FuncCounter(ast.NodeVisitor):
    """Counts function definitions in the AST."""

    def __init__(self) -> None:
        self.count: int = 0

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        """Called for every function definition."""
        self.count += 1
        self.generic_visit(node)


source: str = """
def foo(): pass
def bar(): pass
def baz(): pass
"""
tree: ast.Module = ast.parse(source)
counter: FuncCounter = FuncCounter()
counter.visit(tree)

print(f"Number of functions: {counter.count}")
print(f"Count is 3: {counter.count == 3}")

In [None]:
import ast


class CodeAnalyzer(ast.NodeVisitor):
    """Counts multiple types of nodes in the AST."""

    def __init__(self) -> None:
        self.functions: int = 0
        self.assignments: int = 0
        self.names: list[str] = []

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        self.functions += 1
        self.names.append(node.name)
        self.generic_visit(node)

    def visit_Assign(self, node: ast.Assign) -> None:
        self.assignments += 1
        self.generic_visit(node)


source: str = """
x = 10
y = 20

def add(a, b):
    return a + b

def multiply(a, b):
    result = a * b
    return result

z = add(x, y)
"""
tree: ast.Module = ast.parse(source)
analyzer: CodeAnalyzer = CodeAnalyzer()
analyzer.visit(tree)

print(f"Functions: {analyzer.functions}")
print(f"Function names: {analyzer.names}")
print(f"Assignments: {analyzer.assignments}")

## Section 3: NodeTransformer for Modifying the AST

`ast.NodeTransformer` works like `NodeVisitor` but its `visit_*` methods return a node. Returning a different node replaces the original. After transforming, call `ast.fix_missing_locations` to fill in line numbers before compiling.

In [None]:
import ast


class DoubleConstants(ast.NodeTransformer):
    """Doubles all integer constants in the AST."""

    def visit_Constant(self, node: ast.Constant) -> ast.Constant:
        if isinstance(node.value, int):
            return ast.Constant(value=node.value * 2)
        return node


# Parse, transform, and execute
tree: ast.Module = ast.parse("x = 5")
print(f"Before: {ast.dump(tree)}")

new_tree: ast.Module = DoubleConstants().visit(tree)
ast.fix_missing_locations(new_tree)
print(f"After:  {ast.dump(new_tree)}")

# Compile and run the modified tree
code = compile(new_tree, "<transformed>", "exec")
ns: dict[str, object] = {}
exec(code, ns)  # noqa: S102
print(f"\nx = {ns['x']}")
print(f"x == 10: {ns['x'] == 10}")

In [None]:
import ast


class NegateConstants(ast.NodeTransformer):
    """Negates all integer constants in the AST."""

    def visit_Constant(self, node: ast.Constant) -> ast.Constant:
        if isinstance(node.value, int):
            return ast.Constant(value=-node.value)
        return node


source: str = "result = 10 + 20"
tree: ast.Module = ast.parse(source)
new_tree: ast.Module = NegateConstants().visit(tree)
ast.fix_missing_locations(new_tree)

code = compile(new_tree, "<negated>", "exec")
ns: dict[str, object] = {}
exec(code, ns)  # noqa: S102
print(f"result = {ns['result']}")
print(f"result == -30: {ns['result'] == -30}")

## Section 4: `ast.fix_missing_locations`

When you create or modify AST nodes, they may lack line number and column offset information. `ast.fix_missing_locations` copies location data from parent nodes, which is required before compiling.

In [None]:
import ast

# Create a new node without location info
new_node: ast.Constant = ast.Constant(value=99)
print(f"Has lineno: {hasattr(new_node, 'lineno') and new_node.lineno is not None}")

# Build a module with the node
module: ast.Module = ast.Module(
    body=[ast.Expr(value=new_node)],
    type_ignores=[],
)

# fix_missing_locations adds line/col info
ast.fix_missing_locations(module)
print(f"After fix - lineno: {module.body[0].lineno}")
print(f"After fix - col_offset: {module.body[0].col_offset}")

# Now it can be compiled
code = compile(module, "<test>", "exec")
print(f"\nCompiled successfully: {code is not None}")

## Section 5: Iterating All Nodes with `ast.walk`

`ast.walk` provides a simple way to iterate over every node in the AST without defining a visitor class. It yields nodes in no particular order and does not give you control over recursion.

In [None]:
import ast

# Walk all nodes in an expression
tree: ast.Module = ast.parse("a = b + c * d")
node_types: set[str] = {type(node).__name__ for node in ast.walk(tree)}

print(f"All node types: {sorted(node_types)}")
print(f"'BinOp' found: {'BinOp' in node_types}")
print(f"'Name' found: {'Name' in node_types}")
print(f"'Add' found: {'Add' in node_types}")
print(f"'Mult' found: {'Mult' in node_types}")

In [None]:
import ast

# Count node types using ast.walk
source: str = """
def greet(name):
    message = "Hello, " + name
    return message

def farewell(name):
    return "Goodbye, " + name
"""
tree: ast.Module = ast.parse(source)

# Count specific node types with walk
func_count: int = sum(1 for node in ast.walk(tree) if isinstance(node, ast.FunctionDef))
name_count: int = sum(1 for node in ast.walk(tree) if isinstance(node, ast.Name))
const_count: int = sum(1 for node in ast.walk(tree) if isinstance(node, ast.Constant))

print(f"Functions: {func_count}")
print(f"Name references: {name_count}")
print(f"Constants: {const_count}")

In [None]:
import ast

# Collect all string constants using walk
source: str = """
name = "Alice"
greeting = "Hello"
message = greeting + ", " + name + "!"
"""
tree: ast.Module = ast.parse(source)

strings: list[str] = [
    node.value
    for node in ast.walk(tree)
    if isinstance(node, ast.Constant) and isinstance(node.value, str)
]

print(f"String constants: {strings}")

## Section 6: Practical Code Analysis Patterns

Combining visitors and walkers enables practical code analysis tools such as finding unused imports, detecting complexity, or extracting documentation.

In [None]:
import ast


class ImportCollector(ast.NodeVisitor):
    """Collects all import statements from source code."""

    def __init__(self) -> None:
        self.imports: list[str] = []

    def visit_Import(self, node: ast.Import) -> None:
        for alias in node.names:
            self.imports.append(alias.name)
        self.generic_visit(node)

    def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
        module: str = node.module or ""
        for alias in node.names:
            self.imports.append(f"{module}.{alias.name}")
        self.generic_visit(node)


source: str = """
import os
import sys
from pathlib import Path
from collections import defaultdict, Counter
"""
tree: ast.Module = ast.parse(source)
collector: ImportCollector = ImportCollector()
collector.visit(tree)

print("Imports found:")
for imp in collector.imports:
    print(f"  {imp}")

In [None]:
import ast


class FunctionSummarizer(ast.NodeVisitor):
    """Extracts function signatures and docstrings."""

    def __init__(self) -> None:
        self.summaries: list[dict[str, object]] = []

    def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
        args: list[str] = [arg.arg for arg in node.args.args]
        docstring: str | None = ast.get_docstring(node)
        self.summaries.append({
            "name": node.name,
            "args": args,
            "lineno": node.lineno,
            "docstring": docstring,
        })
        self.generic_visit(node)


source: str = """
def add(a, b):
    \"\"\"Add two numbers.\"\"\"  
    return a + b

def multiply(x, y, z):
    \"\"\"Multiply three numbers together.\"\"\"  
    return x * y * z

def no_doc():
    pass
"""
tree: ast.Module = ast.parse(source)
summarizer: FunctionSummarizer = FunctionSummarizer()
summarizer.visit(tree)

for summary in summarizer.summaries:
    print(f"Function: {summary['name']}({', '.join(summary['args'])})")
    print(f"  Line: {summary['lineno']}")
    print(f"  Docstring: {summary['docstring']}")
    print()

In [None]:
import ast


class RenameVariable(ast.NodeTransformer):
    """Renames all occurrences of a variable in the AST."""

    def __init__(self, old_name: str, new_name: str) -> None:
        self.old_name: str = old_name
        self.new_name: str = new_name

    def visit_Name(self, node: ast.Name) -> ast.Name:
        if node.id == self.old_name:
            return ast.Name(id=self.new_name, ctx=node.ctx)
        return node


source: str = "result = x + y * x"
tree: ast.Module = ast.parse(source)
print(f"Before: {ast.dump(tree, indent=2)}")

# Rename 'x' to 'value'
renamed: ast.Module = RenameVariable("x", "value").visit(tree)
ast.fix_missing_locations(renamed)
print(f"\nAfter: {ast.dump(renamed, indent=2)}")

# Verify by compiling and running
code = compile(renamed, "<renamed>", "exec")
ns: dict[str, object] = {"value": 3, "y": 4}
exec(code, ns)  # noqa: S102
print(f"\nresult = {ns['result']}  (expected: 3 + 4 * 3 = 15)")

## Summary

### NodeVisitor
- Subclass `ast.NodeVisitor` and define `visit_<NodeType>` methods for read-only traversal
- Call `self.generic_visit(node)` to continue visiting child nodes
- Use `visitor.visit(tree)` to start the traversal from the root

### NodeTransformer
- Subclass `ast.NodeTransformer` and return modified or replacement nodes from `visit_*` methods
- Return the original node to keep it unchanged
- Always call `ast.fix_missing_locations(tree)` after transformation before compiling

### ast.walk
- `ast.walk(tree)` yields every node in the tree with no guaranteed order
- Simpler than `NodeVisitor` for flat queries (counting, collecting)
- No control over recursion depth or traversal order

### Common Patterns
- **Name collection**: `visit_Name` to gather variable references
- **Function counting**: `visit_FunctionDef` to count or analyze functions
- **Import analysis**: `visit_Import` and `visit_ImportFrom` for dependency tracking
- **Code transformation**: `NodeTransformer` for renaming, constant folding, or instrumentation