Skip to content

Commit e0a749b

Browse files
committed
Just delete linear_assignment and move to structure amtcher, which is
the only place this is used.
1 parent fe485dc commit e0a749b

File tree

4 files changed

+273
-295
lines changed

4 files changed

+273
-295
lines changed

src/pymatgen/analysis/structure_matcher.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
import numpy as np
1212
from monty.json import MSONable
13+
from scipy.optimize import linear_sum_assignment
1314

1415
from pymatgen.core import SETTINGS, Composition, IStructure, Lattice, Structure, get_el_sp
15-
from pymatgen.optimization.linear_assignment import LinearAssignment
1616
from pymatgen.util.coord import lattice_points_in_supercell
1717
from pymatgen.util.coord_cython import is_coord_subset_pbc, pbc_shortest_vectors
1818

@@ -34,6 +34,33 @@
3434
LRU_CACHE_SIZE = SETTINGS.get("STRUCTURE_MATCHER_CACHE_SIZE", 300)
3535

3636

37+
def get_linear_assignment_solution(cost_matrix: np.ndarray):
38+
"""Wrapper for SciPy's linear_sum_assignment.
39+
40+
Args:
41+
costs: The cost matrix of the problem. cost[i,j] should be the
42+
cost of matching x[i] to y[j]. The cost matrix may be
43+
rectangular
44+
epsilon: Tolerance for determining if solution vector is < 0
45+
46+
Returns:
47+
min_cost: The minimum cost of the matching.
48+
solution: The matching of the rows to columns. i.e solution = [1, 2, 0]
49+
would match row 0 to column 1, row 1 to column 2 and row 2
50+
to column 0. Total cost would be c[0, 1] + c[1, 2] + c[2, 0].
51+
"""
52+
53+
orig_c = np.asarray(cost_matrix, dtype=float)
54+
55+
n_rows, n_cols = orig_c.shape
56+
if n_rows > n_cols:
57+
raise ValueError("cost matrix must have at least as many columns as rows")
58+
59+
row_ind, col_ind = linear_sum_assignment(orig_c)
60+
61+
return col_ind, orig_c[row_ind, col_ind].sum()
62+
63+
3764
class SiteOrderedIStructure(IStructure):
3865
"""
3966
Imutable structure where the order of sites matters.
@@ -542,8 +569,7 @@ def _cart_dists(cls, s1, s2, avg_lattice, mask, normalization, lll_frac_tol=None
542569

543570
# vectors are from s2 to s1
544571
vecs, d_2 = pbc_shortest_vectors(avg_lattice, s2, s1, mask, return_d2=True, lll_frac_tol=lll_frac_tol)
545-
lin = LinearAssignment(d_2)
546-
sol = lin.solution
572+
sol, _ = get_linear_assignment_solution(d_2)
547573
short_vecs = vecs[np.arange(len(sol)), sol]
548574
translation = np.mean(short_vecs, axis=0)
549575
f_translation = avg_lattice.get_fractional_coords(translation)
@@ -771,7 +797,7 @@ def _strict_match(
771797
if (not self._subset) and mask.shape[1] != mask.shape[0]:
772798
return None
773799

774-
if LinearAssignment(mask).min_cost > 0:
800+
if get_linear_assignment_solution(mask)[1] > 0:
775801
return None
776802

777803
best_match = None

src/pymatgen/optimization/linear_assignment.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

tests/analysis/test_structure_matcher.py

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
OccupancyComparator,
1515
OrderDisorderElementComparator,
1616
StructureMatcher,
17+
get_linear_assignment_solution,
1718
)
1819
from pymatgen.core import Element, Lattice, Structure, SymmOp
1920
from pymatgen.util.coord import find_in_coord_list_pbc
@@ -22,6 +23,248 @@
2223
TEST_DIR = f"{TEST_FILES_DIR}/analysis/structure_matcher"
2324

2425

26+
class TestLinearAssignment:
27+
def test_square_cost_matrix(self):
28+
costs_0 = np.array(
29+
[
30+
[19, 95, 9, 43, 62, 90, 10, 77, 71, 27],
31+
[26, 30, 88, 78, 87, 2, 14, 71, 78, 11],
32+
[48, 70, 26, 82, 32, 16, 36, 26, 42, 79],
33+
[47, 46, 93, 66, 38, 20, 73, 39, 55, 51],
34+
[1, 81, 31, 49, 20, 24, 95, 80, 82, 11],
35+
[81, 48, 35, 54, 35, 55, 27, 87, 96, 7],
36+
[42, 17, 60, 73, 37, 36, 79, 3, 60, 82],
37+
[14, 57, 23, 69, 93, 78, 56, 49, 83, 36],
38+
[11, 37, 24, 70, 62, 35, 64, 18, 99, 20],
39+
[73, 11, 98, 50, 19, 96, 61, 73, 98, 14],
40+
]
41+
)
42+
sol, min_cost = get_linear_assignment_solution(costs_0)
43+
assert min_cost == 194, "Incorrect cost"
44+
45+
costs_1 = np.array(
46+
[
47+
[95, 60, 89, 38, 36, 38, 58, 94, 66, 23],
48+
[37, 0, 40, 58, 97, 85, 18, 54, 86, 21],
49+
[9, 74, 11, 45, 65, 64, 27, 88, 24, 26],
50+
[58, 90, 6, 36, 17, 21, 2, 12, 80, 90],
51+
[33, 0, 74, 75, 11, 84, 34, 7, 39, 0],
52+
[17, 61, 94, 68, 27, 41, 33, 86, 59, 2],
53+
[61, 94, 36, 53, 66, 33, 15, 87, 97, 11],
54+
[22, 20, 57, 69, 15, 9, 15, 8, 82, 68],
55+
[40, 0, 13, 61, 67, 40, 29, 25, 72, 44],
56+
[13, 97, 97, 54, 5, 30, 44, 75, 16, 0],
57+
]
58+
)
59+
sol, min_cost = get_linear_assignment_solution(costs_1)
60+
assert min_cost == 125, "Incorrect cost"
61+
62+
costs_2 = np.array(
63+
[
64+
[34, 44, 72, 13, 10, 58, 16, 1, 10, 61],
65+
[54, 70, 99, 4, 64, 0, 15, 94, 39, 46],
66+
[49, 21, 80, 68, 96, 58, 24, 87, 79, 67],
67+
[86, 46, 58, 83, 83, 56, 83, 65, 4, 96],
68+
[48, 95, 64, 34, 75, 82, 64, 47, 35, 19],
69+
[11, 49, 6, 57, 80, 26, 47, 63, 75, 75],
70+
[74, 7, 15, 83, 64, 26, 78, 17, 67, 46],
71+
[19, 13, 2, 26, 52, 16, 65, 24, 2, 98],
72+
[36, 7, 93, 93, 11, 39, 94, 26, 46, 69],
73+
[32, 95, 37, 50, 97, 96, 12, 70, 40, 93],
74+
]
75+
)
76+
sol, min_cost = get_linear_assignment_solution(costs_2)
77+
assert min_cost == 110, "Incorrect cost"
78+
79+
def test_rectangular_cost_matrix(self):
80+
costs_0 = np.array(
81+
[
82+
[19, 95, 9, 43, 62, 90, 10, 77, 71, 27],
83+
[26, 30, 88, 78, 87, 2, 14, 71, 78, 11],
84+
[48, 70, 26, 82, 32, 16, 36, 26, 42, 79],
85+
[47, 46, 93, 66, 38, 20, 73, 39, 55, 51],
86+
[1, 81, 31, 49, 20, 24, 95, 80, 82, 11],
87+
[81, 48, 35, 54, 35, 55, 27, 87, 96, 7],
88+
[42, 17, 60, 73, 37, 36, 79, 3, 60, 82],
89+
[14, 57, 23, 69, 93, 78, 56, 49, 83, 36],
90+
[11, 37, 24, 70, 62, 35, 64, 18, 99, 20],
91+
]
92+
)
93+
sol, min_cost_0 = get_linear_assignment_solution(costs_0)
94+
95+
costs_1 = np.array(
96+
[
97+
[19, 95, 9, 43, 62, 90, 10, 77, 71, 27],
98+
[26, 30, 88, 78, 87, 2, 14, 71, 78, 11],
99+
[48, 70, 26, 82, 32, 16, 36, 26, 42, 79],
100+
[47, 46, 93, 66, 38, 20, 73, 39, 55, 51],
101+
[1, 81, 31, 49, 20, 24, 95, 80, 82, 11],
102+
[81, 48, 35, 54, 35, 55, 27, 87, 96, 7],
103+
[42, 17, 60, 73, 37, 36, 79, 3, 60, 82],
104+
[14, 57, 23, 69, 93, 78, 56, 49, 83, 36],
105+
[11, 37, 24, 70, 62, 35, 64, 18, 99, 20],
106+
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
107+
]
108+
)
109+
sol, min_cost_1 = get_linear_assignment_solution(costs_1)
110+
assert list(sol) == [6, 5, 8, 4, 0, 9, 1, 2, 7, 3]
111+
112+
assert min_cost_0 == min_cost_1
113+
114+
with pytest.raises(ValueError, match="cost matrix must have at least as many columns as rows"):
115+
_ = get_linear_assignment_solution(costs_0.T)
116+
117+
def test_another_case(self):
118+
costs = np.array(
119+
[
120+
[
121+
0.03900238875468465,
122+
0.003202415721817453,
123+
0.20107156847937024,
124+
0.0,
125+
0.5002116398420846,
126+
0.11951326861160616,
127+
0.0,
128+
0.5469032363997579,
129+
0.3243791041219123,
130+
0.1119882291981289,
131+
],
132+
[
133+
0.6048342640688928,
134+
0.3847629088356139,
135+
0.0,
136+
0.44358269535118944,
137+
0.45925670625165016,
138+
0.31416882324798145,
139+
0.8065128182180494,
140+
0.0,
141+
0.26153475286065075,
142+
0.6862799559241944,
143+
],
144+
[
145+
0.5597215814025246,
146+
0.15133664165478322,
147+
0.0,
148+
0.6218101659263295,
149+
0.15438455134183793,
150+
0.17281467064043232,
151+
0.8458127968475472,
152+
0.020860721537078075,
153+
0.1926886361228456,
154+
0.0,
155+
],
156+
[
157+
0.0,
158+
0.0,
159+
0.6351848838666995,
160+
0.21261247074659906,
161+
0.4811603832432241,
162+
0.6663733668270337,
163+
0.63970145187428,
164+
0.1415815172623256,
165+
0.5294574133825874,
166+
0.5576702829768786,
167+
],
168+
[
169+
0.25052904388309016,
170+
0.2309392544588127,
171+
0.0656162006684271,
172+
0.0248922362001176,
173+
0.0,
174+
0.2101808638720748,
175+
0.6529031699724193,
176+
0.1503003886507902,
177+
0.375576165698992,
178+
0.7368328849560374,
179+
],
180+
[
181+
0.0,
182+
0.042215873587668984,
183+
0.10326920761908365,
184+
0.3562551151517992,
185+
0.9170343984958856,
186+
0.818783531026254,
187+
0.7896770426052844,
188+
0.0,
189+
0.6573135097946438,
190+
0.17806189728574429,
191+
],
192+
[
193+
0.44992199118890386,
194+
0.0,
195+
0.38548898339412585,
196+
0.6269193883601244,
197+
1.0022861602564634,
198+
0.0,
199+
0.1869765500803764,
200+
0.03474156273982543,
201+
0.3715310534696664,
202+
0.6197122486230232,
203+
],
204+
[
205+
0.37939853696836545,
206+
0.2421427374018027,
207+
0.5586150342727723,
208+
0.0,
209+
0.7171485794073893,
210+
0.8021029235865014,
211+
0.11213464903613135,
212+
0.6497896761660467,
213+
0.3274108706187846,
214+
0.0,
215+
],
216+
[
217+
0.6674685746225324,
218+
0.5347953626128863,
219+
0.11461835366075113,
220+
0.0,
221+
0.8170639855163434,
222+
0.7291931505979982,
223+
0.3149153087053108,
224+
0.1008681103294512,
225+
0.0,
226+
0.18751172321112997,
227+
],
228+
[
229+
0.6985944652913342,
230+
0.6139921045056471,
231+
0.0,
232+
0.4393266955771965,
233+
0.0,
234+
0.47265399761400695,
235+
0.3674241844351025,
236+
0.04731761392352629,
237+
0.21484886069716147,
238+
0.16488710920126137,
239+
],
240+
]
241+
)
242+
sol, min_cost = get_linear_assignment_solution(costs)
243+
assert min_cost == approx(0)
244+
245+
def test_small_range(self):
246+
# can be tricky for the augment step
247+
costs = np.array(
248+
[
249+
[4, 5, 5, 6, 8, 4, 7, 4, 7, 8],
250+
[5, 6, 6, 6, 7, 6, 6, 5, 6, 7],
251+
[4, 4, 5, 7, 7, 4, 8, 4, 7, 7],
252+
[6, 7, 6, 6, 7, 6, 6, 6, 6, 6],
253+
[4, 4, 4, 6, 6, 4, 7, 4, 7, 7],
254+
[4, 5, 5, 6, 8, 4, 7, 4, 7, 8],
255+
[5, 7, 5, 5, 5, 6, 4, 5, 4, 6],
256+
[8, 9, 8, 4, 5, 9, 4, 8, 4, 4],
257+
[5, 6, 6, 6, 7, 6, 6, 5, 6, 7],
258+
[5, 6, 6, 6, 7, 6, 6, 5, 6, 7],
259+
]
260+
)
261+
assert get_linear_assignment_solution(costs)[1] == 48
262+
263+
def test_boolean_inputs(self):
264+
costs_ones = np.ones((135, 135), dtype=bool)
265+
np.fill_diagonal(costs_ones, val=False)
266+
267+
25268
class TestStructureMatcher(MatSciTest):
26269
def setup_method(self):
27270
with open(f"{TEST_FILES_DIR}/entries/TiO2_entries.json", encoding="utf-8") as file:

0 commit comments

Comments
 (0)