In [2]:
import tree_sitter_python as tspython
from graphviz import Digraph
from tree_sitter import Language, Parser

PY_LANGUAGE = Language(tspython.language())
# Set up parser
parser = Parser(PY_LANGUAGE)

# Sample code
code = b"""
def add(a, b):
    return a + b
"""

tree = parser.parse(code)
root_node = tree.root_node

# Graphviz setup
dot = Digraph(comment="AST")
node_id_counter = [0]


def add_node(graph, node, parent_id=None):
    current_id = str(node_id_counter[0])
    label = node.type
    if node.child_count == 0:
        # Leaf nodes show text
        label += f"\\n'{code[node.start_byte : node.end_byte].decode()}'"
    graph.node(current_id, label)

    if parent_id is not None:
        graph.edge(parent_id, current_id)

    node_id_counter[0] += 1
    for child in node.children:
        add_node(graph, child, current_id)


# Build the tree
add_node(dot, root_node)

# Render to file
dot.render("ast_output", format="png", cleanup=True)
print("AST plotted as ast_output.png")

AST plotted as ast_output.png
