Skip to content

Commit

Permalink
Merge pull request #230 from nlesc-nano/sorting
Browse files Browse the repository at this point in the history
BUG: FIxed an issue wherein frozen parameters weren't properly sorted
  • Loading branch information
BvB93 committed Mar 30, 2021
2 parents b6d6286 + 464c7a7 commit 481b243
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 32 deletions.
46 changes: 15 additions & 31 deletions FOX/armc/sanitization.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from ..functions.cp2k_utils import UNIT_MAP
from ..functions.molecule_utils import fix_bond_orders, residue_argsort
from ..functions.charge_parser import assign_constraints
from ..functions.sorting import sort_param

if TYPE_CHECKING:
from .package_manager import PackageManager, PkgDict
Expand Down Expand Up @@ -341,13 +342,13 @@ def get_param(dct: ParamMapping_) -> Tuple[ParamMapping, dict, dict, ValidationD
data['unit'] = units

if _sub_prm_dict_frozen is not None:
for *_key, value in _get_prm(_sub_prm_dict_frozen):
key = tuple(_key)
try:
unit = data.loc[key[:2], 'unit'].iat[0]
except KeyError:
unit = ''
data.loc[key, :] = [value, value, True, False, -np.inf, np.inf, 0, unit]
data2 = _get_param_df(_sub_prm_dict_frozen)
if len(data2) != 0:
constraints2, min_max2, units2 = _get_prm_constraints(_sub_prm_dict_frozen)
data2[['min', 'max']] = min_max2
data2['unit'] = units2
data2['frozen'] = True
data = data.append(data2)
data.sort_index(inplace=True)

param_type = prm_dict.pop('type') # type: ignore
Expand Down Expand Up @@ -555,30 +556,13 @@ def _sort_atoms(df: pd.DataFrame) -> None:
param_types = set(df["param_type"])
for prm in param_types:
condition = df['param_type'] == prm
atoms = df.loc[condition, 'atoms'].values.astype(str)
atoms_split = np.char.partition(atoms, " ")

# Sort the atoms whenever dealing with atom-pair/triplet-based parameters
n = atoms_split.shape[1]
if n == 3:
atoms_split[..., ::2].sort(axis=1)
elif n == 5:
atoms_split[..., ::4].sort(axis=1)
elif n >= 7:
m = (n - 1) // 2
warnings.warn(f"The sorting of {m}-atom based parameters is not implemented")
continue
else:
continue

new_atoms = np.array([''.join(j for j in i) for i in atoms_split])

# Check for duplicates
_, idx, counts = np.unique(new_atoms, return_index=True, return_counts=True)
if not (counts == 1).all():
duplicates = atoms[idx[counts != 1]]
raise KeyError(f"Duplicate {prm!r} keys encountered: {duplicates}")
df.loc[condition, 'atoms'] = new_atoms
atoms = df.loc[condition, 'atoms'].values.astype(np.str_)
try:
df.loc[condition, 'atoms'] = sort_param(atoms)
except NotImplementedError as ex:
warning = RuntimeWarning(str(ex))
warning.__cause__ = ex
warnings.warn(warning)


def _get_prm(dct: Mapping[str, Union[Mapping, Iterable[Mapping]]]
Expand Down
122 changes: 122 additions & 0 deletions FOX/functions/sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
"""A module with functions for sorting forcefield parameters.
Index
-----
.. currentmodule:: FOX.functions.sorting
.. autosummary::
sort_param
API
---
.. autofunction:: sort_param
"""

from __future__ import annotations

from typing import TYPE_CHECKING, TypeVar, Any

import numpy as np

if TYPE_CHECKING:
import numpy.typing as npt

SCT = TypeVar("SCT", bound=np.generic)
NDArray = np.ndarray[Any, np.dtype[SCT]]

__all__ = ["sort_param"]


def sort_param(
param: npt.ArrayLike,
seperator: str = " ",
check_duplicates: bool = True,
) -> NDArray[np.str_]:
"""Sort all atoms in an atom-based parameter set.
Parameters represented by two atoms are simply sorted in alphabetical order.
For parameters consisting of three atoms only the first and last atoms are
sorted alphabetically.
Parameters consisting of four or more atoms are not supported.
Examples
--------
.. code-block ::
>>> from FOX.functions.sorting import sort_param
>>> param1 = [
... "Cd Cd",
... "Se Cd",
... "Se Se",
... ]
>>> param2 = [
... "Cd Cd Cd",
... "Se Cd Cd",
... "Se Se Se",
... ]
>>> sort_param(param1)
array(['Cd Cd', 'Cd Se', 'Se Se'], dtype='<U5')
>>> sort_param(param2)
array(['Cd Cd Cd', 'Cd Cd Se', 'Se Se Se'], dtype='<U8')
Parameters
----------
param : array-like
The to-be sorted parameters.
seperator : :class:`str`
The seperator used for splitting the atoms.
check_duplicates : :class:`bool`
Whether to check for duplicate elements after sorting the array.
Returns
-------
:class:`np.ndarray[np.str_] <numpy.ndarray>`
A new array with the atoms sorted within each parameter.
Raises
------
:exc:`ValueError`
Raised when ``check_duplicates is True`` and duplicate parameters are present
in the to-be returned array.
"""
atoms: NDArray[np.str_] = np.asarray(param)
if atoms.dtype.kind != "U":
raise TypeError(f"Expected a string array; observed dtype: {atoms.dtype}")
elif atoms.size == 0:
return atoms if atoms is not param else atoms.copy()
atoms_split = np.array(np.char.split(atoms, seperator).tolist())

# Sort the atoms whenever dealing with atom-pair/triplet-based parameters
n = atoms_split.shape[-1]
if n == 1:
ret = atoms if atoms is not param else atoms.copy()
else:
if n == 2:
atoms_split.sort(axis=-1)
elif n == 3:
atoms_split[..., ::2].sort(axis=-1)
else:
raise NotImplementedError(
f"Sorting parameters consisting of {n} atoms is not supported"
)

iterator = (seperator.join(i) for i in atoms_split.reshape(-1, n))
ret = np.fromiter(iterator, dtype=atoms.dtype, count=atoms.size)
ret.shape = atoms.shape

# Check for duplicates
if not check_duplicates:
return ret

unique, idx, counts = np.unique(ret, return_index=True, return_counts=True)
is_duplicate = counts != 1
if is_duplicate.any():
duplicates = unique[is_duplicate]
raise ValueError(f"Duplicate parameters: {duplicates}")
return ret
2 changes: 1 addition & 1 deletion tests/test_armc.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def test_param_sorting() -> None:
df2.loc[4, "atoms"] = "Se Cd"
df2.loc[5, "atoms"] = "Cd Se"
df2.loc[6, "atoms"] = "Se Cd"
assertion.assert_(_sort_atoms, df2, exception=KeyError)
assertion.assert_(_sort_atoms, df2, exception=ValueError)


@delete_finally(PATH / '_ARMC', PATH / '_ARMCPT')
Expand Down
Binary file added tests/test_files/sort_param.hdf5
Binary file not shown.
63 changes: 63 additions & 0 deletions tests/test_sorting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
"""Tests for :mod:`FOX.functions.sorting`."""

from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING, Tuple, Mapping, Any, Type
from itertools import chain

import numpy as np
import h5py
import pytest
from FOX.functions.sorting import sort_param

if TYPE_CHECKING:
import numpy.typing as npt

PATH = Path('tests') / 'test_files'
HDF5 = PATH / 'sort_param.hdf5'


def construct_param(sep: str) -> Tuple[Tuple[str, npt.ArrayLike, str], ...]:
s = sep
a = ("")
b = ("Cd", "Se")
c = (f"Cd{s}Se", f"Se{s}Cd")
d = (f"Cd{s}Cd{s}Cd", f"Cd{s}Cd{s}Se", f"Cd{s}Se{s}Se", f"Cd{s}Se{s}Cd", f"Se{s}Se{s}Se")

ret = [
np.array(i, ndmin=j) for i in chain(a, b, c, d) for j in range(3)
]
ret += [i.tolist() for i in ret]
ret += [np.array([], dtype=np.str_), np.array([[]], dtype=np.str_),
np.array([[[]]], dtype=np.str_)]
return tuple((str(i), item, sep) for i, item in enumerate(ret))


PARAM = construct_param(" ") + construct_param("|")


@pytest.mark.parametrize("name,param,sep", PARAM)
def test_passes(name: str, param: npt.ArrayLike, sep: str) -> None:
"""Test :func:`FOX.functions.sorting.sort_param`."""
name += sep
with h5py.File(HDF5, "r", libver="latest") as f:
ref = f[name][...].astype(np.str_)

out = sort_param(param, sep)
np.testing.assert_array_equal(out, ref)


@pytest.mark.parametrize(
"kwargs,exc",
[
({"param": range(3)}, TypeError),
({"param": range(3), "casting": "safe"}, TypeError),
({"param": "Cd Cd Cd Cd"}, NotImplementedError),
({"param": ["Cd Cd Cd", "Cd Cd Cd"]}, ValueError),
]
)
def test_raises(kwargs: Mapping[str, Any], exc: Type[Exception]) -> None:
"""Test :func:`FOX.functions.sorting.sort_param`."""
with pytest.raises(exc):
sort_param(**kwargs)

0 comments on commit 481b243

Please sign in to comment.