In [70]:
from pyro_graph_nets.utils.data import random_input_output_graphs

import torch
import networkx as nx
from collections import namedtuple
from typing import *
import numpy as np
from pyro_graph_nets.utils.graph_tuple import to_graph_tuple

def graph_generator(
    n_nodes: Tuple[int, int],
    n_features: Tuple[int, int],
    e_features: Tuple[int, int],
    g_features: Tuple[int, int],
):
    gen = random_input_output_graphs(
        lambda: np.random.randint(*n_nodes),
        20,
        lambda: np.random.uniform(1, 10, n_features[0]),
        lambda: np.random.uniform(1, 10, e_features[0]),
        lambda: np.random.uniform(1, 10, g_features[0]),
        lambda: np.random.uniform(1, 10, n_features[1]),
        lambda: np.random.uniform(1, 10, e_features[1]),
        lambda: np.random.uniform(1, 10, g_features[1]),
        input_attr_name="features",
        target_attr_name="target",
        do_copy=False,
    )
    return gen

gen = graph_generator((2,20), (10, 10), (5, 5), (3, 3))

GraphTuple = namedtuple(
    "GraphTuple",
    [
        "node_attr",  # node level attributes
        "edge_attr",  # edge level attributes
        "global_attr",  # global level attributes
        "edges",  # node-to-node connectivity
        "node_indices",  # tensor where each element indicates the index of the graph the node_attr belongs to
        "edge_indices",
        # tensor where each element indicates the index of the graph that the edge_attr and edges belong to.
    ],
)

In [71]:
graphs = [next(gen) for _ in range(1000)]


In [72]:
def test0():
    to_graph_tuple(graphs)
    
def test1():
    for _ in range(10):
        to_graph_tuple(graphs[:100])
        
%timeit -n 10 test0()

32.7 ms ± 136 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [73]:
%timeit -n 10 test1()

32.8 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [85]:
def pick_edge(g):
    for x in g.edges(data=True):
        return x
    
def pick_node(g):
    for n in g.nodes(data=True):
        return n
    


def to_graph_tuple_new(
    graphs: List[nx.DiGraph],
    feature_key: str = "features",
    global_attr_key: str = "data",
    requires_grad: bool = False,
    device: str = None
) -> GraphTuple:
    """Convert a list og networkx graphs into a GraphTuple.

    :param graphs: list of graphs
    :param feature_key: key to find the node, edge, and global features
    :param global_attr_key: attribute on the NetworkX graph to find the global data (default: 'data')
    :return: GraphTuple, a namedtuple of ['node_attr', 'edge_attr', 'global_attr',
        'edges', 'node_inices', 'edge_indices']
    """
    senders = []
    receivers = []
    edge_attributes = []
    node_attributes = []
    global_attributes = []
    node_indices = []
    edge_indices = []
    
    n_edges = 0
    n_nodes = 0
    for graph in graphs:
        n_edges += graph.number_of_edges()
        n_nodes += graph.number_of_nodes()
    
    n = len(graphs)
    node_idx = np.empty(n_nodes)
    edge_idx = np.empty(n_edges)
    
    edata = pick_edge(graph)[-1][feature_key]
    vdata = pick_node(graph)[-1][feature_key]
    udata = graph.data[feature_key]
    connectivity = np.empty((n_edges, 2))
    
    v = np.empty((n_nodes, *tuple(vdata.shape)))
    e = np.empty((n_edges, *tuple(edata.shape)))
    u = np.empty((n, *tuple(udata.shape)))
    
    _v = 0
    _e = 0
    
    ndict = {}
    
    for gidx, graph in enumerate(graphs):
        for node, ndata in graph.nodes(data=True):
            v[_v] = ndata[feature_key]
            ndict[node] = _v
            node_idx[_v] = gidx
            _v += 1
            
    
        for n1, n2, edata in graph.edges(data=True):
            e[_e] = edata[feature_key]
            edge_idx[_e] = gidx
            connectivity = [ndict[n1], ndict[n2]]
            _e += 1
            
    return GraphTuple(
        torch.tensor(v, dtype=torch.float), 
        torch.tensor(e, dtype=torch.float), 
        torch.tensor(v, dtype=torch.float), 
        torch.tensor(connectivity, dtype=torch.long), 
        torch.tensor(node_idx, dtype=torch.long), 
        torch.tensor(edge_idx, dtype=torch.long)
    )
#     for index, graph in enumerate(graphs):

#         nodes = list(graph.nodes(data=True))
#         edges = list(graph.edges(data=True))

#         new_nodes = list(
#             range(len(node_attributes), len(node_attributes) + graph.number_of_nodes())
#         )
#         ndict = dict(zip([n[0] for n in nodes], new_nodes))

#         if not hasattr(graph, global_attr_key):
#             global_attributes.append([0.])
#         else:
#             global_attributes.append(graph.data[feature_key])
#         for node, ndata in nodes:
#             node_attributes.append(ndata[feature_key])
#             node_indices.append(index)
#         for n1, n2, edata in edges:
#             senders.append(ndict[n1])
#             receivers.append(ndict[n2])
#             edge_attributes.append(edata[feature_key])
#             edge_indices.append(index)

#     def vstack(arr):
#         if not arr:
#             return []
#         return np.vstack(arr)
    
#     node_attr = torch.tensor(vstack(node_attributes), dtype=torch.float, requires_grad=requires_grad)
#     edge_attr = torch.tensor(vstack(edge_attributes), dtype=torch.float, requires_grad=requires_grad)
#     edges = torch.tensor(np.vstack([senders, receivers]).T, dtype=torch.long)
#     global_attr = torch.tensor(vstack(global_attributes), dtype=torch.float, requires_grad=requires_grad)
#     node_indices = torch.tensor(node_indices, dtype=torch.long).detach()
#     edge_indices = torch.tensor(edge_indices, dtype=torch.long).detach()
#     result = GraphTuple(
#         node_attr, edge_attr, global_attr, edges, node_indices, edge_indices
#     )
#     if device:
#         return gt_to_device(result, device)
#     return result

graphs = [next(gen) for _ in range(1000)]


def test0():
    to_graph_tuple(graphs)
    
def test1():
    to_graph_tuple_new(graphs)
    

%timeit -n 10 test1()
%timeit -n 10 test0()

20.2 ms ± 291 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
33.8 ms ± 797 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [91]:
def to_graph_tuple(
        graphs: List[nx.DiGraph],
        feature_key: str = "features",
        global_attr_key: str = "data",
        device: str = None
) -> GraphTuple:
    """Convert a list og networkx graphs into a GraphTuple.

    :param graphs: list of graphs
    :param feature_key: key to find the node, edge, and global features
    :param global_attr_key: attribute on the NetworkX graph to find the global data (default: 'data')
    :return: GraphTuple, a namedtuple of ['node_attr', 'edge_attr', 'global_attr',
        'edges', 'node_inices', 'edge_indices']
    """
    senders = []
    receivers = []
    edge_attributes = []
    node_attributes = []
    global_attributes = []
    node_indices = []
    edge_indices = []

    n_edges = 0
    n_nodes = 0
    for graph in graphs:
        n_edges += graph.number_of_edges()
        n_nodes += graph.number_of_nodes()

    n = len(graphs)
    node_idx = np.empty(n_nodes)
    edge_idx = np.empty(n_edges)

    edata = pick_edge(graph)[-1][feature_key]
    vdata = pick_node(graph)[-1][feature_key]
    udata = getattr(graph, global_attr_key)[feature_key]
    connectivity = np.empty((n_edges, 2))

    v = np.empty((n_nodes, *tuple(vdata.shape)))
    e = np.empty((n_edges, *tuple(edata.shape)))
    u = np.empty((n, *tuple(udata.shape)))

    _v = 0
    _e = 0

    ndict = {}

    for gidx, graph in enumerate(graphs):
        for node, ndata in graph.nodes(data=True):
            v[_v] = ndata[feature_key]
            ndict[node] = _v
            node_idx[_v] = gidx
            _v += 1

        for n1, n2, edata in graph.edges(data=True):
            e[_e] = edata[feature_key]
            edge_idx[_e] = gidx
            connectivity[_e] = [ndict[n1], ndict[n2]]
            _e += 1

        u[gidx] = getattr(graph, global_attr_key)[feature_key]

    result = GraphTuple(
        torch.tensor(v, dtype=torch.float),
        torch.tensor(e, dtype=torch.float),
        torch.tensor(u, dtype=torch.float),
        torch.tensor(connectivity, dtype=torch.long),
        torch.tensor(node_idx, dtype=torch.long),
        torch.tensor(edge_idx, dtype=torch.long)
    )
    if device:
        return GraphTuple(*[x.to(device) for x in result])
    return result

to_graph_tuple(graphs).edges.shape

torch.Size([8026, 2])

In [None]:
x

In [30]:
def pick_edge(g):
    for x in g.edges(data=True):
        return x
    
def pick_node(g):
    for n in g.nodes(data=True):
        return n
    
    

AttributeError: 'DiGraph' object has no attribute 'edge_data'

In [14]:
torch.empty

<function _VariableFunctions.empty>