Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: update dihedrals, make pairs unique #383

Merged
merged 4 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 42 additions & 17 deletions src/kimmdy/topology/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,8 @@ def reindex_atomnrs(self) -> dict[str, str]:
"""

update_map = {
atom_nr: str(i + 1) for i, atom_nr in enumerate(self.atoms.keys())
atom_nr: str(i + 1)
for i, atom_nr in enumerate(sorted(list(self.atoms.keys()), key=int))
}

new_atoms = {}
Expand Down Expand Up @@ -539,8 +540,8 @@ def reindex_atomnrs(self) -> dict[str, str]:

new_pairs = {}
new_multiple_dihedrals = {}
new_dihedrals = {}
for dihedrals in self.proper_dihedrals.values():
new_dihedrals = {}
ai = update_map.get(dihedrals.ai)
aj = update_map.get(dihedrals.aj)
ak = update_map.get(dihedrals.ak)
Expand All @@ -550,9 +551,10 @@ def reindex_atomnrs(self) -> dict[str, str]:
continue

# do pairs before the dihedrals are updated
if pair := self.pairs.get((dihedrals.ai, dihedrals.al)):
if pair := self.pairs.pop((dihedrals.ai, dihedrals.al), False):
pair_ai = update_map.get(pair.ai)
pair_aj = update_map.get(pair.aj)

if None not in (pair_ai, pair_aj):
pair.ai = pair_ai # type: ignore (pyright bug)
pair.aj = pair_aj # type: ignore
Expand All @@ -569,6 +571,7 @@ def reindex_atomnrs(self) -> dict[str, str]:
dihedral.ak = ak # type: ignore
dihedral.al = al # type: ignore
new_dihedrals[dihedral.periodicity] = dihedral
dihedrals.dihedrals = new_dihedrals

new_multiple_dihedrals[
(
Expand All @@ -582,21 +585,31 @@ def reindex_atomnrs(self) -> dict[str, str]:
self.proper_dihedrals = new_multiple_dihedrals
self.pairs = new_pairs

new_impropers = {}
for dihedral in self.improper_dihedrals.values():
ai = update_map.get(dihedral.ai)
aj = update_map.get(dihedral.aj)
ak = update_map.get(dihedral.ak)
al = update_map.get(dihedral.al)
new_multiple_dihedrals = {}
for dihedrals in self.improper_dihedrals.values():
new_dihedrals = {}
ai = update_map.get(dihedrals.ai)
aj = update_map.get(dihedrals.aj)
ak = update_map.get(dihedrals.ak)
al = update_map.get(dihedrals.al)
# drop dihedrals to a deleted atom
if None in (ai, aj, ak, al):
continue
dihedral.ai = ai # type: ignore
dihedral.aj = aj # type: ignore
dihedral.ak = ak # type: ignore
dihedral.al = al # type: ignore
new_impropers[(ai, aj, ak, al)] = dihedral
self.improper_dihedrals = new_impropers
dihedrals.ai = ai # type: ignore
dihedrals.aj = aj # type: ignore
dihedrals.ak = ak # type: ignore
dihedrals.al = al # type: ignore

for dihedral in dihedrals.dihedrals.values():
dihedral.ai = ai # type: ignore
dihedral.aj = aj # type: ignore
dihedral.ak = ak # type: ignore
dihedral.al = al # type: ignore
new_dihedrals[dihedral.periodicity] = dihedral
dihedrals.dihedrals = new_dihedrals

new_multiple_dihedrals[(ai, aj, ak, al)] = dihedrals
self.improper_dihedrals = new_multiple_dihedrals

return update_map

Expand Down Expand Up @@ -975,11 +988,23 @@ def del_atom(
f"{float(self.atoms[atom.bound_to_nrs[0]].charge) + float(atom.charge):7.4f}"
)

# break all bonds and delete all pairs, diheadrals etc
# break all bonds and delete all pairs, diheadrals with these bonds
for bound_nr in copy(atom.bound_to_nrs):
self.break_bond((bound_nr, _atom_nr))
self.radicals.pop(_atom_nr)

for an in tuple(self.angles.keys()):
if _atom_nr in an:
self.angles.pop(an)

for pd in tuple(self.proper_dihedrals.keys()):
if _atom_nr in pd:
self.proper_dihedrals.pop(pd)

for id in tuple(self.improper_dihedrals.keys()):
if _atom_nr in id:
self.improper_dihedrals.pop(id)

self.radicals.pop(_atom_nr)
self.atoms.pop(_atom_nr)

update_map_all = self.reindex_atomnrs()
Expand Down
94 changes: 93 additions & 1 deletion tests/test_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,60 @@ def test_top_ab(self, raw_top_a_fix, raw_top_b_fix):


class TestTopology:

def test_reindex_no_change(self, hexala_top_fix: Topology):
org_top: Topology = deepcopy(hexala_top_fix)
update = hexala_top_fix.reindex_atomnrs()

# test produced mapping
assert len(update.keys()) == 12
for mapping in update.values():
for k, v in mapping.items():
assert k == v

# test topology
assert org_top.atoms == hexala_top_fix.atoms
assert org_top.bonds == hexala_top_fix.bonds
assert org_top.angles == hexala_top_fix.angles
assert org_top.proper_dihedrals == hexala_top_fix.proper_dihedrals
assert org_top.improper_dihedrals == hexala_top_fix.improper_dihedrals

@given(atomindex=st.integers(min_value=1, max_value=72))
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture], deadline=400)
def test_del_atom_hexala(self, hexala_top_fix, atomindex):
top: Topology = deepcopy(hexala_top_fix)

atom = top.atoms[str(atomindex)]
# bonds
bound_nrs = deepcopy(atom.bound_to_nrs)
bound_atms = [top.atoms[i] for i in bound_nrs]
# angles
angles_to_delete = []
angles_to_update = []
for key in top.angles.keys():
if str(atomindex) in key:
angles_to_delete.append(key)
else:
angles_to_update.append(key)
# proper dihedrals
pd_to_delete = []
pd_to_update = []
for key in top.proper_dihedrals.keys():
if str(atomindex) in key:
pd_to_delete.append(key)
else:
pd_to_update.append(key)
# improper dihedrals
id_to_delete = []
id_to_update = []
for key in top.improper_dihedrals.keys():
if str(atomindex) in key:
id_to_delete.append(key)
else:
id_to_update.append(key)

top.del_atom(str(atomindex), parameterize=False)
update = top.del_atom(str(atomindex), parameterize=False)
rev_update = {v: k for k, v in update.items()}

for nr, atm in zip(bound_nrs, bound_atms):
if int(nr) > atomindex:
Expand All @@ -273,6 +317,54 @@ def test_del_atom_hexala(self, hexala_top_fix, atomindex):
assert atom not in top.atoms.values()
assert len(atom.bound_to_nrs) == 0

# angles
for a_del in angles_to_delete:
assert None in [update.get(a) for a in a_del]
for a_up in angles_to_update:
new = top.angles[tuple([update.get(a) for a in a_up])]
old = hexala_top_fix.angles[a_up]
assert old.ai == rev_update.get(new.ai)
assert old.aj == rev_update.get(new.aj)
assert old.ak == rev_update.get(new.ak)

# proper dihedrals
for pd_del in pd_to_delete:
assert None in tuple([update.get(a) for a in pd_del])
for pd_up in pd_to_update:
new = top.proper_dihedrals[tuple([update.get(a) for a in pd_up])]
old = hexala_top_fix.proper_dihedrals[pd_up]
assert old.ai == rev_update.get(new.ai)
assert old.aj == rev_update.get(new.aj)
assert old.ak == rev_update.get(new.ak)
assert old.al == rev_update.get(new.al)

for d_key in old.dihedrals:
old_d = old.dihedrals[d_key]
new_d = new.dihedrals[d_key]
assert new_d.ai == update.get(old_d.ai)
assert new_d.aj == update.get(old_d.aj)
assert new_d.ak == update.get(old_d.ak)
assert new_d.al == update.get(old_d.al)

# improper dihedrals
for id_del in id_to_delete:
assert None in tuple([update.get(a) for a in id_del])
for id_up in id_to_update:
new = top.improper_dihedrals[tuple([update.get(a) for a in id_up])]
old = hexala_top_fix.improper_dihedrals[id_up]
assert old.ai == rev_update.get(new.ai)
assert old.aj == rev_update.get(new.aj)
assert old.ak == rev_update.get(new.ak)
assert old.al == rev_update.get(new.al)

for d_key in old.dihedrals:
old_d = old.dihedrals[d_key]
new_d = new.dihedrals[d_key]
assert new_d.ai == update.get(old_d.ai)
assert new_d.aj == update.get(old_d.aj)
assert new_d.ak == update.get(old_d.ak)
assert new_d.al == update.get(old_d.al)

def test_break_bind_bond_hexala(self, hexala_top_fix):
top = deepcopy(hexala_top_fix)
og_top = deepcopy(top)
Expand Down