From 5ec4560a9f1ce3be88553ada3d9b03ee4f6d964f Mon Sep 17 00:00:00 2001 From: jrudz Date: Wed, 15 May 2024 15:50:18 +0200 Subject: [PATCH] resolved indexing issues --- atomisticparsers/utils/parsers.py | 176 +++++++++++++++--------------- tests/test_gromacsparser.py | 10 +- tests/test_lammpsparser.py | 6 +- 3 files changed, 94 insertions(+), 98 deletions(-) diff --git a/atomisticparsers/utils/parsers.py b/atomisticparsers/utils/parsers.py index 6263534..e6a44f2 100644 --- a/atomisticparsers/utils/parsers.py +++ b/atomisticparsers/utils/parsers.py @@ -183,98 +183,98 @@ def parse_interactions(self, interactions: List[Dict], sec_model: MSection) -> N return ## OLD ## - interaction_dict = {} - for interaction_key in Interaction.m_def.all_quantities.keys(): - interaction_dict[interaction_key] = np.array( - [interaction.get(interaction_key) for interaction in interactions], - dtype=object, - ) - interaction_dict = {key: val for key, val in interaction_dict.items()} - interaction_types = ( - np.unique(interaction_dict["type"]) - if interaction_dict.get("type") is not None - else [] - ) - for interaction_type in interaction_types: - sec_interaction = Interaction() - sec_model.contributions.append(sec_interaction) - interaction_indices = np.where( - interaction_dict["type"] == interaction_type - )[0] - sec_interaction.type = interaction_type - sec_interaction.n_interactions = len(interaction_indices) - sec_interaction.n_atoms - for key, val in interaction_dict.items(): - if key == "type": - continue - interaction_vals = val[interaction_indices] - if type(interaction_vals[0]).__name__ == "ndarray": - interaction_vals = np.array( - [vals.tolist() for vals in interaction_vals], dtype=object - ) - if interaction_vals.all() is None: - continue - if key == "parameters": - interaction_vals = interaction_vals.tolist() - elif key == "n_atoms": - interaction_vals = interaction_vals[0] - if hasattr(sec_interaction, key): - sec_interaction.m_set( - sec_interaction.m_get_quantity_definition(key), interaction_vals - ) - - if not sec_interaction.n_atoms: - sec_interaction.n_atoms = ( - len(sec_interaction.get("atom_indices")[0]) - if sec_interaction.get("atom_indices") is not None - else None - ) - ## NEW ## - # def write_interaction_values(values): + # interaction_dict = {} + # for interaction_key in Interaction.m_def.all_quantities.keys(): + # interaction_dict[interaction_key] = np.array( + # [interaction.get(interaction_key) for interaction in interactions], + # dtype=object, + # ) + # interaction_dict = {key: val for key, val in interaction_dict.items()} + # interaction_types = ( + # np.unique(interaction_dict["type"]) + # if interaction_dict.get("type") is not None + # else [] + # ) + # for interaction_type in interaction_types: # sec_interaction = Interaction() # sec_model.contributions.append(sec_interaction) - # sec_interaction.type = current_type - # sec_interaction.n_atoms = max( - # [len(v) for v in values.get("atom_indices", [[0]])] - # ) - # for key, val in values.items(): - # # TODO tempory fix: atom_labels, atom_indices not homogeneous - # # fill in missing atom label with 'X', atom index with -1 - # # if key in ["atom_indices", "atom_labels"]: - # # val = [ - # # ( - # # v - # # + [-1 if key == "atom_indices" else "X"] - # # * sec_interaction.n_atoms - # # )[: sec_interaction.n_atoms] - # # for v in val - # # ] - # quantity_def = sec_interaction.m_def.all_quantities.get(key) - # if quantity_def: - # try: - # sec_interaction.m_set(quantity_def, val) - # except Exception: - # self.logger.error("Error setting metadata.", data={"key": key}) - - # interactions.sort(key=lambda x: x.get("type")) - # print(interactions) - # current_type = interactions[0].get("type") - # interaction_values: Dict[str, Any] = {} - # for interaction in interactions: - # interaction_type = interaction.get("type") - # if current_type and current_type != interaction_type: - # write_interaction_values(interaction_values) - # current_type = interaction_type - # interaction_values = {} - # interaction_values.setdefault("n_interactions", 0) - # interaction_values["n_interactions"] += 1 - # for key, val in interaction.items(): + # interaction_indices = np.where( + # interaction_dict["type"] == interaction_type + # )[0] + # sec_interaction.type = interaction_type + # sec_interaction.n_interactions = len(interaction_indices) + # sec_interaction.n_atoms + # for key, val in interaction_dict.items(): # if key == "type": # continue - # interaction_values.setdefault(key, []) - # interaction_values[key].append(val) - # if interaction_values: - # write_interaction_values(interaction_values) + # interaction_vals = val[interaction_indices] + # if type(interaction_vals[0]).__name__ == "ndarray": + # interaction_vals = np.array( + # [vals.tolist() for vals in interaction_vals], dtype=object + # ) + # if interaction_vals.all() is None: + # continue + # if key == "parameters": + # interaction_vals = interaction_vals.tolist() + # elif key == "n_atoms": + # interaction_vals = interaction_vals[0] + # if hasattr(sec_interaction, key): + # sec_interaction.m_set( + # sec_interaction.m_get_quantity_definition(key), interaction_vals + # ) + + # if not sec_interaction.n_atoms: + # sec_interaction.n_atoms = ( + # len(sec_interaction.get("atom_indices")[0]) + # if sec_interaction.get("atom_indices") is not None + # else None + # ) + ## NEW ## + def write_interaction_values(values): + sec_interaction = Interaction() + sec_model.contributions.append(sec_interaction) + sec_interaction.type = current_type + sec_interaction.n_atoms = max( + [len(v) for v in values.get("atom_indices", [[0]])] + ) + for key, val in values.items(): + # TODO tempory fix: atom_labels, atom_indices not homogeneous + # fill in missing atom label with 'X', atom index with -1 + # if key in ["atom_indices", "atom_labels"]: + # val = [ + # ( + # v + # + [-1 if key == "atom_indices" else "X"] + # * sec_interaction.n_atoms + # )[: sec_interaction.n_atoms] + # for v in val + # ] + quantity_def = sec_interaction.m_def.all_quantities.get(key) + if quantity_def: + try: + sec_interaction.m_set(quantity_def, val) + except Exception: + self.logger.error("Error setting metadata.", data={"key": key}) + + interactions.sort(key=lambda x: x.get("type")) + print(interactions) + current_type = interactions[0].get("type") + interaction_values: Dict[str, Any] = {} + for interaction in interactions: + interaction_type = interaction.get("type") + if current_type and current_type != interaction_type: + write_interaction_values(interaction_values) + current_type = interaction_type + interaction_values = {} + interaction_values.setdefault("n_interactions", 0) + interaction_values["n_interactions"] += 1 + for key, val in interaction.items(): + if key == "type": + continue + interaction_values.setdefault(key, []) + interaction_values[key].append(val) + if interaction_values: + write_interaction_values(interaction_values) def parse_interactions_by_type( self, interactions_by_type: List[Dict], sec_model: Model diff --git a/tests/test_gromacsparser.py b/tests/test_gromacsparser.py index a938d57..440e507 100644 --- a/tests/test_gromacsparser.py +++ b/tests/test_gromacsparser.py @@ -140,8 +140,7 @@ def test_md_verbose(parser): assert sec_systems[1].atoms.positions[800][1].magnitude == approx(2.4740036e-09) assert sec_systems[0].atoms.velocities[500][0].magnitude == approx(869.4773) assert sec_systems[1].atoms.lattice_vectors[2][2].magnitude == approx(2.469158e-09) - # TODO fix this, fails with changes in utils.parsers.parse_interactions - # assert sec_systems[0].atoms.bond_list[200][0] == 289 + assert sec_systems[0].atoms.bond_list[200, 0] == 289 sec_method = sec_run.method assert len(sec_method) == 1 @@ -150,10 +149,9 @@ def test_md_verbose(parser): assert sec_method[0].force_field.model[0].contributions[6].n_interactions == 1017 assert sec_method[0].force_field.model[0].contributions[6].n_atoms == 2 assert sec_method[0].force_field.model[0].contributions[6].atom_labels[10][0] == "C" - # TODO fix this, fails with changes in utils.parsers.parse_interactions - # assert ( - # sec_method[0].force_field.model[0].contributions[6].atom_indices[100][1] == 141 - # ) + assert ( + sec_method[0].force_field.model[0].contributions[6].atom_indices[100, 1] == 141 + ) assert sec_method[0].force_field.model[0].contributions[6].parameters[ 858 ] == approx(0.9999996193044006) diff --git a/tests/test_lammpsparser.py b/tests/test_lammpsparser.py index 98afafd..5d83b2b 100644 --- a/tests/test_lammpsparser.py +++ b/tests/test_lammpsparser.py @@ -72,8 +72,7 @@ def test_nvt(parser): assert sec_method.force_field.model[0].contributions[1].type == "bond" assert sec_method.force_field.model[0].contributions[1].n_interactions == 666 assert sec_method.force_field.model[0].contributions[1].n_atoms == 2 - # TODO fix this, fails with changes in utils.parsers.parse_interactions - # assert sec_method.force_field.model[0].contributions[1].atom_indices[100][1] == 103 + assert sec_method.force_field.model[0].contributions[1].atom_indices[100, 1] == 103 assert sec_method.force_field.model[0].contributions[1].parameters[200] == approx( 1.1147454117684314 ) @@ -93,8 +92,7 @@ def test_nvt(parser): assert sec_system[5].atoms.lattice_vectors[1][1].magnitude == approx(2.24235e-09) assert False not in sec_system[0].atoms.periodic assert sec_system[80].atoms.labels[91:96] == ["H", "H", "H", "C", "C"] - # TODO fix this, fails with changes in utils.parsers.parse_interactions - # assert sec_system[0].atoms.bond_list[200][0] == 194 + assert sec_system[0].atoms.bond_list[200, 0] == 194 sec_scc = sec_run.calculation assert len(sec_scc) == 201