Skip to content

Commit

Permalink
Merge pull request #2490 from lan496/speedup-group-structures
Browse files Browse the repository at this point in the history
Speed up to group structures by equivalence
  • Loading branch information
shyuep committed Apr 12, 2022
2 parents a82f60d + e65f680 commit 834dc9c
Showing 1 changed file with 54 additions and 28 deletions.
82 changes: 54 additions & 28 deletions pymatgen/analysis/structure_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"""
This module provides classes to perform fitting of structures.
"""

import abc
import itertools

Expand Down Expand Up @@ -579,7 +578,9 @@ def _get_mask(self, struct1, struct2, fu, s1_supercell):
inds = inds[::fu]
return np.array(mask, dtype=int), inds, i

def fit(self, struct1, struct2, symmetric=False):
def fit(
self, struct1: Structure, struct2: Structure, symmetric: bool = False, skip_structure_reduction: bool = False
) -> bool:
"""
Fit two structures.
Expand All @@ -589,6 +590,8 @@ def fit(self, struct1, struct2, symmetric=False):
symmetric (Bool): Defaults to False
If True, check the equality both ways.
This only impacts a small percentage of structures
skip_structure_reduction (Bool): Defaults to False
If True, skip to get a primitive structure and perform Niggli reduction for struct1 and struct2
Returns:
True or False.
Expand All @@ -598,20 +601,26 @@ def fit(self, struct1, struct2, symmetric=False):
if not self._subset and self._comparator.get_hash(struct1.composition) != self._comparator.get_hash(
struct2.composition
):
return None
return False

if not symmetric:
struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
struct1, struct2, fu, s1_supercell = self._preprocess(
struct1, struct2, skip_structure_reduction=skip_structure_reduction
)
match = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)
if match is None:
return False

return match[0] <= self.stol

struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
struct1, struct2, fu, s1_supercell = self._preprocess(
struct1, struct2, skip_structure_reduction=skip_structure_reduction
)
match1 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)
struct1, struct2 = struct2, struct1
struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2)
struct1, struct2, fu, s1_supercell = self._preprocess(
struct1, struct2, skip_structure_reduction=skip_structure_reduction
)
match2 = self._match(struct1, struct2, fu, s1_supercell, break_on_match=True)

if match1 is None or match2 is None:
Expand Down Expand Up @@ -652,23 +661,21 @@ def _process_species(self, structures):
copied_structures.append(ss)
return copied_structures

def _preprocess(self, struct1, struct2, niggli=True):
def _preprocess(self, struct1, struct2, niggli=True, skip_structure_reduction: bool = False):
"""
Rescales, finds the reduced structures (primitive and niggli),
and finds fu, the supercell size to make struct1 comparable to
s2
"""
struct1 = struct1.copy()
struct2 = struct2.copy()

if niggli:
struct1 = struct1.get_reduced_structure(reduction_algo="niggli")
struct2 = struct2.get_reduced_structure(reduction_algo="niggli")

# primitive cell transformation
if self._primitive_cell:
struct1 = struct1.get_primitive_structure()
struct2 = struct2.get_primitive_structure()
s2.
If skip_structure_reduction is True, skip to get reduced structures (by primitive transformation and
niggli reduction). This option is useful for fitting a set of structures several times.
"""
if skip_structure_reduction:
# Need to copy original structures to rescale lattices later
struct1 = struct1.copy()
struct2 = struct2.copy()
else:
struct1 = self._get_reduced_structure(struct1, self._primitive_cell, niggli)
struct2 = self._get_reduced_structure(struct2, self._primitive_cell, niggli)

if self._supercell:
fu, s1_supercell = self._get_supercell_size(struct1, struct2)
Expand Down Expand Up @@ -805,6 +812,8 @@ def group_structures(self, s_list, anonymous=False):

original_s_list = list(s_list)
s_list = self._process_species(s_list)
# Prepare reduced structures beforehand
s_list = [self._get_reduced_structure(s, self._primitive_cell, niggli=True) for s in s_list]

# Use structure hash to pre-group structures
if anonymous:
Expand All @@ -822,19 +831,19 @@ def s_hash(s):
all_groups = []

# For each pre-grouped list of structures, perform actual matching.
for k, g in itertools.groupby(sorted_s_list, key=s_hash):
for _, g in itertools.groupby(sorted_s_list, key=s_hash):
unmatched = list(g)
while len(unmatched) > 0:
i, refs = unmatched.pop(0)
matches = [i]
if anonymous:
inds = filter(
lambda i: self.fit_anonymous(refs, unmatched[i][1]),
lambda i: self.fit_anonymous(refs, unmatched[i][1], skip_structure_reduction=True),
list(range(len(unmatched))),
)
else:
inds = filter(
lambda i: self.fit(refs, unmatched[i][1]),
lambda i: self.fit(refs, unmatched[i][1], skip_structure_reduction=True),
list(range(len(unmatched))),
)
inds = list(inds)
Expand Down Expand Up @@ -942,6 +951,18 @@ def _anonymous_match(
break
return matches

@classmethod
def _get_reduced_structure(cls, struct: Structure, primitive_cell: bool = True, niggli: bool = True) -> Structure:
"""
Helper method to find a reduced structure
"""
reduced = struct.copy()
if niggli:
reduced = reduced.get_reduced_structure(reduction_algo="niggli")
if primitive_cell:
reduced = reduced.get_primitive_structure()
return reduced

def get_rms_anonymous(self, struct1, struct2):
"""
Performs an anonymous fitting, which allows distinct species in one
Expand Down Expand Up @@ -1029,7 +1050,9 @@ def get_all_anonymous_mappings(self, struct1, struct2, niggli=True, include_dist

return None

def fit_anonymous(self, struct1, struct2, niggli=True):
def fit_anonymous(
self, struct1: Structure, struct2: Structure, niggli: bool = True, skip_structure_reduction: bool = False
):
"""
Performs an anonymous fitting, which allows distinct species in one
structure to map to another. E.g., to compare if the Li2O and Na2O
Expand All @@ -1038,12 +1061,15 @@ def fit_anonymous(self, struct1, struct2, niggli=True):
Args:
struct1 (Structure): 1st structure
struct2 (Structure): 2nd structure
niggli (Bool): If true, perform Niggli reduction for struct1 and struct2
skip_structure_reduction (Bool): Defaults to False
If True, skip to get a primitive structure and perform Niggli reduction for struct1 and struct2
Returns:
True/False: Whether a species mapping can map struct1 to stuct2
"""
struct1, struct2 = self._process_species([struct1, struct2])
struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli)
struct1, struct2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli, skip_structure_reduction)

matches = self._anonymous_match(struct1, struct2, fu, s1_supercell, break_on_match=True, single_match=True)

Expand All @@ -1057,7 +1083,7 @@ def get_supercell_matrix(self, supercell, struct):
"""
if self._primitive_cell:
raise ValueError("get_supercell_matrix cannot be used with the primitive cell option")
struct, supercell, fu, s1_supercell = self._preprocess(struct, supercell, False)
struct, supercell, fu, s1_supercell = self._preprocess(struct, supercell, niggli=False)

if not s1_supercell:
raise ValueError("The non-supercell must be put onto the basis of the supercell, not the other way around")
Expand Down Expand Up @@ -1092,7 +1118,7 @@ def get_transformation(self, struct1, struct2):

struct1, struct2 = self._process_species((struct1, struct2))

s1, s2, fu, s1_supercell = self._preprocess(struct1, struct2, False)
s1, s2, fu, s1_supercell = self._preprocess(struct1, struct2, niggli=False)
ratio = fu if s1_supercell else 1 / fu
if s1_supercell and fu > 1:
raise ValueError("Struct1 must be the supercell, not the other way around")
Expand Down Expand Up @@ -1180,7 +1206,7 @@ def get_mapping(self, superset, subset):
if len(subset) > len(superset):
raise ValueError("subset is larger than superset")

superset, subset, _, _ = self._preprocess(superset, subset, True)
superset, subset, _, _ = self._preprocess(superset, subset, niggli=True)
match = self._strict_match(superset, subset, 1, break_on_match=False)

if match is None or match[0] > self.stol:
Expand Down

0 comments on commit 834dc9c

Please sign in to comment.