Skip to content

Commit

Permalink
Merge pull request #222 from nlesc-nano/adf
Browse files Browse the repository at this point in the history
ENH: Add support for periodic ADF calculations with `r_max=inf`
  • Loading branch information
BvB93 committed Mar 1, 2021
2 parents 7d467da + 9831b39 commit d42e490
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 44 deletions.
27 changes: 15 additions & 12 deletions FOX/classes/multi_mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@
_warn.__cause__ = ex
warnings.warn(_warn)
del _warn
DASK_EX = Exception()

try:
from ase import Atoms
Expand Down Expand Up @@ -1716,11 +1715,6 @@ def init_adf(

# Periodic calculations
if periodic is not None:
if not r_max:
raise NotImplementedError(
"periodic calculations are not supported for `r_max=inf`"
)

# Validate the parameters
periodic_set = {i.lower() for i in periodic}
lattice = self.lattice
Expand All @@ -1730,11 +1724,13 @@ def init_adf(
elif lattice is None:
raise TypeError("cannot perform periodic calculations if the "
"molecules `lattice` is None")
lattice = lattice[m_subset]

# Scipy's `cKDTree` only supports cuboid lattices
is0 = np.abs(lattice - 0) < 1e-8
if not (np.count_nonzero(is0, axis=-1) == 2).all():
raise NotImplementedError("non-cuboid lattices are not supported")
if r_max_:
is0 = np.abs(lattice - 0) < 1e-8
if not (np.count_nonzero(is0, axis=-1) == 2).all():
raise NotImplementedError("non-cuboid lattices are not supported")

# Set the vector-length of all absent axes to `inf`
xyz_dct = {"x": 0, "y": 1, "z": 2}
Expand All @@ -1746,9 +1742,14 @@ def init_adf(
mol -= mol.min(axis=1)[..., None, :]
if lattice.ndim == 2:
boxsize_iter = repeat(np.linalg.norm(lattice, axis=-1))
lattice_iter = repeat(lattice)
else:
boxsize_iter = (i for i in np.linalg.norm(lattice, axis=-1))
boxsize_iter = iter(np.linalg.norm(lattice, axis=-1))
lattice_iter = iter(lattice)
periodic_iter = repeat(sorted(periodic_set))
else:
lattice_iter = repeat(None)
periodic_iter = repeat("xyz")
boxsize_iter = repeat(None)

# Construct the angular distribution function
Expand All @@ -1760,15 +1761,17 @@ def init_adf(
results = dask.compute(*jobs)
elif DASK_EX is None and not r_max_:
func = dask.delayed(_adf_inner)
jobs = [func(m, atom_pairs.values(), weight) for m in mol]
jobs = [func(m, atom_pairs.values(), l, p, weight) for m, l, p in
zip(mol, lattice_iter, periodic_iter)]
results = dask.compute(*jobs)
elif DASK_EX is not None and r_max_:
func = _adf_inner_cdktree
results = [func(m, n, r_max_, atom_pairs.values(), b, weight) for m, b in
zip(mol, boxsize_iter)]
elif DASK_EX is not None and not r_max_:
func = _adf_inner
results = [func(m, atom_pairs.values(), weight) for m in mol]
results = [func(m, atom_pairs.values(), l, p, weight) for m, l, p in
zip(mol, lattice_iter, periodic_iter)]

df = get_adf_df(atom_pairs)
df.loc[:, :] = np.array(results).mean(axis=0).T
Expand Down
92 changes: 62 additions & 30 deletions FOX/functions/adf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from __future__ import annotations

import sys
from typing import (
Sequence,
Hashable,
Expand All @@ -34,6 +35,12 @@
from scipy.spatial.distance import cdist

if TYPE_CHECKING:
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from numpy import float64 as f8, int64 as i8

_T = TypeVar("_T")
_SCT = TypeVar("_SCT", bound=np.generic)

Expand All @@ -44,33 +51,33 @@


def _adf_inner_cdktree(
m: NDArray[np.float64],
m: NDArray[f8],
n: int,
r_max: float,
idx_list: Iterable[_3Tuple[NDArray[np.integer[Any]]]],
boxsize: None | NDArray[np.float64],
weight: None | Callable[[NDArray[np.float64]], NDArray[np.float64]] = None,
) -> List[NDArray[np.float64]]:
idx_list: Iterable[_3Tuple[NDArray[np.bool_]]],
boxsize: None | NDArray[f8],
weight: None | Callable[[NDArray[f8]], NDArray[f8]] = None,
) -> List[NDArray[f8]]:
"""Perform the loop of :meth:`.init_adf` with a distance cutoff."""
# Construct slices and a distance matrix
tree = cKDTree(m, boxsize=boxsize)
dist, idx = tree.query(m, n, distance_upper_bound=r_max, p=2) # type: NDArray[np.float64], NDArray[np.intp] # noqa: E501
dist, idx = tree.query(m, n, distance_upper_bound=r_max, p=2) # type: NDArray[f8], NDArray[i8] # noqa: E501
dist[dist == np.inf] = 0.0
idx[idx == m.shape[0]] = 0

# Slice the Cartesian coordinates
coords13: NDArray[np.float64] = m[idx]
coords2: NDArray[np.float64] = m[..., None, :]
coords13: NDArray[f8] = m[idx]
coords2: NDArray[f8] = m[..., None, :]

# Construct (3D) angle- and distance-matrices
with np.errstate(divide='ignore', invalid='ignore'):
vec: NDArray[np.float64] = ((coords13 - coords2) / dist[..., None])
ang: NDArray[np.float64] = np.arccos(np.einsum('jkl,jml->jkm', vec, vec))
vec: NDArray[f8] = ((coords13 - coords2) / dist[..., None])
ang: NDArray[f8] = np.arccos(np.einsum('jkl,jml->jkm', vec, vec))
dist = np.maximum(dist[..., None], dist[..., None, :])
ang[np.isnan(ang)] = 0.0

# Radian (float) to degrees (int)
ang_int: NDArray[np.int64] = np.degrees(ang).astype(np.int64)
ang_int: NDArray[i8] = np.degrees(ang).astype(np.int64)

# Construct and return the ADF
ret = []
Expand All @@ -82,27 +89,32 @@ def _adf_inner_cdktree(


def _adf_inner(
m: NDArray[np.float64],
idx_list: Iterable[_3Tuple[NDArray[np.integer[Any]]]],
weight: None | Callable[[NDArray[np.float64]], NDArray[np.float64]] = None,
) -> List[NDArray[np.float64]]:
m: NDArray[f8],
idx_list: Iterable[_3Tuple[NDArray[np.bool_]]],
lattice: None | NDArray[f8],
periodicity: Iterable[Literal["x", "y", "z"]] = "xyz",
weight: None | Callable[[NDArray[f8]], NDArray[f8]] = None,
) -> List[NDArray[f8]]:
"""Perform the loop of :meth:`.init_adf` without a distance cutoff."""
# Construct a distance matrix
dist: NDArray[np.float64] = cdist(m, m)

# Slice the Cartesian coordinates
coords13: NDArray[np.float64] = m
coords2: NDArray[np.float64] = m[..., None, :]

# Construct (3D) angle- and distance-matrices
with np.errstate(divide='ignore', invalid='ignore'):
vec: NDArray[np.float64] = ((coords13 - coords2) / dist[..., None])
ang: NDArray[np.float64] = np.arccos(np.einsum('jkl,jml->jkm', vec, vec))
dist = np.maximum(dist[..., None], dist[..., None, :])
if lattice is None:
# Construct a distance matrix
dist: NDArray[f8] = cdist(m, m)

# Slice the Cartesian coordinates
coords13: NDArray[f8] = m
coords2: NDArray[f8] = m[..., None, :]

vec: NDArray[f8] = (coords13 - coords2) / dist[..., None]
else:
dist, vec = _adf_inner_periodic(m, lattice, periodicity)
ang: NDArray[f8] = np.arccos(np.einsum('jkl,jml->jkm', vec, vec))
dist = np.maximum(dist[..., :, None], dist[..., None, :])
ang[np.isnan(ang)] = 0.0

# Radian (float) to degrees (int)
ang_int: NDArray[np.int64] = np.degrees(ang).astype(np.int64)
ang_int: NDArray[i8] = np.degrees(ang).astype(np.int64)

# Construct and return the ADF
ret = []
Expand All @@ -113,6 +125,26 @@ def _adf_inner(
return ret


def _adf_inner_periodic(
m: NDArray[f8],
lattice: NDArray[f8],
periodicity: Iterable[Literal["x", "y", "z"]] = "xyz",
) -> Tuple[NDArray[f8], NDArray[f8]]:
"""Construct the distance matrix and angle-defining vectors for periodic systems."""
vec = m - m[..., None, :]
lat_norm = np.linalg.norm(lattice, axis=-1)

dct = {"x": 0, "y": 1, "z": 2}
iterator = ((dct[i], lat_norm[dct[i]]) for i in periodicity)
for i, vec_len in iterator:
vec[..., i][vec[..., i] > (vec_len / 2)] -= vec_len
vec[..., i][vec[..., i] < -(vec_len / 2)] += vec_len

dist = np.linalg.norm(vec, axis=-1)
vec /= dist[..., None]
return dist, vec


def get_adf_df(atom_pairs: Sequence[Hashable]) -> pd.DataFrame:
"""Construct and return a pandas dataframe filled to hold angular distribution functions.
Expand All @@ -137,7 +169,7 @@ def get_adf_df(atom_pairs: Sequence[Hashable]) -> pd.DataFrame:
def get_adf(
ang: NDArray[np.integer[Any]],
weights: None | NDArray[np.number[Any]] = None,
) -> NDArray[np.float64]:
) -> NDArray[f8]:
r"""Calculate and return the angular distribution function (ADF).
Parameters
Expand All @@ -159,14 +191,14 @@ def get_adf(
"""
# Calculate and normalize the density
denominator = len(ang) / 180
at_count: NDArray[np.int64] = np.bincount(ang, minlength=181)[1:181]
dens: NDArray[np.float64] = at_count / denominator
at_count: NDArray[i8] = np.bincount(ang, minlength=181)[1:181]
dens: NDArray[f8] = at_count / denominator

if weights is None:
return dens

# Weight (and re-normalize) the density based on the distance matrix **dist**
area: np.float64 = dens.sum()
area: f8 = dens.sum()
with np.errstate(divide='ignore', invalid='ignore'):
dens *= np.bincount(ang, weights=weights, minlength=181)[1:181] / at_count
dens *= area / np.nansum(dens)
Expand Down
Binary file added tests/test_files/adf_2d_inf.npy
Binary file not shown.
Binary file added tests/test_files/adf_3d_inf.npy
Binary file not shown.
Binary file added tests/test_files/adf_periodic_2d_inf.npy
Binary file not shown.
Binary file added tests/test_files/adf_periodic_3d_inf.npy
Binary file not shown.
10 changes: 8 additions & 2 deletions tests/test_multi_mol.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,8 +288,15 @@ class TestADF:
("adf", MOL, {"atom_subset": ("Cd", "Se"), "weight": None}),
("adf_periodic_2d", MOL_LATTICE_2D_ORTH, {"atom_subset": ("Pb",), "periodic": "xyz"}),
("adf_periodic_3d", MOL_LATTICE_3D_ORTH, {"atom_subset": ("Pb",), "periodic": "xy"}),
("adf_2d_inf", MOL_LATTICE_2D, {"atom_subset": ("Pb",), "r_max": np.inf}),
("adf_3d_inf", MOL_LATTICE_3D, {"atom_subset": ("Pb",), "r_max": np.inf}),
("adf_periodic_2d_inf", MOL_LATTICE_2D,
{"atom_subset": ("Pb",), "periodic": "xyz", "r_max": np.inf}),
("adf_periodic_3d_inf", MOL_LATTICE_3D,
{"atom_subset": ("Pb",), "periodic": "xy", "r_max": np.inf}),
],
ids=["adf_weighted", "adf", "adf_periodic_2d", "adf_periodic_3d"],
ids=["adf_weighted", "adf", "adf_periodic_2d", "adf_periodic_3d",
"adf_2d_inf", "adf_3d_inf", "adf_periodic_2d_inf", "adf_periodic_3d_inf"],
)
def test_passes(self, name: str, mol: MultiMolecule, kwargs: Mapping[str, Any]) -> None:
adf = mol.init_adf(**kwargs)
Expand All @@ -300,7 +307,6 @@ def test_passes(self, name: str, mol: MultiMolecule, kwargs: Mapping[str, Any])
"kwargs,exc",
[
({"periodic": "bob"}, ValueError),
({"periodic": "xyz", "r_max": np.inf}, NotImplementedError),
({"periodic": "xyz"}, NotImplementedError),
]
)
Expand Down

0 comments on commit d42e490

Please sign in to comment.