diff --git a/src/kimmdy/coordinates.py b/src/kimmdy/coordinates.py index 94daa4ad..b45c7c04 100644 --- a/src/kimmdy/coordinates.py +++ b/src/kimmdy/coordinates.py @@ -19,6 +19,7 @@ Dihedral, DihedralType, ProperDihedralId, + ImproperDihedralId, MultipleDihedrals, Interaction, InteractionType, @@ -157,8 +158,12 @@ def merge_dihedrals( dihedral_key: tuple[str, str, str, str], dihedral_a: Optional[Dihedral], dihedral_b: Optional[Dihedral], - dihedral_types_a: dict[ProperDihedralId, DihedralType], - dihedral_types_b: dict[ProperDihedralId, DihedralType], + dihedral_types_a: Union[ + dict[ProperDihedralId, DihedralType], dict[ImproperDihedralId, DihedralType] + ], + dihedral_types_b: Union[ + dict[ProperDihedralId, DihedralType], dict[ImproperDihedralId, DihedralType] + ], molA: MoleculeType, molB: MoleculeType, funct: str, @@ -447,28 +452,58 @@ def merge_top_moleculetypes_slow_growth( ) # improper dihedrals - # all impropers in amber99SB ffbonded.itp have a periodicity of 2 - # but not the ones defined in aminoacids.rtp. For now, I am assuming - # a periodicity of 2 in this section + # TODO: duplicate of proper dihedrals, could refactor keys = set(molA.improper_dihedrals.keys()) | set(molB.improper_dihedrals.keys()) for key in keys: - interactionA = molA.improper_dihedrals.get(key) - interactionB = molB.improper_dihedrals.get(key) + multiple_dihedralsA = molA.improper_dihedrals.get(key) + multiple_dihedralsB = molB.improper_dihedrals.get(key) - if interactionA != interactionB: - molB.improper_dihedrals[key] = merge_dihedrals( - key, - interactionA, - interactionB, - ff.improper_dihedraltypes, - ff.improper_dihedraltypes, - molA, - molB, - "4", - "2", + if multiple_dihedralsA != multiple_dihedralsB: + multiple_dihedralsA = get_explicit_MultipleDihedrals( + key, molA, multiple_dihedralsA, ff + ) + multiple_dihedralsB = get_explicit_MultipleDihedrals( + key, molB, multiple_dihedralsB, ff + ) + keysA = ( + set(multiple_dihedralsA.dihedrals.keys()) + if multiple_dihedralsA + else set() + ) + keysB = ( + set(multiple_dihedralsB.dihedrals.keys()) + if multiple_dihedralsB + else set() ) + molB.improper_dihedrals[key] = MultipleDihedrals(*key, "9", {}) + periodicities = keysA | keysB + for periodicity in periodicities: + assert isinstance(periodicity, str) + interactionA = ( + multiple_dihedralsA.dihedrals.get(periodicity) + if multiple_dihedralsA + else None + ) + interactionB = ( + multiple_dihedralsB.dihedrals.get(periodicity) + if multiple_dihedralsB + else None + ) + + molB.improper_dihedrals[key].dihedrals[periodicity] = merge_dihedrals( + key, + interactionA, + interactionB, + ff.improper_dihedraltypes, + ff.improper_dihedraltypes, + molA, + molB, + "9", + periodicity, + ) + # amber fix for breaking/binding atom types without LJ potential # breakpoint() diff --git a/tests/test_coordinates.py b/tests/test_coordinates.py index fa53e6d6..ede29545 100644 --- a/tests/test_coordinates.py +++ b/tests/test_coordinates.py @@ -145,7 +145,9 @@ def test_merge_prm_top(arranged_tmp_path): assert top_merge.bonds[("26", "27")].funct == "3" assert top_merge.angles[("17", "19", "20")].c3 is not None assert top_merge.proper_dihedrals[("15", "17", "19", "24")].dihedrals["3"].c5 == "3" - assert top_merge.improper_dihedrals[("17", "20", "19", "24")].c5 == "2" + assert ( + top_merge.improper_dihedrals[("17", "20", "19", "24")].dihedrals["2"].c5 == "2" + ) # assert one dihedral merge improper/proper