New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up similarity.py and use dataclasses for storing state #5831
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,9 +16,8 @@ | |
import math | ||
import time | ||
import warnings | ||
from functools import reduce | ||
from dataclasses import dataclass | ||
from itertools import product | ||
from operator import mul | ||
|
||
import networkx as nx | ||
|
||
|
@@ -187,7 +186,7 @@ def graph_edit_distance( | |
|
||
""" | ||
bestcost = None | ||
for vertex_path, edge_path, cost in optimize_edit_paths( | ||
for _, _, cost in optimize_edit_paths( | ||
G1, | ||
G2, | ||
node_match, | ||
|
@@ -503,7 +502,7 @@ def optimize_graph_edit_distance( | |
<10.5220/0005209202710278>. <hal-01168816> | ||
https://hal.archives-ouvertes.fr/hal-01168816 | ||
""" | ||
for vertex_path, edge_path, cost in optimize_edit_paths( | ||
for _, _, cost in optimize_edit_paths( | ||
G1, | ||
G2, | ||
node_match, | ||
|
@@ -672,18 +671,12 @@ def optimize_edit_paths( | |
import scipy as sp | ||
import scipy.optimize # call as sp.optimize | ||
|
||
@dataclass | ||
class CostMatrix: | ||
def __init__(self, C, lsa_row_ind, lsa_col_ind, ls): | ||
# assert C.shape[0] == len(lsa_row_ind) | ||
# assert C.shape[1] == len(lsa_col_ind) | ||
# assert len(lsa_row_ind) == len(lsa_col_ind) | ||
# assert set(lsa_row_ind) == set(range(len(lsa_row_ind))) | ||
# assert set(lsa_col_ind) == set(range(len(lsa_col_ind))) | ||
# assert ls == C[lsa_row_ind, lsa_col_ind].sum() | ||
self.C = C | ||
self.lsa_row_ind = lsa_row_ind | ||
self.lsa_col_ind = lsa_col_ind | ||
self.ls = ls | ||
C: ... | ||
lsa_row_ind: ... | ||
lsa_col_ind: ... | ||
ls: ... | ||
Comment on lines
+676
to
+679
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we all just agree that dataclasses without type annotations are a bit daft? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can add annotations once we have settled on them. Dataclasses in itself doesn't care or need annotations https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I simply mean in the sense that you need to have an ellipsis as a placeholder for annotations anyway. So the arguments against type annotations made in that other monster thread (namely visual clutter) fall away entirely in the case of these dataclasses. |
||
|
||
def make_CostMatrix(C, m, n): | ||
# assert(C.shape == (m + n, m + n)) | ||
|
@@ -694,9 +687,9 @@ def make_CostMatrix(C, m, n): | |
# NOTE: fast reduce of Cv relies on it | ||
# assert len(lsa_row_ind) == len(lsa_col_ind) | ||
indexes = zip(range(len(lsa_row_ind)), lsa_row_ind, lsa_col_ind) | ||
subst_ind = list(k for k, i, j in indexes if i < m and j < n) | ||
subst_ind = [k for k, i, j in indexes if i < m and j < n] | ||
indexes = zip(range(len(lsa_row_ind)), lsa_row_ind, lsa_col_ind) | ||
dummy_ind = list(k for k, i, j in indexes if i >= m and j >= n) | ||
dummy_ind = [k for k, i, j in indexes if i >= m and j >= n] | ||
# assert len(subst_ind) == len(dummy_ind) | ||
lsa_row_ind[dummy_ind] = lsa_col_ind[subst_ind] + m | ||
lsa_col_ind[dummy_ind] = lsa_row_ind[subst_ind] + n | ||
|
@@ -724,7 +717,7 @@ def reduce_ind(ind, i): | |
rind[rind >= k] -= 1 | ||
return rind | ||
|
||
def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]): | ||
def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=None): | ||
""" | ||
Parameters: | ||
u, v: matched vertices, u=None or v=None for | ||
|
@@ -748,7 +741,10 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]): | |
# only attempt to match edges after one node match has been made | ||
# this will stop self-edges on the first node being automatically deleted | ||
# even when a substitution is the better option | ||
if matched_uv: | ||
if matched_uv is None or len(matched_uv) == 0: | ||
MridulS marked this conversation as resolved.
Show resolved
Hide resolved
|
||
g_ind = [] | ||
h_ind = [] | ||
else: | ||
g_ind = [ | ||
i | ||
for i in range(M) | ||
|
@@ -765,9 +761,6 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]): | |
pending_h[j][:2] in ((q, v), (v, q), (q, q)) for p, q in matched_uv | ||
) | ||
] | ||
else: | ||
g_ind = [] | ||
h_ind = [] | ||
|
||
m = len(g_ind) | ||
n = len(h_ind) | ||
|
@@ -778,9 +771,9 @@ def match_edges(u, v, pending_g, pending_h, Ce, matched_uv=[]): | |
|
||
# Forbid structurally invalid matches | ||
# NOTE: inf remembered from Ce construction | ||
for k, i in zip(range(m), g_ind): | ||
for k, i in enumerate(g_ind): | ||
dschult marked this conversation as resolved.
Show resolved
Hide resolved
|
||
g = pending_g[i][:2] | ||
for l, j in zip(range(n), h_ind): | ||
for l, j in enumerate(h_ind): | ||
h = pending_h[j][:2] | ||
if nx.is_directed(G1) or nx.is_directed(G2): | ||
if any( | ||
|
@@ -822,8 +815,7 @@ def reduce_Ce(Ce, ij, m, n): | |
m_i = m - sum(1 for t in i if t < m) | ||
n_j = n - sum(1 for t in j if t < n) | ||
return make_CostMatrix(reduce_C(Ce.C, i, j, m, n), m_i, n_j) | ||
else: | ||
return Ce | ||
return Ce | ||
|
||
def get_edit_ops( | ||
matched_uv, pending_u, pending_v, Cv, pending_g, pending_h, Ce, matched_cost | ||
|
@@ -982,8 +974,9 @@ def get_edit_paths( | |
# assert not len(pending_g) | ||
# assert not len(pending_h) | ||
# path completed! | ||
# assert matched_cost <= maxcost.value | ||
maxcost.value = min(maxcost.value, matched_cost) | ||
# assert matched_cost <= maxcost_value | ||
nonlocal maxcost_value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
maxcost_value = min(maxcost_value, matched_cost) | ||
yield matched_uv, matched_gh, matched_cost | ||
|
||
else: | ||
|
@@ -1051,7 +1044,7 @@ def get_edit_paths( | |
for y, h in zip(sortedy, reversed(H)): | ||
if h is not None: | ||
pending_h.insert(y, h) | ||
for t in xy: | ||
for _ in xy: | ||
matched_gh.pop() | ||
|
||
# Initialization | ||
|
@@ -1167,13 +1160,7 @@ def get_edit_paths( | |
# debug_print(Ce.C) | ||
# debug_print() | ||
|
||
class MaxCost: | ||
def __init__(self): | ||
# initial upper-bound estimate | ||
# NOTE: should work for empty graph | ||
self.value = Cv.C.sum() + Ce.C.sum() + 1 | ||
|
||
maxcost = MaxCost() | ||
maxcost_value = Cv.C.sum() + Ce.C.sum() + 1 | ||
|
||
if timeout is not None: | ||
if timeout <= 0: | ||
|
@@ -1187,10 +1174,11 @@ def prune(cost): | |
if upper_bound is not None: | ||
if cost > upper_bound: | ||
return True | ||
if cost > maxcost.value: | ||
if cost > maxcost_value: | ||
return True | ||
elif strictly_decreasing and cost >= maxcost.value: | ||
if strictly_decreasing and cost >= maxcost_value: | ||
return True | ||
return False | ||
|
||
# Now go! | ||
|
||
|
@@ -1204,7 +1192,7 @@ def prune(cost): | |
# assert sorted(G1.edges) == sorted(g for g, h in edge_path if g is not None) | ||
# assert sorted(G2.edges) == sorted(h for g, h in edge_path if h is not None) | ||
# print(vertex_path, edge_path, cost, file = sys.stderr) | ||
# assert cost == maxcost.value | ||
# assert cost == maxcost_value | ||
yield list(vertex_path), list(edge_path), cost | ||
|
||
|
||
|
@@ -1324,9 +1312,9 @@ def simrank(G, u, v): | |
|
||
if isinstance(x, np.ndarray): | ||
if x.ndim == 1: | ||
return {node: val for node, val in zip(G, x)} | ||
else: # x.ndim == 2: | ||
return {u: dict(zip(G, row)) for u, row in zip(G, x)} | ||
return dict(zip(G, x)) | ||
# else x.ndim == 2 | ||
return {u: dict(zip(G, row)) for u, row in zip(G, x)} | ||
return x | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i don't necessarily think this an improvement. I would typically prefix unused loop variables with an underscore, but by removing the names entirely you've stripped out some of the context
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well this is what flake8 would complain about 🙃
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That depends on the configuration that you use. A leading underscore is a widely used convention (in many languages, not just python) for indicating an unused variable, that you still wish to explicitly name.
I'm not sure vanilla flake8 will pick this up, but
flake8-bugbear
certainly will