In [1]:
import ast

code = """
def main():
    def plus(a, b):
        return a + b

    x, y = 1, 2
    res = plus(x, y)
    return res
"""

tree = ast.parse(code)

print(ast.dump(tree))


Module(body=[FunctionDef(name='main', args=arguments(posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[FunctionDef(name='plus', args=arguments(posonlyargs=[], args=[arg(arg='a'), arg(arg='b')], kwonlyargs=[], kw_defaults=[], defaults=[]), body=[Return(value=BinOp(left=Name(id='a', ctx=Load()), op=Add(), right=Name(id='b', ctx=Load())))], decorator_list=[]), Assign(targets=[Tuple(elts=[Name(id='x', ctx=Store()), Name(id='y', ctx=Store())], ctx=Store())], value=Tuple(elts=[Constant(value=1), Constant(value=2)], ctx=Load())), Assign(targets=[Name(id='res', ctx=Store())], value=Call(func=Name(id='plus', ctx=Load()), args=[Name(id='x', ctx=Load()), Name(id='y', ctx=Load())], keywords=[])), Return(value=Name(id='res', ctx=Load()))], decorator_list=[])], type_ignores=[])


In [2]:
import networkx as nx

def ast_to_networkx(tree):
    """
    Converts an AST to a NetworkX graph.

    Args:
        tree: The AST to convert.

    Returns:
        A NetworkX graph.
    """

    graph = nx.DiGraph()

    # Create a node for the root of the tree.
    root_node = graph.add_node(id(tree), label=type(tree).__name__)

    # Recursively add nodes for the children of the root node.
    _add_nodes(graph, id(tree), tree)

    return graph

def _add_nodes(graph, parent_node, node):
    """
    Recursively adds nodes for the children of a given node.

    Args:
        graph: The NetworkX graph.
        parent_node: The node to which the new nodes will be added.
        node: The node whose children will be added.
    """

    # Create a node for the current node.
    node_id = id(node)
    graph.add_node(node_id, label=type(node).__name__)

    # Add an edge between the parent node and the current node.
    graph.add_edge(parent_node, node_id)

    # Recursively add nodes for the children of the current node.
    for child in ast.iter_child_nodes(node):
        _add_nodes(graph, node_id, child)

networkx = ast_to_networkx(tree)

In [3]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.5.3-py3-none-any.whl (1.1 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.3/1.1 MB[0m [31m7.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━[0m [32m0.7/1.1 MB[0m [31m10.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m1.1/1.1 MB[0m [31m11.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m10.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch-geometric
Successfully installed torch-geometric-2.5.3


In [4]:
import torch_geometric.nn as pyg_nn
import torch
from torch_geometric.utils.convert import from_networkx

class GraphConvModel(pyg_nn.MessagePassing):
    def __init__(self, emb_dim):
        super(GraphConvModel, self).__init__(aggr='add')

        self.linear = torch.nn.Linear(emb_dim, emb_dim)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Perform message passing.
        x = self.propagate(edge_index, x=x)

        # Apply a linear transformation to the node features.
        x = self.linear(x)

        return x

    def message(self, x_j, edge_index, index):
        # x_j has shape [E, in_channels]

        # Return the message to be passed to the target nodes.
        return x_j


# Convert NetworkX graph to PyTorch Geometric Data object
data = from_networkx(networkx)
data.x = torch.randn((data.num_nodes, 32))

# Create a GraphConvModel instance.
model = GraphConvModel(emb_dim=32)

# Perform graph embedding.
x = model(data.x, data.edge_index)

# Print the node embeddings.
print(x)

tensor([[ 1.1743, -0.1491, -0.2841, -0.8671, -0.5194, -0.9399, -0.6772, -0.0463,
          0.7097, -0.4908,  0.1794,  0.0355,  0.6212,  0.1596, -0.7153,  0.5101,
         -0.9400,  1.1197,  0.1818,  0.1141,  0.0855, -0.9767,  0.3716,  0.3465,
         -0.5902, -0.6419, -0.0209, -1.1500, -0.2615,  0.4466,  0.3183,  0.5415],
        [ 1.1743, -0.1491, -0.2841, -0.8671, -0.5194, -0.9399, -0.6772, -0.0463,
          0.7097, -0.4908,  0.1794,  0.0355,  0.6212,  0.1596, -0.7153,  0.5101,
         -0.9400,  1.1197,  0.1818,  0.1141,  0.0855, -0.9767,  0.3716,  0.3465,
         -0.5902, -0.6419, -0.0209, -1.1500, -0.2615,  0.4466,  0.3183,  0.5415],
        [ 1.0195, -0.0443,  0.9975,  0.4516, -0.2199, -0.4213, -0.1514,  0.2207,
          0.6055,  0.2642, -0.2074,  0.8438, -0.2558,  0.3416, -0.0530, -0.1145,
          1.1637,  0.0349,  0.1074, -0.3381,  0.3954,  0.0412,  0.1043, -0.1452,
         -0.2886,  0.8808,  0.1413,  0.6297, -0.0069,  0.0617,  1.0973, -0.0807],
        [ 1.0195, -0.0443