Skip to content

Commit

Permalink
resolved indexing issues
Browse files Browse the repository at this point in the history
  • Loading branch information
jrudz committed May 15, 2024
1 parent ff415cf commit 5ec4560
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 98 deletions.
176 changes: 88 additions & 88 deletions atomisticparsers/utils/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,98 +183,98 @@ def parse_interactions(self, interactions: List[Dict], sec_model: MSection) -> N
return

## OLD ##
interaction_dict = {}
for interaction_key in Interaction.m_def.all_quantities.keys():
interaction_dict[interaction_key] = np.array(
[interaction.get(interaction_key) for interaction in interactions],
dtype=object,
)
interaction_dict = {key: val for key, val in interaction_dict.items()}
interaction_types = (
np.unique(interaction_dict["type"])
if interaction_dict.get("type") is not None
else []
)
for interaction_type in interaction_types:
sec_interaction = Interaction()
sec_model.contributions.append(sec_interaction)
interaction_indices = np.where(
interaction_dict["type"] == interaction_type
)[0]
sec_interaction.type = interaction_type
sec_interaction.n_interactions = len(interaction_indices)
sec_interaction.n_atoms
for key, val in interaction_dict.items():
if key == "type":
continue
interaction_vals = val[interaction_indices]
if type(interaction_vals[0]).__name__ == "ndarray":
interaction_vals = np.array(
[vals.tolist() for vals in interaction_vals], dtype=object
)
if interaction_vals.all() is None:
continue
if key == "parameters":
interaction_vals = interaction_vals.tolist()
elif key == "n_atoms":
interaction_vals = interaction_vals[0]
if hasattr(sec_interaction, key):
sec_interaction.m_set(
sec_interaction.m_get_quantity_definition(key), interaction_vals
)

if not sec_interaction.n_atoms:
sec_interaction.n_atoms = (
len(sec_interaction.get("atom_indices")[0])
if sec_interaction.get("atom_indices") is not None
else None
)
## NEW ##
# def write_interaction_values(values):
# interaction_dict = {}
# for interaction_key in Interaction.m_def.all_quantities.keys():
# interaction_dict[interaction_key] = np.array(
# [interaction.get(interaction_key) for interaction in interactions],
# dtype=object,
# )
# interaction_dict = {key: val for key, val in interaction_dict.items()}
# interaction_types = (
# np.unique(interaction_dict["type"])
# if interaction_dict.get("type") is not None
# else []
# )
# for interaction_type in interaction_types:
# sec_interaction = Interaction()
# sec_model.contributions.append(sec_interaction)
# sec_interaction.type = current_type
# sec_interaction.n_atoms = max(
# [len(v) for v in values.get("atom_indices", [[0]])]
# )
# for key, val in values.items():
# # TODO tempory fix: atom_labels, atom_indices not homogeneous
# # fill in missing atom label with 'X', atom index with -1
# # if key in ["atom_indices", "atom_labels"]:
# # val = [
# # (
# # v
# # + [-1 if key == "atom_indices" else "X"]
# # * sec_interaction.n_atoms
# # )[: sec_interaction.n_atoms]
# # for v in val
# # ]
# quantity_def = sec_interaction.m_def.all_quantities.get(key)
# if quantity_def:
# try:
# sec_interaction.m_set(quantity_def, val)
# except Exception:
# self.logger.error("Error setting metadata.", data={"key": key})

# interactions.sort(key=lambda x: x.get("type"))
# print(interactions)
# current_type = interactions[0].get("type")
# interaction_values: Dict[str, Any] = {}
# for interaction in interactions:
# interaction_type = interaction.get("type")
# if current_type and current_type != interaction_type:
# write_interaction_values(interaction_values)
# current_type = interaction_type
# interaction_values = {}
# interaction_values.setdefault("n_interactions", 0)
# interaction_values["n_interactions"] += 1
# for key, val in interaction.items():
# interaction_indices = np.where(
# interaction_dict["type"] == interaction_type
# )[0]
# sec_interaction.type = interaction_type
# sec_interaction.n_interactions = len(interaction_indices)
# sec_interaction.n_atoms
# for key, val in interaction_dict.items():
# if key == "type":
# continue
# interaction_values.setdefault(key, [])
# interaction_values[key].append(val)
# if interaction_values:
# write_interaction_values(interaction_values)
# interaction_vals = val[interaction_indices]
# if type(interaction_vals[0]).__name__ == "ndarray":
# interaction_vals = np.array(
# [vals.tolist() for vals in interaction_vals], dtype=object
# )
# if interaction_vals.all() is None:
# continue
# if key == "parameters":
# interaction_vals = interaction_vals.tolist()
# elif key == "n_atoms":
# interaction_vals = interaction_vals[0]
# if hasattr(sec_interaction, key):
# sec_interaction.m_set(
# sec_interaction.m_get_quantity_definition(key), interaction_vals
# )

# if not sec_interaction.n_atoms:
# sec_interaction.n_atoms = (
# len(sec_interaction.get("atom_indices")[0])
# if sec_interaction.get("atom_indices") is not None
# else None
# )
## NEW ##
def write_interaction_values(values):
sec_interaction = Interaction()
sec_model.contributions.append(sec_interaction)
sec_interaction.type = current_type
sec_interaction.n_atoms = max(
[len(v) for v in values.get("atom_indices", [[0]])]
)
for key, val in values.items():
# TODO tempory fix: atom_labels, atom_indices not homogeneous
# fill in missing atom label with 'X', atom index with -1
# if key in ["atom_indices", "atom_labels"]:
# val = [
# (
# v
# + [-1 if key == "atom_indices" else "X"]
# * sec_interaction.n_atoms
# )[: sec_interaction.n_atoms]
# for v in val
# ]
quantity_def = sec_interaction.m_def.all_quantities.get(key)
if quantity_def:
try:
sec_interaction.m_set(quantity_def, val)
except Exception:
self.logger.error("Error setting metadata.", data={"key": key})

interactions.sort(key=lambda x: x.get("type"))
print(interactions)
current_type = interactions[0].get("type")
interaction_values: Dict[str, Any] = {}
for interaction in interactions:
interaction_type = interaction.get("type")
if current_type and current_type != interaction_type:
write_interaction_values(interaction_values)
current_type = interaction_type
interaction_values = {}
interaction_values.setdefault("n_interactions", 0)
interaction_values["n_interactions"] += 1
for key, val in interaction.items():
if key == "type":
continue
interaction_values.setdefault(key, [])
interaction_values[key].append(val)
if interaction_values:
write_interaction_values(interaction_values)

def parse_interactions_by_type(
self, interactions_by_type: List[Dict], sec_model: Model
Expand Down
10 changes: 4 additions & 6 deletions tests/test_gromacsparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def test_md_verbose(parser):
assert sec_systems[1].atoms.positions[800][1].magnitude == approx(2.4740036e-09)
assert sec_systems[0].atoms.velocities[500][0].magnitude == approx(869.4773)
assert sec_systems[1].atoms.lattice_vectors[2][2].magnitude == approx(2.469158e-09)
# TODO fix this, fails with changes in utils.parsers.parse_interactions
# assert sec_systems[0].atoms.bond_list[200][0] == 289
assert sec_systems[0].atoms.bond_list[200, 0] == 289

sec_method = sec_run.method
assert len(sec_method) == 1
Expand All @@ -150,10 +149,9 @@ def test_md_verbose(parser):
assert sec_method[0].force_field.model[0].contributions[6].n_interactions == 1017
assert sec_method[0].force_field.model[0].contributions[6].n_atoms == 2
assert sec_method[0].force_field.model[0].contributions[6].atom_labels[10][0] == "C"
# TODO fix this, fails with changes in utils.parsers.parse_interactions
# assert (
# sec_method[0].force_field.model[0].contributions[6].atom_indices[100][1] == 141
# )
assert (
sec_method[0].force_field.model[0].contributions[6].atom_indices[100, 1] == 141
)
assert sec_method[0].force_field.model[0].contributions[6].parameters[
858
] == approx(0.9999996193044006)
Expand Down
6 changes: 2 additions & 4 deletions tests/test_lammpsparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,7 @@ def test_nvt(parser):
assert sec_method.force_field.model[0].contributions[1].type == "bond"
assert sec_method.force_field.model[0].contributions[1].n_interactions == 666
assert sec_method.force_field.model[0].contributions[1].n_atoms == 2
# TODO fix this, fails with changes in utils.parsers.parse_interactions
# assert sec_method.force_field.model[0].contributions[1].atom_indices[100][1] == 103
assert sec_method.force_field.model[0].contributions[1].atom_indices[100, 1] == 103
assert sec_method.force_field.model[0].contributions[1].parameters[200] == approx(
1.1147454117684314
)
Expand All @@ -93,8 +92,7 @@ def test_nvt(parser):
assert sec_system[5].atoms.lattice_vectors[1][1].magnitude == approx(2.24235e-09)
assert False not in sec_system[0].atoms.periodic
assert sec_system[80].atoms.labels[91:96] == ["H", "H", "H", "C", "C"]
# TODO fix this, fails with changes in utils.parsers.parse_interactions
# assert sec_system[0].atoms.bond_list[200][0] == 194
assert sec_system[0].atoms.bond_list[200, 0] == 194

sec_scc = sec_run.calculation
assert len(sec_scc) == 201
Expand Down

0 comments on commit 5ec4560

Please sign in to comment.