Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Default weight to 1 for unweighted graph during n2v #789

Merged
merged 13 commits into from
Jun 23, 2021
Merged
2 changes: 1 addition & 1 deletion graspologic/embed/n2v.py
Expand Up @@ -130,7 +130,7 @@ def node2vec_embed(
f"Completed. Ending time is {str(end)} Elapsed time is {str(start - end)}"
)

labels = node2vec_graph.original_graph.nodes()
labels = list(node2vec_graph.original_graph.nodes())
nicaurvi marked this conversation as resolved.
Show resolved Hide resolved
remapped_labels = node2vec_graph.label_map_to_string

return (
Expand Down
14 changes: 11 additions & 3 deletions graspologic/utils/utils.py
Expand Up @@ -1004,7 +1004,7 @@ def remap_labels(


def remap_node_ids(
graph: nx.Graph, weight_attribute: str = "weight"
graph: nx.Graph, weight_attribute: str = "weight", weight_default: float = 1.0
) -> Tuple[nx.Graph, Dict[Any, str]]:
"""
Given a graph with arbitrarily types node ids, return a new graph that contains the exact same edgelist
Expand All @@ -1016,7 +1016,8 @@ def remap_node_ids(
A graph that has node ids of arbitrary types.
weight_attribute : str,
Default is ``weight``. An optional attribute to specify which column in your graph contains the weight value.

weight_default : float,
Default edge weight to use if a weight is not found on an edge in the graph
Returns
-------
Tuple[nx.Graph, Dict[Any, str]]
Expand All @@ -1030,10 +1031,17 @@ def remap_node_ids(
if not isinstance(graph, nx.Graph):
raise TypeError("graph must be of type nx.Graph")

if not nx.is_weighted(graph, weight=weight_attribute):
warnings.warn(
f'Graph is unweighted using weight_attribute "{weight_attribute}". Defaulting weights to "{weight_default}"'
nicaurvi marked this conversation as resolved.
Show resolved Hide resolved
)

node_id_dict = dict()
graph_remapped = type(graph)()

for source, target, weight in graph.edges(data=weight_attribute):
for source, target, weight in graph.edges(
data=weight_attribute, default=weight_default
):
if source not in node_id_dict:
node_id_dict[source] = str(len(node_id_dict.keys()))

Expand Down
29 changes: 28 additions & 1 deletion tests/embed/test_n2v.py
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import io
import networkx as nx
import numpy as np
import unittest
Expand Down Expand Up @@ -106,6 +105,34 @@ def test_node2vec_embedding_florentine_graph_correct_shape_is_returned(self):
# vocab list should have exactly 34 elements
self.assertEqual(len(vocab_list), 15)

def test_node2vec_embedding_unweighted_florentine_graph_correct_shape_is_returned(
self,
):
graph = nx.florentine_families_graph()

model = gc.embed.node2vec_embed(graph)
model_matrix: np.ndarray = model[0]
vocab_list = model[1]
self.assertIsNotNone(model)
self.assertIsNotNone(model[0])
self.assertIsNotNone(model[1])

# model matrix should be 34 x 128
self.assertEqual(model_matrix.shape[0], 15)
self.assertEqual(model_matrix.shape[1], 128)

# vocab list should have exactly 34 elements
self.assertEqual(len(vocab_list), 15)

def test_node2vec_same_labels_are_returned(self):
graph = nx.florentine_families_graph()
node_ids = list(graph.nodes())

embedding, labels = gc.embed.node2vec_embed(graph)

for i in range(len(node_ids)):
self.assertEqual(node_ids[i], labels[i])

def test_node2vec_embedding_barbell_graph_correct_shape_is_returned(self):
graph = nx.barbell_graph(25, 2)
for s, t in graph.edges():
Expand Down
16 changes: 16 additions & 0 deletions tests/test_utils.py
@@ -1,7 +1,9 @@
# Copyright (c) Microsoft Corporation and contributors.
# Licensed under the MIT License.

import logging
import unittest
import warnings
from math import sqrt

import networkx as nx
Expand Down Expand Up @@ -587,6 +589,20 @@ def test_remap_node_ids_invalid_typ_raises_typeerror(self):
with pytest.raises(TypeError):
gus.remap_node_ids(graph=type())

def test_remap_node_ids_unweighted_graph_raises_warning(self):
with warnings.catch_warnings(record=True) as warnings_context_manager:
graph = nx.florentine_families_graph()

gus.remap_node_ids(graph)

self.assertEqual(len(warnings_context_manager), 1)
self.assertTrue(
issubclass(warnings_context_manager[0].category, UserWarning)
)
self.assertTrue(
"Graph is unweighted using" in str(warnings_context_manager[0].message)
)

def _assert_graphs_are_equivalent(self, graph, new_graph, new_node_ids):
self.assertTrue(len(new_graph.nodes()) == len(graph.nodes()))
self.assertTrue(len(new_graph.edges()) == len(graph.edges()))
Expand Down