Skip to content

Commit

Permalink
Improved attribute handling in Graph.from_networkx (#3432)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr committed Jan 27, 2019
1 parent 9b23328 commit 2aa0a25
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 44 deletions.
112 changes: 87 additions & 25 deletions holoviews/element/graphs.py
@@ -1,4 +1,5 @@
from types import FunctionType
from collections import defaultdict

import param
import numpy as np
Expand Down Expand Up @@ -383,37 +384,98 @@ def edgepaths(self):


@classmethod
def from_networkx(cls, G, layout_function, nodes=None, **kwargs):
def from_networkx(cls, G, positions, nodes=None, **kwargs):
"""
Generate a HoloViews Graph from a networkx.Graph object and
networkx layout function. Any keyword arguments will be passed
to the layout function. By default it will extract all node
and edge attributes from the networkx.Graph but explicit node
information may also be supplied.
networkx layout function or dictionary of node positions.
Any keyword arguments will be passed to the layout
function. By default it will extract all node and edge
attributes from the networkx.Graph but explicit node
information may also be supplied. Any non-scalar attributes,
such as lists or dictionaries will be ignored.
Args:
G (networkx.Graph): Graph to convert to Graph element
positions (dict or callable): Node positions
Node positions defined as a dictionary mapping from
node id to (x, y) tuple or networkx layout function
which computes a positions dictionary
kwargs (dict): Keyword arguments for layout function
Returns:
Graph element
"""
positions = layout_function(G, **kwargs)
edges = []
for start, end in G.edges():
attrs = sorted(G.adj[start][end].items())
edges.append((start, end)+tuple(v for k, v in attrs))
edge_vdims = [k for k, v in attrs] if edges else []
if not isinstance(positions, dict):
positions = positions(G, **kwargs)

# Unpack edges
edges = defaultdict(list)
for start, end in G.edges():
for attr, value in sorted(G.adj[start][end].items()):
if isinstance(value, (list, dict)):
continue # Cannot handle list or dict attrs
edges[attr].append(value)

# Handle tuple node indexes (used in 2D grid Graphs)
if isinstance(start, tuple):
start = str(start)
if isinstance(end, tuple):
end = str(end)
edges['start'].append(start)
edges['end'].append(end)
edge_cols = sorted([k for k in edges if k not in ('start', 'end')
and len(edges[k]) == len(edges['start'])])
edge_vdims = [str(col) if isinstance(col, int) else col for col in edge_cols]
edge_data = tuple(edges[col] for col in ['start', 'end']+edge_cols)

# Unpack user node info
xdim, ydim, idim = cls.node_type.kdims[:3]
if nodes:
idx_dim = nodes.kdims[-1].name
xs, ys = zip(*[v for k, v in sorted(positions.items())])
indices = list(nodes.dimension_values(idx_dim))
edges = [edge for edge in edges if edge[0] in indices and edge[1] in indices]
nodes = nodes.select(**{idx_dim: [eid for e in edges for eid in e]}).sort()
nodes = nodes.add_dimension('x', 0, xs)
nodes = nodes.add_dimension('y', 1, ys).clone(new_type=cls.node_type)
node_columns = nodes.columns()
idx_dim = nodes.kdims[0].name
info_cols, values = zip(*((k, v) for k, v in node_columns.items() if k != idx_dim))
node_info = {i: vals for i, vals in zip(node_columns[idx_dim], zip(*values))}
else:
nodes = []
for idx, pos in sorted(positions.items()):
attrs = sorted(G.nodes[idx].items())
nodes.append(tuple(pos)+(idx,)+tuple(v for k, v in attrs))
vdims = [k for k, v in attrs] if nodes else []
nodes = cls.node_type(nodes, vdims=vdims)
return cls((edges, nodes), vdims=edge_vdims)
info_cols = []
node_info = None
node_columns = defaultdict(list)

# Unpack node positions
for idx, pos in sorted(positions.items()):
node = G.nodes.get(idx)
if node is None:
continue
x, y = pos
node_columns[xdim.name].append(x)
node_columns[ydim.name].append(y)
for attr, value in node.items():
if isinstance(value, (list, dict)):
continue
node_columns[attr].append(value)
for i, col in enumerate(info_cols):
node_columns[col].append(node_info[idx][i])
if isinstance(idx, tuple):
idx = str(idx) # Tuple node indexes handled as strings
node_columns[idim.name].append(idx)
node_cols = sorted([k for k in node_columns if k not in cls.node_type.kdims
and len(node_columns[k]) == len(node_columns[xdim.name])])
columns = [xdim.name, ydim.name, idim.name]+node_cols+list(info_cols)
node_data = tuple(node_columns[col] for col in columns)

# Construct nodes
vdims = []
for col in node_cols:
if isinstance(col, int):
dim = str(col)
elif nodes is not None and col in nodes.vdims:
dim = nodes.get_dimension(col)
else:
dim = col
vdims.append(dim)
nodes = cls.node_type(node_data, vdims=vdims)

# Construct graph
return cls((edge_data, nodes), vdims=edge_vdims)



Expand Down
56 changes: 37 additions & 19 deletions holoviews/tests/element/testgraphelement.py
Expand Up @@ -2,7 +2,6 @@
Unit tests of Graph Element.
"""
from unittest import SkipTest
from nose.plugins.attrib import attr

import numpy as np
from holoviews.core.data import Dataset
Expand Down Expand Up @@ -129,12 +128,16 @@ def test_graph_redim_nodes(self):
self.assertEqual(redimmed.nodes, graph.nodes.redim(x='x2', y='y2'))
self.assertEqual(redimmed.edgepaths, graph.edgepaths.redim(x='x2', y='y2'))

@attr(optional=1)
def test_from_networkx_with_node_attrs(self):
class FromNetworkXTests(ComparisonTestCase):

def setUp(self):
try:
import networkx as nx
import networkx as nx # noqa
except:
raise SkipTest('Test requires networkx to be installed')

def test_from_networkx_with_node_attrs(self):
import networkx as nx
G = nx.karate_club_graph()
graph = Graph.from_networkx(G, nx.circular_layout)
clubs = np.array([
Expand All @@ -146,41 +149,56 @@ def test_from_networkx_with_node_attrs(self):
'Officer', 'Officer', 'Officer', 'Officer'])
self.assertEqual(graph.nodes.dimension_values('club'), clubs)

@attr(optional=1)
def test_from_networkx_with_invalid_node_attrs(self):
import networkx as nx
FG = nx.Graph()
FG.add_node(1, test=[])
FG.add_node(2, test=[])
FG.add_edge(1, 2)
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.nodes.vdims, [])
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2]))
self.assertEqual(graph.array(), np.array([(1, 2)]))

def test_from_networkx_with_edge_attrs(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)])
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.dimension_values('weight'), np.array([0.125, 0.75, 1.2, 0.375]))

@attr(optional=1)
def test_from_networkx_with_invalid_edge_attrs(self):
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,[]), (1,3,[]), (2,4,[]), (3,4,[])])
graph = Graph.from_networkx(FG, nx.circular_layout)
self.assertEqual(graph.vdims, [])

def test_from_networkx_only_nodes(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
G = nx.Graph()
G.add_nodes_from([1, 2, 3])
graph = Graph.from_networkx(G, nx.circular_layout)
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3]))

@attr(optional=1)
def test_from_networkx_custom_nodes(self):
try:
import networkx as nx
except:
raise SkipTest('Test requires networkx to be installed')
import networkx as nx
FG = nx.Graph()
FG.add_weighted_edges_from([(1,2,0.125), (1,3,0.75), (2,4,1.2), (3,4,0.375)])
nodes = Dataset([(1, 'A'), (2, 'B'), (3, 'A'), (4, 'B')], 'index', 'some_attribute')
graph = Graph.from_networkx(FG, nx.circular_layout, nodes=nodes)
self.assertEqual(graph.nodes.dimension_values('some_attribute'), np.array(['A', 'B', 'A', 'B']))

def test_from_networkx_dictionary_positions(self):
import networkx as nx
G = nx.Graph()
G.add_nodes_from([1, 2, 3])
positions = nx.circular_layout(G)
graph = Graph.from_networkx(G, positions)
self.assertEqual(graph.nodes.dimension_values(2), np.array([1, 2, 3]))



class ChordTests(ComparisonTestCase):

def setUp(self):
Expand Down

0 comments on commit 2aa0a25

Please sign in to comment.