Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix outlier handling in qap when len(seeds)==n #754

Merged
merged 6 commits into from
Apr 10, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions graspologic/match/qap.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,6 +400,8 @@ def _quadratic_assignment_faq(

# check outlier cases
if n == 0 or partial_match.shape[0] == n:
# Cannot assume partial_match is sorted.
partial_match = np.row_stack(sorted(partial_match, key=lambda x: x[0]))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does sorted require partial_match to be a list?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope - it can be any iterable.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great. can you add a test to the function test_barycenter_SGM in test_match.py that gives (a possibly shuffled) full partial_match and returns the correct score_ and perm_inds? (Should be very similar to the test in lines 138-142)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure - done :)

score = _calc_score(A, B, S, partial_match[:, 1])
res = {"col_ind": partial_match[:, 1], "fun": score, "nit": 0}
return OptimizeResult(res)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,24 @@ def test_barycenter_SGM(self):
score = chr12c.score_
assert 11156 == score

W1 = np.array(range(n))
W2 = [pi[z] for z in W1]
chr12c = self.barycenter.fit(A, B, W1, W2)
score = chr12c.score_
assert 11156 == score

W1 = np.array(range(n))
W2 = [pi[z] for z in W1]
# Shuffle seed pairs.
pairs = [i for i in zip(W1, W2)]
random.shuffle(pairs)
W1 = list(list(zip(*pairs))[0])
W2 = list(list(zip(*pairs))[1])
kellymarchisio marked this conversation as resolved.
Show resolved Hide resolved
chr12c = self.barycenter.fit(A, B, W1, W2)
score = chr12c.score_
assert 11156 == score

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you check the perm_inds_ here as well


def test_rand_SGM(self):
A, B = self._get_AB()
chr12c = self.rand.fit(A, B)
Expand Down