|
14 | 14 | OccupancyComparator,
|
15 | 15 | OrderDisorderElementComparator,
|
16 | 16 | StructureMatcher,
|
| 17 | + get_linear_assignment_solution, |
17 | 18 | )
|
18 | 19 | from pymatgen.core import Element, Lattice, Structure, SymmOp
|
19 | 20 | from pymatgen.util.coord import find_in_coord_list_pbc
|
|
22 | 23 | TEST_DIR = f"{TEST_FILES_DIR}/analysis/structure_matcher"
|
23 | 24 |
|
24 | 25 |
|
| 26 | +class TestLinearAssignment: |
| 27 | + def test_square_cost_matrix(self): |
| 28 | + costs_0 = np.array( |
| 29 | + [ |
| 30 | + [19, 95, 9, 43, 62, 90, 10, 77, 71, 27], |
| 31 | + [26, 30, 88, 78, 87, 2, 14, 71, 78, 11], |
| 32 | + [48, 70, 26, 82, 32, 16, 36, 26, 42, 79], |
| 33 | + [47, 46, 93, 66, 38, 20, 73, 39, 55, 51], |
| 34 | + [1, 81, 31, 49, 20, 24, 95, 80, 82, 11], |
| 35 | + [81, 48, 35, 54, 35, 55, 27, 87, 96, 7], |
| 36 | + [42, 17, 60, 73, 37, 36, 79, 3, 60, 82], |
| 37 | + [14, 57, 23, 69, 93, 78, 56, 49, 83, 36], |
| 38 | + [11, 37, 24, 70, 62, 35, 64, 18, 99, 20], |
| 39 | + [73, 11, 98, 50, 19, 96, 61, 73, 98, 14], |
| 40 | + ] |
| 41 | + ) |
| 42 | + sol, min_cost = get_linear_assignment_solution(costs_0) |
| 43 | + assert min_cost == 194, "Incorrect cost" |
| 44 | + |
| 45 | + costs_1 = np.array( |
| 46 | + [ |
| 47 | + [95, 60, 89, 38, 36, 38, 58, 94, 66, 23], |
| 48 | + [37, 0, 40, 58, 97, 85, 18, 54, 86, 21], |
| 49 | + [9, 74, 11, 45, 65, 64, 27, 88, 24, 26], |
| 50 | + [58, 90, 6, 36, 17, 21, 2, 12, 80, 90], |
| 51 | + [33, 0, 74, 75, 11, 84, 34, 7, 39, 0], |
| 52 | + [17, 61, 94, 68, 27, 41, 33, 86, 59, 2], |
| 53 | + [61, 94, 36, 53, 66, 33, 15, 87, 97, 11], |
| 54 | + [22, 20, 57, 69, 15, 9, 15, 8, 82, 68], |
| 55 | + [40, 0, 13, 61, 67, 40, 29, 25, 72, 44], |
| 56 | + [13, 97, 97, 54, 5, 30, 44, 75, 16, 0], |
| 57 | + ] |
| 58 | + ) |
| 59 | + sol, min_cost = get_linear_assignment_solution(costs_1) |
| 60 | + assert min_cost == 125, "Incorrect cost" |
| 61 | + |
| 62 | + costs_2 = np.array( |
| 63 | + [ |
| 64 | + [34, 44, 72, 13, 10, 58, 16, 1, 10, 61], |
| 65 | + [54, 70, 99, 4, 64, 0, 15, 94, 39, 46], |
| 66 | + [49, 21, 80, 68, 96, 58, 24, 87, 79, 67], |
| 67 | + [86, 46, 58, 83, 83, 56, 83, 65, 4, 96], |
| 68 | + [48, 95, 64, 34, 75, 82, 64, 47, 35, 19], |
| 69 | + [11, 49, 6, 57, 80, 26, 47, 63, 75, 75], |
| 70 | + [74, 7, 15, 83, 64, 26, 78, 17, 67, 46], |
| 71 | + [19, 13, 2, 26, 52, 16, 65, 24, 2, 98], |
| 72 | + [36, 7, 93, 93, 11, 39, 94, 26, 46, 69], |
| 73 | + [32, 95, 37, 50, 97, 96, 12, 70, 40, 93], |
| 74 | + ] |
| 75 | + ) |
| 76 | + sol, min_cost = get_linear_assignment_solution(costs_2) |
| 77 | + assert min_cost == 110, "Incorrect cost" |
| 78 | + |
| 79 | + def test_rectangular_cost_matrix(self): |
| 80 | + costs_0 = np.array( |
| 81 | + [ |
| 82 | + [19, 95, 9, 43, 62, 90, 10, 77, 71, 27], |
| 83 | + [26, 30, 88, 78, 87, 2, 14, 71, 78, 11], |
| 84 | + [48, 70, 26, 82, 32, 16, 36, 26, 42, 79], |
| 85 | + [47, 46, 93, 66, 38, 20, 73, 39, 55, 51], |
| 86 | + [1, 81, 31, 49, 20, 24, 95, 80, 82, 11], |
| 87 | + [81, 48, 35, 54, 35, 55, 27, 87, 96, 7], |
| 88 | + [42, 17, 60, 73, 37, 36, 79, 3, 60, 82], |
| 89 | + [14, 57, 23, 69, 93, 78, 56, 49, 83, 36], |
| 90 | + [11, 37, 24, 70, 62, 35, 64, 18, 99, 20], |
| 91 | + ] |
| 92 | + ) |
| 93 | + sol, min_cost_0 = get_linear_assignment_solution(costs_0) |
| 94 | + |
| 95 | + costs_1 = np.array( |
| 96 | + [ |
| 97 | + [19, 95, 9, 43, 62, 90, 10, 77, 71, 27], |
| 98 | + [26, 30, 88, 78, 87, 2, 14, 71, 78, 11], |
| 99 | + [48, 70, 26, 82, 32, 16, 36, 26, 42, 79], |
| 100 | + [47, 46, 93, 66, 38, 20, 73, 39, 55, 51], |
| 101 | + [1, 81, 31, 49, 20, 24, 95, 80, 82, 11], |
| 102 | + [81, 48, 35, 54, 35, 55, 27, 87, 96, 7], |
| 103 | + [42, 17, 60, 73, 37, 36, 79, 3, 60, 82], |
| 104 | + [14, 57, 23, 69, 93, 78, 56, 49, 83, 36], |
| 105 | + [11, 37, 24, 70, 62, 35, 64, 18, 99, 20], |
| 106 | + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| 107 | + ] |
| 108 | + ) |
| 109 | + sol, min_cost_1 = get_linear_assignment_solution(costs_1) |
| 110 | + assert list(sol) == [6, 5, 8, 4, 0, 9, 1, 2, 7, 3] |
| 111 | + |
| 112 | + assert min_cost_0 == min_cost_1 |
| 113 | + |
| 114 | + with pytest.raises(ValueError, match="cost matrix must have at least as many columns as rows"): |
| 115 | + _ = get_linear_assignment_solution(costs_0.T) |
| 116 | + |
| 117 | + def test_another_case(self): |
| 118 | + costs = np.array( |
| 119 | + [ |
| 120 | + [ |
| 121 | + 0.03900238875468465, |
| 122 | + 0.003202415721817453, |
| 123 | + 0.20107156847937024, |
| 124 | + 0.0, |
| 125 | + 0.5002116398420846, |
| 126 | + 0.11951326861160616, |
| 127 | + 0.0, |
| 128 | + 0.5469032363997579, |
| 129 | + 0.3243791041219123, |
| 130 | + 0.1119882291981289, |
| 131 | + ], |
| 132 | + [ |
| 133 | + 0.6048342640688928, |
| 134 | + 0.3847629088356139, |
| 135 | + 0.0, |
| 136 | + 0.44358269535118944, |
| 137 | + 0.45925670625165016, |
| 138 | + 0.31416882324798145, |
| 139 | + 0.8065128182180494, |
| 140 | + 0.0, |
| 141 | + 0.26153475286065075, |
| 142 | + 0.6862799559241944, |
| 143 | + ], |
| 144 | + [ |
| 145 | + 0.5597215814025246, |
| 146 | + 0.15133664165478322, |
| 147 | + 0.0, |
| 148 | + 0.6218101659263295, |
| 149 | + 0.15438455134183793, |
| 150 | + 0.17281467064043232, |
| 151 | + 0.8458127968475472, |
| 152 | + 0.020860721537078075, |
| 153 | + 0.1926886361228456, |
| 154 | + 0.0, |
| 155 | + ], |
| 156 | + [ |
| 157 | + 0.0, |
| 158 | + 0.0, |
| 159 | + 0.6351848838666995, |
| 160 | + 0.21261247074659906, |
| 161 | + 0.4811603832432241, |
| 162 | + 0.6663733668270337, |
| 163 | + 0.63970145187428, |
| 164 | + 0.1415815172623256, |
| 165 | + 0.5294574133825874, |
| 166 | + 0.5576702829768786, |
| 167 | + ], |
| 168 | + [ |
| 169 | + 0.25052904388309016, |
| 170 | + 0.2309392544588127, |
| 171 | + 0.0656162006684271, |
| 172 | + 0.0248922362001176, |
| 173 | + 0.0, |
| 174 | + 0.2101808638720748, |
| 175 | + 0.6529031699724193, |
| 176 | + 0.1503003886507902, |
| 177 | + 0.375576165698992, |
| 178 | + 0.7368328849560374, |
| 179 | + ], |
| 180 | + [ |
| 181 | + 0.0, |
| 182 | + 0.042215873587668984, |
| 183 | + 0.10326920761908365, |
| 184 | + 0.3562551151517992, |
| 185 | + 0.9170343984958856, |
| 186 | + 0.818783531026254, |
| 187 | + 0.7896770426052844, |
| 188 | + 0.0, |
| 189 | + 0.6573135097946438, |
| 190 | + 0.17806189728574429, |
| 191 | + ], |
| 192 | + [ |
| 193 | + 0.44992199118890386, |
| 194 | + 0.0, |
| 195 | + 0.38548898339412585, |
| 196 | + 0.6269193883601244, |
| 197 | + 1.0022861602564634, |
| 198 | + 0.0, |
| 199 | + 0.1869765500803764, |
| 200 | + 0.03474156273982543, |
| 201 | + 0.3715310534696664, |
| 202 | + 0.6197122486230232, |
| 203 | + ], |
| 204 | + [ |
| 205 | + 0.37939853696836545, |
| 206 | + 0.2421427374018027, |
| 207 | + 0.5586150342727723, |
| 208 | + 0.0, |
| 209 | + 0.7171485794073893, |
| 210 | + 0.8021029235865014, |
| 211 | + 0.11213464903613135, |
| 212 | + 0.6497896761660467, |
| 213 | + 0.3274108706187846, |
| 214 | + 0.0, |
| 215 | + ], |
| 216 | + [ |
| 217 | + 0.6674685746225324, |
| 218 | + 0.5347953626128863, |
| 219 | + 0.11461835366075113, |
| 220 | + 0.0, |
| 221 | + 0.8170639855163434, |
| 222 | + 0.7291931505979982, |
| 223 | + 0.3149153087053108, |
| 224 | + 0.1008681103294512, |
| 225 | + 0.0, |
| 226 | + 0.18751172321112997, |
| 227 | + ], |
| 228 | + [ |
| 229 | + 0.6985944652913342, |
| 230 | + 0.6139921045056471, |
| 231 | + 0.0, |
| 232 | + 0.4393266955771965, |
| 233 | + 0.0, |
| 234 | + 0.47265399761400695, |
| 235 | + 0.3674241844351025, |
| 236 | + 0.04731761392352629, |
| 237 | + 0.21484886069716147, |
| 238 | + 0.16488710920126137, |
| 239 | + ], |
| 240 | + ] |
| 241 | + ) |
| 242 | + sol, min_cost = get_linear_assignment_solution(costs) |
| 243 | + assert min_cost == approx(0) |
| 244 | + |
| 245 | + def test_small_range(self): |
| 246 | + # can be tricky for the augment step |
| 247 | + costs = np.array( |
| 248 | + [ |
| 249 | + [4, 5, 5, 6, 8, 4, 7, 4, 7, 8], |
| 250 | + [5, 6, 6, 6, 7, 6, 6, 5, 6, 7], |
| 251 | + [4, 4, 5, 7, 7, 4, 8, 4, 7, 7], |
| 252 | + [6, 7, 6, 6, 7, 6, 6, 6, 6, 6], |
| 253 | + [4, 4, 4, 6, 6, 4, 7, 4, 7, 7], |
| 254 | + [4, 5, 5, 6, 8, 4, 7, 4, 7, 8], |
| 255 | + [5, 7, 5, 5, 5, 6, 4, 5, 4, 6], |
| 256 | + [8, 9, 8, 4, 5, 9, 4, 8, 4, 4], |
| 257 | + [5, 6, 6, 6, 7, 6, 6, 5, 6, 7], |
| 258 | + [5, 6, 6, 6, 7, 6, 6, 5, 6, 7], |
| 259 | + ] |
| 260 | + ) |
| 261 | + assert get_linear_assignment_solution(costs)[1] == 48 |
| 262 | + |
| 263 | + def test_boolean_inputs(self): |
| 264 | + costs_ones = np.ones((135, 135), dtype=bool) |
| 265 | + np.fill_diagonal(costs_ones, val=False) |
| 266 | + |
| 267 | + |
25 | 268 | class TestStructureMatcher(MatSciTest):
|
26 | 269 | def setup_method(self):
|
27 | 270 | with open(f"{TEST_FILES_DIR}/entries/TiO2_entries.json", encoding="utf-8") as file:
|
|
0 commit comments