Skip to content

Commit

Permalink
Speeding up get_nn_info in local_env.py (#3635)
Browse files Browse the repository at this point in the history
* Checks if image and index attributes exist before recomputing

`_get_image` and `_get_original_site` do not need to recompute site and index when `site` is a `PeriodicNeighbor` which has them as attributes.

* pre-commit auto-fixes

* Removed np.mod of fractional coordinates before call to get_points_in_sphere

If the modulo is present the image returned by get_points_in_sphere does not point to the original site.

* change "Site not found" exception to ValueError

* Fixed site variable used twice

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Janosh Riebesell <janosh.riebesell@gmail.com>
  • Loading branch information
3 people committed Feb 20, 2024
1 parent bd88e0d commit 38b9b58
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 21 deletions.
30 changes: 18 additions & 12 deletions pymatgen/analysis/local_env.py
Expand Up @@ -543,7 +543,7 @@ def _get_nn_shell_info(
return list(all_sites.values())

@staticmethod
def _get_image(structure, site):
def _get_image(structure: Structure, site: Site) -> tuple[int, int, int]:
"""Private convenience method for get_nn_info,
gives lattice image from provided PeriodicSite and Structure.
Expand All @@ -552,30 +552,36 @@ def _get_image(structure, site):
Note that this method takes O(number of sites) due to searching an original site.
Args:
structure: Structure Object
site: PeriodicSite Object
structure (Structure): Structure Object
site (Site): PeriodicSite Object
Returns:
image: ((int)*3) Lattice image
tuple[int, int , int] Lattice image
"""
if isinstance(site, PeriodicNeighbor):
return site.image

original_site = structure[NearNeighbors._get_original_site(structure, site)]
image = np.around(np.subtract(site.frac_coords, original_site.frac_coords))
return tuple(image.astype(int))

@staticmethod
def _get_original_site(structure, site):
def _get_original_site(structure: Structure, site: Site) -> int:
"""Private convenience method for get_nn_info,
gives original site index from ProvidedPeriodicSite.
"""
if isinstance(site, PeriodicNeighbor):
return site.index

if isinstance(structure, (IStructure, Structure)):
for i, s in enumerate(structure):
if site.is_periodic_image(s):
return i
for idx, struc_site in enumerate(structure):
if site.is_periodic_image(struc_site):
return idx
else:
for i, s in enumerate(structure):
if site == s:
return i
raise Exception("Site not found!")
for idx, struc_site in enumerate(structure):
if site == struc_site:
return idx
raise ValueError("Site not found in structure")

def get_bonded_structure(
self,
Expand Down
17 changes: 8 additions & 9 deletions pymatgen/core/structure.py
Expand Up @@ -1492,27 +1492,26 @@ def get_sites_in_sphere(
Args:
pt (3x1 array): Cartesian coordinates of center of sphere.
r (float): Radius of sphere.
r (float): Radius of sphere in Angstrom.
include_index (bool): Whether the non-supercell site index
is included in the returned data
is included in the returned data.
include_image (bool): Whether to include the supercell image
is included in the returned data
is included in the returned data.
Returns:
PeriodicNeighbor
"""
site_fcoords = np.mod(self.frac_coords, 1)
neighbors: list[PeriodicNeighbor] = []
for frac_coord, dist, i, img in self._lattice.get_points_in_sphere(site_fcoords, pt, r):
for frac_coord, dist, idx, img in self._lattice.get_points_in_sphere(self.frac_coords, pt, r):
nn_site = PeriodicNeighbor(
self[i].species,
self[idx].species,
frac_coord,
self._lattice,
properties=self[i].properties,
properties=self[idx].properties,
nn_distance=dist,
image=img, # type: ignore
index=i,
label=self[i].label,
index=idx,
label=self[idx].label,
)
neighbors.append(nn_site)
return neighbors
Expand Down

0 comments on commit 38b9b58

Please sign in to comment.