Skip to content

Commit

Permalink
ENH: Added the new segment_dict parameter to the `PSFContainer.gene…
Browse files Browse the repository at this point in the history
…rate_x()` functions (#239)
  • Loading branch information
BvB93 committed May 4, 2021
1 parent 17f4f4b commit f7c0489
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 28 deletions.
159 changes: 147 additions & 12 deletions FOX/io/read_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,14 @@

T = TypeVar('T')

_GenerateNames = Literal["angles", "bonds", "impropers", "dihedrals"]
_FUNC_DICT: MappingProxyType[_GenerateNames, Callable[[Molecule], np.ndarray]] = MappingProxyType({
"angles": get_angles,
"bonds": get_bonds,
"impropers": get_impropers,
"dihedrals": get_dihedrals,
})


class DummyGetter(Generic[T]):
def __init__(self, return_value: T) -> None:
Expand Down Expand Up @@ -761,49 +769,176 @@ def update_atom_type(self, atom_type_old: str, atom_type_new: str) -> None:
condition = self.atom_type == atom_type_old
self.atoms.loc[condition, 'atom type'] = atom_type_new

def generate_bonds(self, mol: Molecule) -> None:
def _generate_from_res_dict(
self,
residue_dict: Mapping[str, Molecule],
name: _GenerateNames,
) -> np.ndarray:
"""Helper function for the ``PSFContainer.generate_x()`` functions."""
func = _FUNC_DICT[name]

offset_dict: dict[tuple[str, int], int] = {}
for offset, ij in enumerate(zip(self.segment_name, self.residue_id)):
offset_dict.setdefault(ij, offset)

# Validate that all `residue_dict` keys are valid
segment_set = {i for i, _ in offset_dict}
if not segment_set.issuperset(residue_dict.keys()):
key = next(iter(residue_dict.keys()))
raise KeyError(key)

cls = type(self)
angle_dict = {i: func(m) for i, m in residue_dict.items()}
ret = np.concatenate(
[np.ravel(angle_dict[i] + offset) for (i, _), offset in offset_dict.items() if
i in angle_dict]
)
ret.shape = -1, cls._SHAPE_DICT[name]["shape"]
return ret

@overload
def generate_bonds(self, mol: Molecule, *, segment_dict: None = ...) -> None:
...
@overload # noqa: E301
def generate_bonds(self, mol: None = ..., *, segment_dict: Mapping[str, Molecule]) -> None:
...
def generate_bonds(self, mol=None, *, segment_dict=None) -> None: # noqa: E301
"""Update :attr:`PSFContainer.bonds` with the indices of all bond-forming atoms from **mol**.
Notes
-----
The **mol** and **segment_dict** parameters are mutually exclusive.
Examples
--------
.. code-block:: python
>>> from FOX import PSFContainer
>>> from scm.plams import Molecule
>>> psf = PSFContainer(...)
>>> segment_dict = {"MOL3": Molecule(...)}
>>> psf.generate_bonds(segment_dict=segment_dict)
Parameters
----------
mol : :class:`plams.Molecule <scm.plams.mol.molecule.Molecule>`
A PLAMS Molecule.
segment_dict : :class:`Mapping[str, plams.Molecule] <collections.abc.Mapping>`
A dictionary mapping segment names to individual ligands.
This can result in dramatic speed ups for systems wherein each segment
contains a large number of residues.
""" # noqa
self.bonds = get_bonds(mol)
""" # noqa: E501
if mol is segment_dict is None:
raise TypeError("One of `mol` and `segment_dict` must be specified")
elif mol is not None and segment_dict is not None:
raise TypeError("Only one of `mol` and `segment_dict` can be specified")

if mol is not None:
self.bonds = get_bonds(mol)
else:
self.bonds = self._generate_from_res_dict(segment_dict, "bonds")

def generate_angles(self, mol: Molecule) -> None:
@overload
def generate_angles(self, mol: Molecule, *, segment_dict: None = ...) -> None:
...
@overload # noqa: E301
def generate_angles(self, mol: None = ..., *, segment_dict: Mapping[str, Molecule]) -> None:
...
def generate_angles(self, mol=None, *, segment_dict=None) -> None: # noqa: E301
"""Update :attr:`PSFContainer.angles` with the indices of all angle-defining atoms from **mol**.
Notes
-----
The **mol** and **segment_dict** parameters are mutually exclusive.
Parameters
----------
mol : :class:`plams.Molecule <scm.plams.mol.molecule.Molecule>`
A PLAMS Molecule.
segment_dict : :class:`Mapping[str, plams.Molecule] <collections.abc.Mapping>`
A dictionary mapping segment names to individual ligands.
This can result in dramatic speed ups for systems wherein each segment
contains a large number of residues.
""" # noqa
self.angles = get_angles(mol)
""" # noqa: E501
if mol is segment_dict is None:
raise TypeError("One of `mol` and `segment_dict` must be specified")
elif mol is not None and segment_dict is not None:
raise TypeError("Only one of `mol` and `segment_dict` can be specified")

if mol is not None:
self.angles = get_angles(mol)
else:
self.angles = self._generate_from_res_dict(segment_dict, "angles")

def generate_dihedrals(self, mol: Molecule) -> None:
@overload
def generate_dihedrals(self, mol: Molecule, *, segment_dict: None = ...) -> None:
...
@overload # noqa: E301
def generate_dihedrals(self, mol: None = ..., *, segment_dict: Mapping[str, Molecule]) -> None:
...
def generate_dihedrals(self, mol=None, *, segment_dict=None) -> None: # noqa: E301
"""Update :attr:`PSFContainer.dihedrals` with the indices of all proper dihedral angle-defining atoms from **mol**.
Notes
-----
The **mol** and **segment_dict** parameters are mutually exclusive.
Parameters
----------
mol : :class:`plams.Molecule <scm.plams.mol.molecule.Molecule>`
A PLAMS Molecule.
segment_dict : :class:`Mapping[str, plams.Molecule] <collections.abc.Mapping>`
A dictionary mapping segment names to individual ligands.
This can result in dramatic speed ups for systems wherein each segment
contains a large number of residues.
""" # noqa
self.dihedrals = get_dihedrals(mol)
""" # noqa: E501
if mol is segment_dict is None:
raise TypeError("One of `mol` and `segment_dict` must be specified")
elif mol is not None and segment_dict is not None:
raise TypeError("Only one of `mol` and `segment_dict` can be specified")

if mol is not None:
self.dihedrals = get_dihedrals(mol)
else:
self.dihedrals = self._generate_from_res_dict(segment_dict, "dihedrals")

def generate_impropers(self, mol: Molecule) -> None:
@overload
def generate_impropers(self, mol: Molecule, *, segment_dict: None = ...) -> None:
...
@overload # noqa: E301
def generate_impropers(self, mol: None = ..., *, segment_dict: Mapping[str, Molecule]) -> None:
...
def generate_impropers(self, mol=None, *, segment_dict=None) -> None: # noqa: E301
"""Update :attr:`PSFContainer.impropers` with the indices of all improper dihedral angle-defining atoms from **mol**.
Notes
-----
The **mol** and **segment_dict** parameters are mutually exclusive.
Parameters
----------
mol : :class:`plams.Molecule <scm.plams.mol.molecule.Molecule>`
A PLAMS Molecule.
segment_dict : :class:`Mapping[str, plams.Molecule] <collections.abc.Mapping>`
A dictionary mapping segment names to individual ligands.
This can result in dramatic speed ups for systems wherein each segment
contains a large number of residues.
""" # noqa
self.impropers = get_impropers(mol)
""" # noqa: E501
if mol is segment_dict is None:
raise TypeError("One of `mol` and `segment_dict` must be specified")
elif mol is not None and segment_dict is not None:
raise TypeError("Only one of `mol` and `segment_dict` can be specified")

if mol is not None:
self.impropers = get_impropers(mol)
else:
self.impropers = self._generate_from_res_dict(segment_dict, "impropers")

def generate_atoms(self, mol: Molecule,
id_map: Optional[Mapping[int, Any]] = None) -> None:
Expand Down
Binary file added tests/test_files/psf/generate_angles2.npy
Binary file not shown.
Binary file added tests/test_files/psf/generate_bonds2.npy
Binary file not shown.
Binary file added tests/test_files/psf/generate_dihedrals2.npy
Binary file not shown.
Binary file added tests/test_files/psf/generate_impropers2.npy
Binary file not shown.
58 changes: 42 additions & 16 deletions tests/test_psf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for :class:`FOX.io.read_psf.PSFContainer`."""

import os
from types import MappingProxyType
from pathlib import Path
from tempfile import TemporaryFile
from itertools import zip_longest
Expand All @@ -17,6 +18,11 @@

MOL = Molecule(PATH / 'mol.pdb')
MOL.guess_bonds(atom_subset=[at for at in MOL if at.symbol in ('C', 'O', 'H')])
LIGAND = MOL.separate()[-1]

SEGMENT_DICT = MappingProxyType({
"MOL3": LIGAND,
})


def test_write() -> None:
Expand Down Expand Up @@ -62,34 +68,54 @@ def test_update_atom_type() -> None:

def test_generate_bonds() -> None:
"""Tests for :meth:`PSFContainer.generate_bonds`."""
psf = PSF.copy()
ref = np.load(PATH / 'generate_bonds.npy')
psf.generate_bonds(MOL)
np.testing.assert_array_equal(ref, psf.bonds)
psf1 = PSF.copy()
psf1.generate_bonds(MOL)
ref1 = np.load(PATH / 'generate_bonds.npy')
np.testing.assert_array_equal(psf1.bonds, ref1)

psf2 = PSF.copy()
psf2.generate_bonds(segment_dict=SEGMENT_DICT)
ref2 = np.load(PATH / 'generate_bonds2.npy')
np.testing.assert_array_equal(psf2.bonds, ref2)


def test_generate_angles() -> None:
"""Tests for :meth:`PSFContainer.generate_angles`."""
psf = PSF.copy()
ref = np.load(PATH / 'generate_angles.npy')
psf.generate_angles(MOL)
np.testing.assert_array_equal(ref, psf.angles)
psf1 = PSF.copy()
psf1.generate_angles(MOL)
ref1 = np.load(PATH / 'generate_angles.npy')
np.testing.assert_array_equal(psf1.angles, ref1)

psf2 = PSF.copy()
psf2.generate_angles(segment_dict=SEGMENT_DICT)
ref2 = np.load(PATH / 'generate_angles2.npy')
np.testing.assert_array_equal(psf2.angles, ref2)


def test_generate_dihedrals() -> None:
"""Tests for :meth:`PSFContainer.generate_dihedrals`."""
psf = PSF.copy()
ref = np.load(PATH / 'generate_dihedrals.npy')
psf.generate_dihedrals(MOL)
np.testing.assert_array_equal(ref, psf.dihedrals)
psf1 = PSF.copy()
psf1.generate_dihedrals(MOL)
ref1 = np.load(PATH / 'generate_dihedrals.npy')
np.testing.assert_array_equal(psf1.dihedrals, ref1)

psf2 = PSF.copy()
psf2.generate_dihedrals(segment_dict=SEGMENT_DICT)
ref2 = np.load(PATH / 'generate_dihedrals2.npy')
np.testing.assert_array_equal(psf2.dihedrals, ref2)


def test_generate_impropers() -> None:
"""Tests for :meth:`PSFContainer.generate_impropers`."""
psf = PSF.copy()
ref = np.load(PATH / 'generate_impropers.npy')
psf.generate_impropers(MOL)
np.testing.assert_array_equal(ref, psf.impropers)
psf1 = PSF.copy()
psf1.generate_impropers(MOL)
ref1 = np.load(PATH / 'generate_impropers.npy')
np.testing.assert_array_equal(psf1.impropers, ref1)

psf2 = PSF.copy()
psf2.generate_impropers(segment_dict=SEGMENT_DICT)
ref2 = np.load(PATH / 'generate_impropers2.npy')
np.testing.assert_array_equal(psf2.impropers, ref2)


def test_to_atom_alias_dict() -> None:
Expand Down

0 comments on commit f7c0489

Please sign in to comment.