Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Speed up common/non_neighbors by using _adj dict operations #7244

Merged
merged 7 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 7 additions & 7 deletions networkx/algorithms/link_prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def predict(u, v):
union_size = len(set(G[u]) | set(G[v]))
if union_size == 0:
return 0
return len(list(nx.common_neighbors(G, u, v))) / union_size
return len(nx.common_neighbors(G, u, v)) / union_size

return _apply_prediction(G, predict, ebunch)

Expand Down Expand Up @@ -329,7 +329,7 @@ def predict(u, v):
if u == v:
raise nx.NetworkXAlgorithmError("Self loops are not supported")

return sum(1 for _ in nx.common_neighbors(G, u, v))
return len(nx.common_neighbors(G, u, v))

else:
spl = dict(nx.shortest_path_length(G))
Expand All @@ -340,9 +340,9 @@ def predict(u, v):
raise nx.NetworkXAlgorithmError("Self loops are not supported")
path_len = spl[u].get(v, inf)

return alpha * sum(1 for _ in nx.common_neighbors(G, u, v)) + (
1 - alpha
) * (G.number_of_nodes() / path_len)
return alpha * len(nx.common_neighbors(G, u, v)) + (1 - alpha) * (
G.number_of_nodes() / path_len
)
rossbar marked this conversation as resolved.
Show resolved Hide resolved

return _apply_prediction(G, predict, ebunch)

Expand Down Expand Up @@ -486,7 +486,7 @@ def cn_soundarajan_hopcroft(G, ebunch=None, community="community"):
def predict(u, v):
Cu = _community(G, u, community)
Cv = _community(G, v, community)
cnbors = list(nx.common_neighbors(G, u, v))
cnbors = nx.common_neighbors(G, u, v)
neighbors = (
sum(_community(G, w, community) == Cu for w in cnbors) if Cu == Cv else 0
)
Expand Down Expand Up @@ -670,7 +670,7 @@ def predict(u, v):
Cv = _community(G, v, community)
if Cu != Cv:
return 0
cnbors = set(nx.common_neighbors(G, u, v))
cnbors = nx.common_neighbors(G, u, v)
within = {w for w in cnbors if _community(G, w, community) == Cu}
inter = cnbors - within
return len(within) / (len(inter) + delta)
Expand Down
57 changes: 30 additions & 27 deletions networkx/classes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,13 +809,13 @@ def set_edge_attributes(G, values, name=None):
if G.is_multigraph():
for (u, v, key), value in values.items():
try:
G[u][v][key][name] = value
G._adj[u][v][key][name] = value
except KeyError:
pass
else:
for (u, v), value in values.items():
try:
G[u][v][name] = value
G._adj[u][v][name] = value
except KeyError:
pass
except AttributeError:
Expand All @@ -827,13 +827,13 @@ def set_edge_attributes(G, values, name=None):
if G.is_multigraph():
for (u, v, key), d in values.items():
try:
G[u][v][key].update(d)
G._adj[u][v][key].update(d)
except KeyError:
pass
else:
for (u, v), d in values.items():
try:
G[u][v].update(d)
G._adj[u][v].update(d)
except KeyError:
pass

Expand Down Expand Up @@ -918,11 +918,10 @@ def non_neighbors(graph, node):

Returns
-------
non_neighbors : iterator
Iterator of nodes in the graph that are not neighbors of the node.
non_neighbors : set
Set of nodes in the graph that are not neighbors of the node.
"""
nbors = set(neighbors(graph, node)) | {node}
return (nnode for nnode in graph if nnode not in nbors)
return graph._adj.keys() - graph._adj[node].keys() - {node}


def non_edges(graph):
Expand Down Expand Up @@ -964,8 +963,8 @@ def common_neighbors(G, u, v):

Returns
-------
cnbors : iterator
Iterator of common neighbors of u and v in the graph.
cnbors : set
Set of common neighbors of u and v in the graph.

Raises
------
Expand All @@ -983,9 +982,7 @@ def common_neighbors(G, u, v):
if v not in G:
raise nx.NetworkXError("v is not in the graph.")

# Return a generator explicitly instead of yielding so that the above
# checks are executed eagerly.
return (w for w in G[u] if w in G[v] and w not in (u, v))
return G._adj[u].keys() & G._adj[v].keys() - {u, v}


def is_weighted(G, edge=None, weight="weight"):
Expand Down Expand Up @@ -1114,7 +1111,7 @@ def is_empty(G):
is the number of nodes in the graph.

"""
return not any(G.adj.values())
return not any(G._adj.values())


def nodes_with_selfloops(G):
Expand All @@ -1141,7 +1138,7 @@ def nodes_with_selfloops(G):
[1]

"""
return (n for n, nbrs in G.adj.items() if n in nbrs)
return (n for n, nbrs in G._adj.items() if n in nbrs)


def selfloop_edges(G, data=False, keys=False, default=None):
Expand Down Expand Up @@ -1191,56 +1188,59 @@ def selfloop_edges(G, data=False, keys=False, default=None):
if keys is True:
return (
(n, n, k, d)
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
for k, d in nbrs[n].items()
)
else:
return (
(n, n, d)
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
for d in nbrs[n].values()
)
else:
return ((n, n, nbrs[n]) for n, nbrs in G.adj.items() if n in nbrs)
return ((n, n, nbrs[n]) for n, nbrs in G._adj.items() if n in nbrs)
elif data is not False:
if G.is_multigraph():
if keys is True:
return (
(n, n, k, d.get(data, default))
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
for k, d in nbrs[n].items()
)
else:
return (
(n, n, d.get(data, default))
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
for d in nbrs[n].values()
)
else:
return (
(n, n, nbrs[n].get(data, default))
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
)
else:
if G.is_multigraph():
if keys is True:
return (
(n, n, k) for n, nbrs in G.adj.items() if n in nbrs for k in nbrs[n]
(n, n, k)
for n, nbrs in G._adj.items()
if n in nbrs
for k in nbrs[n]
)
else:
return (
(n, n)
for n, nbrs in G.adj.items()
for n, nbrs in G._adj.items()
if n in nbrs
for i in range(len(nbrs[n])) # for easy edge removal (#4068)
)
else:
return ((n, n) for n, nbrs in G.adj.items() if n in nbrs)
return ((n, n) for n, nbrs in G._adj.items() if n in nbrs)


def number_of_selfloops(G):
Expand Down Expand Up @@ -1288,7 +1288,10 @@ def is_path(G, path):
True if `path` is a valid path in `G`

"""
return all((node in G and nbr in G[node]) for node, nbr in nx.utils.pairwise(path))
try:
return all(nbr in G._adj[node] for node, nbr in nx.utils.pairwise(path))
except (KeyError, TypeError):
return False


def path_weight(G, path, weight):
Expand Down Expand Up @@ -1323,7 +1326,7 @@ def path_weight(G, path, weight):
raise nx.NetworkXNoPath("path does not exist")
for node, nbr in nx.utils.pairwise(path):
if multigraph:
cost += min(v[weight] for v in G[node][nbr].values())
cost += min(v[weight] for v in G._adj[node][nbr].values())
else:
cost += G[node][nbr][weight]
cost += G._adj[node][nbr][weight]
return cost
8 changes: 4 additions & 4 deletions networkx/classes/tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,13 +297,13 @@ def test_neighbors_complete_graph(self):
def test_non_neighbors(self):
graph = nx.complete_graph(100)
pop = random.sample(list(graph), 1)
nbors = list(nx.non_neighbors(graph, pop[0]))
nbors = nx.non_neighbors(graph, pop[0])
# should be all the other vertices in the graph
assert len(nbors) == 0

graph = nx.path_graph(100)
node = random.sample(list(graph), 1)[0]
nbors = list(nx.non_neighbors(graph, node))
nbors = nx.non_neighbors(graph, node)
# should be all the other vertices in the graph
if node != 0 and node != 99:
assert len(nbors) == 97
Expand All @@ -312,13 +312,13 @@ def test_non_neighbors(self):

# create a star graph with 99 outer nodes
graph = nx.star_graph(99)
nbors = list(nx.non_neighbors(graph, 0))
nbors = nx.non_neighbors(graph, 0)
assert len(nbors) == 0

# disconnected graph
graph = nx.Graph()
graph.add_nodes_from(range(10))
nbors = list(nx.non_neighbors(graph, 0))
nbors = nx.non_neighbors(graph, 0)
assert len(nbors) == 9

def test_non_edges(self):
Expand Down