diff --git a/emmet-core/emmet/core/mobility/migrationgraph.py b/emmet-core/emmet/core/mobility/migrationgraph.py index 75ec901389..099bf4227a 100644 --- a/emmet-core/emmet/core/mobility/migrationgraph.py +++ b/emmet-core/emmet/core/mobility/migrationgraph.py @@ -183,6 +183,7 @@ def generate_sc_fields( return host_sc, sc_mat, min_length_sc, minmax_num_atoms, coords_dict, combo + @staticmethod def ordered_sc_site_dict( uc_sites_only: Structure, sc_mat: List[List[Union[int, float]]] @@ -202,6 +203,7 @@ def ordered_sc_site_dict( ordered_site_dict = {i: e for i, e in enumerate(sorted(sc_site_dict.values(), key=lambda v: float(np.linalg.norm(v["site"].frac_coords))))} return ordered_site_dict + @staticmethod def get_hop_sc_combo( unique_hops: Dict, sc_mat: List[List[Union[int, float]]], @@ -235,6 +237,7 @@ def get_hop_sc_combo( return combo, ordered_sc_site_dict + @staticmethod def compare_sc_one_hop( one_hop: Dict, sc_mat: List, @@ -260,6 +263,7 @@ def compare_sc_one_hop( return False + @staticmethod def append_new_site( host_sc: Structure, ordered_sc_site_dict: Dict, @@ -296,3 +300,40 @@ def append_new_site( sc_eindex = len(ordered_sc_site_dict) - 1 return f"{sc_iindex}+{sc_eindex}", ordered_sc_site_dict + + def get_distinct_hop_sites( + self + ) -> Tuple[List, List[str], Dict]: + """ + This is a utils function that converts the site dict and combo into a site list and combo that contain only distince endpoints used the combos. + """ + if self.inserted_ion_coords is None or self.insert_coords_combo is None: + raise TypeError("Please make sure that the MGDoc passed in has inserted_ion_coords and inserted_coords_combo fields filled.") + + else: + dis_sites_list = [] + dis_combo_list = [] + mgdoc_sites_mapping = {} # type: dict + combo_mapping = {} + + for one_combo in self.insert_coords_combo: + ini, end = list(map(int, one_combo.split("+"))) + + if ini in mgdoc_sites_mapping.keys(): + dis_ini = mgdoc_sites_mapping[ini] + else: + dis_sites_list.append(list(self.inserted_ion_coords[ini]["site"].frac_coords)) + dis_ini = len(dis_sites_list) - 1 + mgdoc_sites_mapping[ini] = dis_ini + if end in mgdoc_sites_mapping.keys(): + dis_end = mgdoc_sites_mapping[end] + else: + dis_sites_list.append(list(self.inserted_ion_coords[end]["site"].frac_coords)) + dis_end = len(dis_sites_list) - 1 + mgdoc_sites_mapping[end] = dis_end + + dis_combo = f"{dis_ini}+{dis_end}" + dis_combo_list.append(dis_combo) + combo_mapping[dis_combo] = one_combo + + return dis_sites_list, dis_combo_list, combo_mapping diff --git a/tests/emmet-core/mobility/test_migrationgraph.py b/tests/emmet-core/mobility/test_migrationgraph.py index 8da8ed9257..cc99b0fe5e 100644 --- a/tests/emmet-core/mobility/test_migrationgraph.py +++ b/tests/emmet-core/mobility/test_migrationgraph.py @@ -86,3 +86,28 @@ def test_generate_sc_fields(mg_for_sc_fields): hop_sc.insert(0, "Li", sc_esite.frac_coords) check_sc_list = [sm.fit(hop_sc, check_sc) for check_sc in expected_sc_list] assert sum(check_sc_list) >= 1 + + +def test_get_distinct_hop_sites(get_entries): + mgdoc = MigrationGraphDoc.from_entries_and_distance( + battery_id="mp-1234", + grouped_entries=get_entries[0], + working_ion_entry=get_entries[1], + hop_cutoff=5, + populate_sc_fields=True, + min_length_sc=7, + minmax_num_atoms=(80, 160) + ) + dis_sites_list, dis_combo_list, combo_mapping = mgdoc.get_distinct_hop_sites() + for one_test_combo in ['0+1', '0+2', '0+3', '0+4', '0+5', '0+6', '1+7', '1+2']: + assert one_test_combo in dis_combo_list + assert combo_mapping == { + '0+1': '9+4', + '0+2': '9+8', + '0+3': '9+6', + '0+4': '9+16', + '0+5': '9+17', + '0+6': '9+13', + '1+7': '4+0', + '1+2': '4+8' + }