Skip to content

Commit

Permalink
MAINT: Create a dedicated module for handling periodicty-related util…
Browse files Browse the repository at this point in the history
…ity functions (#224)
  • Loading branch information
BvB93 committed Mar 2, 2021
1 parent 3e483da commit f416544
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 37 deletions.
52 changes: 19 additions & 33 deletions FOX/classes/multi_mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from typing import (
Sequence, Optional, Union, List, Hashable, Callable, Iterable, Dict, Tuple, Any, Mapping,
overload, TypeVar, Type, Container, cast, Generator, TYPE_CHECKING,
overload, TypeVar, Type, Container, cast, TYPE_CHECKING,
)

import numpy as np
Expand All @@ -44,6 +44,7 @@
from ..functions.rdf import get_rdf, get_rdf_df
from ..functions.adf import get_adf_df, _adf_inner_cdktree, _adf_inner
from ..functions.molecule_utils import fix_bond_orders, separate_mod
from ..functions.periodic import parse_periodic

if TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -80,16 +81,6 @@ def neg_exp(x: np.ndarray) -> np.ndarray:
return np.exp(-x)


def _periodicity_iter(
periodicity: Iterable[Literal["x", "y", "z"]],
ar: np.ndarray[Any, _DType],
) -> Generator[Tuple[int, np.ndarray[Any, _DType]], None, None]:
dct = {"x": 0, "y": 1, "z": 2}
for i in sorted(periodicity):
j = dct[i]
yield j, ar[j]


class _GetNone:
def __getitem__(self, key: object) -> None:
return None
Expand Down Expand Up @@ -1286,7 +1277,7 @@ def init_rdf(
*,
dr: float = 0.05,
r_max: float = 12.0,
periodic: None | Iterable[Literal["x", "y", "z"]] = None,
periodic: None | Sequence[Literal["x", "y", "z"]] | Sequence[Literal[0, 1, 2]] = None,
) -> pd.DataFrame:
"""Initialize the calculation of radial distribution functions (RDFs).
Expand Down Expand Up @@ -1337,17 +1328,14 @@ def init_rdf(

# Parse the lattice and periodicty settings
if periodic is not None:
periodic_set = {i.lower() for i in periodic}
if not periodic_set.issubset("xyz"):
raise ValueError("periodic expected `x`, `y` and/or `z`; "
f"observed value: {periodic!r}")
elif self.lattice is None:
periodic_ar = parse_periodic(periodic)
if self.lattice is None:
raise TypeError("cannot perform periodic calculations if the "
"molecules `lattice` is None")
lattice_ar = self.lattice if self.lattice.ndim == 2 else self.lattice[m_subset]
else:
lattice_ar = _GetNone()
periodic_set = "xyz"
periodic_ar = np.arange(3, dtype=np.int64)

# Identify the volume occupied by the system
if self.lattice is None:
Expand All @@ -1365,7 +1353,7 @@ def init_rdf(
for slc in iterator:
dist_mat = m_self.get_dist_mat(
mol_subset=slc, atom_subset=(i, j),
lattice=lattice_ar[slc], periodicity=periodic_set,
lattice=lattice_ar[slc], periodicity=periodic_ar,
)
df[key] += get_rdf(dist_mat, dr=dr, r_max=r_max, volume=volume)
df /= n_mol
Expand All @@ -1376,7 +1364,7 @@ def get_dist_mat(
mol_subset: MolSubset = None,
atom_subset: Tuple[AtomSubset, AtomSubset] = (None, None),
lattice: None | np.ndarray[Any, np.dtype[np.float64]] = None,
periodicity: Iterable[Literal["x", "y", "z"]] = "xyz",
periodicity: Iterable[Literal[0, 1, 2]] = range(3),
) -> np.ndarray[Any, np.dtype[np.float64]]:
"""Create and return a distance matrix for all molecules and atoms in this instance.
Expand Down Expand Up @@ -1426,11 +1414,11 @@ def get_dist_mat(
ret = np.abs(A[..., None, :] - B[..., None, :, :])
lat_norm = np.linalg.norm(lattice, axis=-1)
if lat_norm.ndim == 1:
iterator = _periodicity_iter(periodicity, lat_norm)
iterator = ((i, lat_norm[i]) for i in periodicity)
for i, ar1 in iterator:
ret[..., i][ret[..., i] > (ar1 / 2)] -= ar1
elif lat_norm.ndim == 2:
iterator = _periodicity_iter(periodicity, lat_norm.T)
iterator = ((i, lat_norm[:, i]) for i in periodicity)
for i, _ar2 in iterator:
ar2 = np.full(ret.shape[:-1], _ar2[..., None, None])
condition = ret[..., i] > (ar2 / 2)
Expand Down Expand Up @@ -1633,7 +1621,7 @@ def init_adf(
atom_subset: AtomSubset = None,
r_max: Union[float, str] = 8.0,
weight: Callable[[np.ndarray], np.ndarray] = neg_exp,
periodic: None | Iterable[Literal["x", "y", "z"]] = None,
periodic: None | Sequence[Literal["x", "y", "z"]] | Sequence[Literal[0, 1, 2]] = None,
) -> pd.DataFrame:
r"""Initialize the calculation of distance-weighted angular distribution functions (ADFs).
Expand Down Expand Up @@ -1715,16 +1703,15 @@ def init_adf(

# Periodic calculations
if periodic is not None:
periodic_ar = parse_periodic(periodic)

# Validate the parameters
periodic_set = {i.lower() for i in periodic}
lattice = self.lattice
if not periodic_set.issubset("xyz"):
raise ValueError("periodic expected `x`, `y` and/or `z`; "
f"observed value: {periodic!r}")
elif lattice is None:
if lattice is None:
raise TypeError("cannot perform periodic calculations if the "
"molecules `lattice` is None")
lattice = lattice[m_subset]
else:
lattice = cast("np.ndarray[Any, np.dtype[np.float64]]", lattice[m_subset])

# Scipy's `cKDTree` only supports cuboid lattices
if r_max_:
Expand All @@ -1733,8 +1720,7 @@ def init_adf(
raise NotImplementedError("non-cuboid lattices are not supported")

# Set the vector-length of all absent axes to `inf`
xyz_dct = {"x": 0, "y": 1, "z": 2}
slc = [xyz_dct[i] for i in sorted(xyz_dct.keys() - periodic_set)]
slc = [i for i in range(3) if i not in periodic_ar]
lattice[..., slc, :] = np.inf

# Perform a translation to remove negative elements, as `cKDTree` cannot
Expand All @@ -1746,10 +1732,10 @@ def init_adf(
else:
boxsize_iter = iter(np.linalg.norm(lattice, axis=-1))
lattice_iter = iter(lattice)
periodic_iter = repeat(sorted(periodic_set))
periodic_iter = repeat(periodic_ar)
else:
lattice_iter = repeat(None)
periodic_iter = repeat("xyz")
periodic_iter = repeat(range(3))
boxsize_iter = repeat(None)

# Construct the angular distribution function
Expand Down
7 changes: 3 additions & 4 deletions FOX/functions/adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _adf_inner(
m: NDArray[f8],
idx_list: Iterable[_3Tuple[NDArray[np.bool_]]],
lattice: None | NDArray[f8],
periodicity: Iterable[Literal["x", "y", "z"]] = "xyz",
periodicity: Iterable[Literal[0, 1, 2]] = range(3),
weight: None | Callable[[NDArray[f8]], NDArray[f8]] = None,
) -> List[NDArray[f8]]:
"""Perform the loop of :meth:`.init_adf` without a distance cutoff."""
Expand Down Expand Up @@ -128,14 +128,13 @@ def _adf_inner(
def _adf_inner_periodic(
m: NDArray[f8],
lattice: NDArray[f8],
periodicity: Iterable[Literal["x", "y", "z"]] = "xyz",
periodicity: Iterable[Literal[0, 1, 2]] = range(3),
) -> Tuple[NDArray[f8], NDArray[f8]]:
"""Construct the distance matrix and angle-defining vectors for periodic systems."""
vec = m - m[..., None, :]
lat_norm = np.linalg.norm(lattice, axis=-1)

dct = {"x": 0, "y": 1, "z": 2}
iterator = ((dct[i], lat_norm[dct[i]]) for i in periodicity)
iterator = ((i, lat_norm[i]) for i in periodicity)
for i, vec_len in iterator:
vec[..., i][vec[..., i] > (vec_len / 2)] -= vec_len
vec[..., i][vec[..., i] < -(vec_len / 2)] += vec_len
Expand Down
108 changes: 108 additions & 0 deletions FOX/functions/periodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Utility functions related to calculations on periodic systems.
Index
-----
.. currentmodule:: FOX.functions.periodic
.. autosummary::
parse_periodic
API
---
.. autofunction:: parse_periodic
"""

from __future__ import annotations

from types import MappingProxyType
from typing import TypeVar, Any, Sequence, Mapping, Union, Iterable, TYPE_CHECKING
from itertools import chain

import numpy as np

if TYPE_CHECKING:
import sys
if sys.version_info >= (3, 8):
from typing import Literal, TypedDict
else:
from typing_extensions import Literal, TypedDict

SCT = TypeVar("SCT", bound=np.generic)
NDArray = np.ndarray[Any, np.dtype[SCT]]

class _XYZDict(TypedDict):
x: Literal[0]
y: Literal[1]
z: Literal[2]


def parse_periodic(
xyz: Union[
int, str,
Sequence[str], Sequence[int],
NDArray[np.str_], NDArray[np.integer],
]
) -> NDArray[np.intp]:
"""Parse the passed periodicity specifier and convert it into an array of integers.
Parameters
----------
xyz : :class:`str` or :class:`Sequence[int] <collections.abc.Sequence>`
A string or sequence of integers representing .
Expects either ``"x"``, ``"y"`` and/or ``"z"`` or ``0``, ``1`` and/or ``2``.
Returns
-------
:class:`np.ndarray[np.intp] <numpy.ndarray>`
An array with the indices representing the ``x``, ``y`` and/or ``z`` axes.
"""
ar: NDArray[np.str_ | np.integer] = np.array(xyz, ndmin=1)
if ar.ndim != 1:
raise ValueError(f"Expected a 1D array; observed dimensionality: {ar.ndim}")
elif ar.size == 0:
raise ValueError("Expected a non-empty array")

if ar.dtype.kind == "U":
return _parse_char(ar)
elif ar.dtype.kind in "ui":
return _parse_int(ar)
else:
raise TypeError(f"Invalid dtype: {ar.dtype}")


_XYZ_DICT: _XYZDict = MappingProxyType({ # type: ignore[assignment]
"x": 0,
"y": 1,
"z": 2,
})

_012_DICT: Mapping[int, int] = MappingProxyType({
0: 0,
1: 1,
2: 2,
})


def _parse_char(ar: NDArray[np.str_]) -> NDArray[np.intp]:
"""Helper for :func:`parse_periodic`; parses :class:`numpy.str_`-based arrays."""
ar_low: Iterable[Literal["x", "y", "z"]] = np.fromiter(
chain.from_iterable(np.char.lower(ar)),
dtype="U1",
)
try:
ret: NDArray[np.intp] = np.fromiter({_XYZ_DICT[i] for i in ar_low}, dtype=np.intp)
except KeyError as ex:
raise ValueError(f"Invalid axis specifier: {ex}") from None
ret.sort()
return ret


def _parse_int(ar: NDArray[np.integer]) -> NDArray[np.intp]:
"""Helper for :func:`parse_periodic`; parses :class:`numpy.integer`-based arrays."""
try:
ret: NDArray[np.intp] = np.fromiter({_012_DICT[i] for i in ar}, dtype=np.intp)
except KeyError as ex:
raise ValueError(f"Invalid axis specifier: {ex}") from None
ret.sort()
return ret
47 changes: 47 additions & 0 deletions tests/test_periodic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""Tests for :mod:`FOX.functions.periodic`."""

from typing import Any, Type
from itertools import combinations, chain

import pytest
import numpy as np
from assertionlib import assertion

from FOX.functions.periodic import parse_periodic


class TestParsePeriodic:
"""Tests for :func:`FOX.functions.periodic.parse_periodic`."""

@pytest.mark.parametrize(
"xyz",
chain(
["xyz", "zxxyxyyzzyz", "x", "y", "z", 0, 1, 2],
combinations("xyz", r=1),
combinations("xyz", r=2),
combinations("xyz", r=3),
combinations(range(3), r=1),
combinations(range(3), r=2),
combinations(range(3), r=3),
[np.arange(3), np.array(["x", "y", "z"])],
)
)
def test_pass(self, xyz: Any) -> None:
out = parse_periodic(xyz)
assertion.issubset(out, {0, 1, 2})

@pytest.mark.parametrize(
"xyz,exc",
[
(["a", "b"], ValueError),
(np.array([], dtype=np.intp), ValueError),
([-1, 0, 1], ValueError),
(np.arange(9).reshape(3, 3), ValueError),
([True, False], TypeError),
([0.0, 1.0], TypeError),
(b"xyz", TypeError),
]
)
def test_raises(self, xyz: Any, exc: Type[Exception]) -> None:
with pytest.raises(exc):
parse_periodic(xyz)

0 comments on commit f416544

Please sign in to comment.