Skip to content

Commit

Permalink
BUG: Fixed an issue wherein incorrect indices were assigned to atom a…
Browse files Browse the repository at this point in the history
…liases (#217)
  • Loading branch information
BvB93 committed Feb 22, 2021
1 parent 314b7ea commit 6fc7837
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
20 changes: 10 additions & 10 deletions FOX/io/read_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import (Dict, Optional, Any, Set, Iterator, Iterable, TypeVar, Tuple, cast,
List, Mapping, Union, Collection, Generic, IO, Callable)
from itertools import chain
from collections import defaultdict
from types import MappingProxyType

import numpy as np
Expand Down Expand Up @@ -889,17 +890,16 @@ def to_atom_dict(self) -> Dict[str, List[int]]:

def to_atom_alias_dict(self) -> Dict[str, Tuple[str, np.ndarray[Any, np.dtype[np.intp]]]]:
"""Create a with atom aliases."""
iterator: Iterator[Tuple[str, str]] = (
(i, j) for i, j in self.atoms[['atom type', 'atom name']].values if i != j
)
counter: defaultdict[str, int] = defaultdict(lambda: -1)
dct: defaultdict[Tuple[str, str], List[int]] = defaultdict(list)
for (at1, at2) in self.atoms[['atom type', 'atom name']].values: # type: str, str
counter[at2] += 1
if at1 == at2:
continue

dct: Dict[Tuple[str, str], int] = {}
for i, j in iterator:
try:
dct[i, j] += 1
except KeyError:
dct[i, j] = 1
return {i: (j, np.arange(v, dtype=np.intp)) for (i, j), v in dct.items()}
i = counter[at2]
dct[at1, at2].append(i)
return {at1: (at2, np.array(lst, dtype=np.intp)) for (at1, at2), lst in dct.items()}

@raise_if(RDKIT_EX)
def write_pdb(self, mol: Molecule,
Expand Down
8 changes: 8 additions & 0 deletions tests/test_psf.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,11 @@ def test_generate_impropers() -> None:
ref = np.load(join(PATH, 'generate_impropers.npy'))
psf.generate_impropers(MOL)
np.testing.assert_array_equal(ref, psf.impropers)


def test_to_atom_alias_dict() -> None:
"""Tests for :meth:`PSFContainer.to_atom_alias_dict`."""
dct = PSF.to_atom_alias_dict()
for at1, (at2, idx) in dct.items():
at2_slice = PSF.atom_type[PSF.atom_name == at2]
np.testing.assert_array_equal(at2_slice.iloc[idx], at1)

0 comments on commit 6fc7837

Please sign in to comment.