Skip to content

Commit

Permalink
Default weight to 1 for unweighted graph during n2v (#789)
Browse files Browse the repository at this point in the history
* Allow undirected graphs for n2v
  • Loading branch information
nicaurvi committed Jun 23, 2021
1 parent 5ed3648 commit c266aa3
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 5 deletions.
2 changes: 1 addition & 1 deletion graspologic/embed/n2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def node2vec_embed(
f"Completed. Ending time is {str(end)} Elapsed time is {str(start - end)}"
)

labels = node2vec_graph.original_graph.nodes()
labels = list(node2vec_graph.original_graph.nodes())
remapped_labels = node2vec_graph.label_map_to_string

return (
Expand Down
15 changes: 12 additions & 3 deletions graspologic/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def remap_labels(


def remap_node_ids(
graph: nx.Graph, weight_attribute: str = "weight"
graph: nx.Graph, weight_attribute: str = "weight", weight_default: float = 1.0
) -> Tuple[nx.Graph, Dict[Any, str]]:
"""
Given a graph with arbitrarily types node ids, return a new graph that contains the exact same edgelist
Expand All @@ -1016,7 +1016,8 @@ def remap_node_ids(
A graph that has node ids of arbitrary types.
weight_attribute : str,
Default is ``weight``. An optional attribute to specify which column in your graph contains the weight value.
weight_default : float,
Default edge weight to use if a weight is not found on an edge in the graph
Returns
-------
Tuple[nx.Graph, Dict[Any, str]]
Expand All @@ -1030,10 +1031,18 @@ def remap_node_ids(
if not isinstance(graph, nx.Graph):
raise TypeError("graph must be of type nx.Graph")

if not nx.is_weighted(graph, weight=weight_attribute):
warnings.warn(
f'Graph has at least one unweighted edge using weight_attribute "{weight_attribute}". '
f'Defaulting unweighted edges to "{weight_default}"'
)

node_id_dict = dict()
graph_remapped = type(graph)()

for source, target, weight in graph.edges(data=weight_attribute):
for source, target, weight in graph.edges(
data=weight_attribute, default=weight_default
):
if source not in node_id_dict:
node_id_dict[source] = str(len(node_id_dict.keys()))

Expand Down
29 changes: 28 additions & 1 deletion tests/embed/test_n2v.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import io
import networkx as nx
import numpy as np
import unittest
Expand Down Expand Up @@ -106,6 +105,34 @@ def test_node2vec_embedding_florentine_graph_correct_shape_is_returned(self):
# vocab list should have exactly 34 elements
self.assertEqual(len(vocab_list), 15)

def test_node2vec_embedding_unweighted_florentine_graph_correct_shape_is_returned(
self,
):
graph = nx.florentine_families_graph()

model = gc.embed.node2vec_embed(graph)
model_matrix: np.ndarray = model[0]
vocab_list = model[1]
self.assertIsNotNone(model)
self.assertIsNotNone(model[0])
self.assertIsNotNone(model[1])

# model matrix should be 34 x 128
self.assertEqual(model_matrix.shape[0], 15)
self.assertEqual(model_matrix.shape[1], 128)

# vocab list should have exactly 34 elements
self.assertEqual(len(vocab_list), 15)

def test_node2vec_same_labels_are_returned(self):
graph = nx.florentine_families_graph()
node_ids = list(graph.nodes())

embedding, labels = gc.embed.node2vec_embed(graph)

for i in range(len(node_ids)):
self.assertEqual(node_ids[i], labels[i])

def test_node2vec_embedding_barbell_graph_correct_shape_is_returned(self):
graph = nx.barbell_graph(25, 2)
for s, t in graph.edges():
Expand Down
17 changes: 17 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import logging
import unittest
import warnings
from math import sqrt

import networkx as nx
Expand Down Expand Up @@ -587,6 +589,21 @@ def test_remap_node_ids_invalid_typ_raises_typeerror(self):
with pytest.raises(TypeError):
gus.remap_node_ids(graph=type())

def test_remap_node_ids_unweighted_graph_raises_warning(self):
with warnings.catch_warnings(record=True) as warnings_context_manager:
graph = nx.florentine_families_graph()

gus.remap_node_ids(graph)

self.assertEqual(len(warnings_context_manager), 1)
self.assertTrue(
issubclass(warnings_context_manager[0].category, UserWarning)
)
self.assertTrue(
"Graph has at least one unweighted edge"
in str(warnings_context_manager[0].message)
)

def _assert_graphs_are_equivalent(self, graph, new_graph, new_node_ids):
self.assertTrue(len(new_graph.nodes()) == len(graph.nodes()))
self.assertTrue(len(new_graph.edges()) == len(graph.edges()))
Expand Down

0 comments on commit c266aa3

Please sign in to comment.