diff --git a/matminer/featurizers/structure/sites.py b/matminer/featurizers/structure/sites.py index 57f9f04fd..b993eead4 100644 --- a/matminer/featurizers/structure/sites.py +++ b/matminer/featurizers/structure/sites.py @@ -2,8 +2,10 @@ Structure featurizers based on aggregating site features. """ +import itertools import numpy as np from pymatgen.analysis.local_env import VoronoiNN +from pymatgen.core.periodic_table import Element, Specie from matminer.featurizers.base import BaseFeaturizer from matminer.featurizers.site import ( @@ -147,8 +149,8 @@ def citations(self): def implementors(self): return ["Nils E. R. Zimmermann", "Alireza Faghaninia", "Anubhav Jain", "Logan Ward", "Alex Dunn"] - @staticmethod - def from_preset(preset, **kwargs): + @classmethod + def from_preset(cls, preset, **kwargs): """ Create a SiteStatsFingerprint class according to a preset @@ -158,37 +160,37 @@ def from_preset(preset, **kwargs): """ if preset == "SOAP_formation_energy": - return SiteStatsFingerprint(SOAP.from_preset("formation_energy"), **kwargs) + return cls(SOAP.from_preset("formation_energy"), **kwargs) elif preset == "CrystalNNFingerprint_cn": - return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs) + return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=False), **kwargs) elif preset == "CrystalNNFingerprint_cn_cation_anion": - return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs) + return cls(CrystalNNFingerprint.from_preset("cn", cation_anion=True), **kwargs) elif preset == "CrystalNNFingerprint_ops": - return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs) + return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=False), **kwargs) elif preset == "CrystalNNFingerprint_ops_cation_anion": - return SiteStatsFingerprint(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs) + return cls(CrystalNNFingerprint.from_preset("ops", cation_anion=True), **kwargs) elif preset == "OPSiteFingerprint": - return SiteStatsFingerprint(OPSiteFingerprint(), **kwargs) + return cls(OPSiteFingerprint(), **kwargs) elif preset == "LocalPropertyDifference_ward-prb-2017": - return SiteStatsFingerprint( + return cls( LocalPropertyDifference.from_preset("ward-prb-2017"), stats=["minimum", "maximum", "range", "mean", "avg_dev"], ) elif preset == "CoordinationNumber_ward-prb-2017": - return SiteStatsFingerprint( + return cls( CoordinationNumber(nn=VoronoiNN(weight="area"), use_weights="effective"), stats=["minimum", "maximum", "range", "mean", "avg_dev"], ) elif preset == "Composition-dejong2016_AD": - return SiteStatsFingerprint( + return cls( LocalPropertyDifference( properties=[ "Number", @@ -204,7 +206,7 @@ def from_preset(preset, **kwargs): ) elif preset == "Composition-dejong2016_SD": - return SiteStatsFingerprint( + return cls( LocalPropertyDifference( properties=[ "Number", @@ -220,13 +222,13 @@ def from_preset(preset, **kwargs): ) elif preset == "BondLength-dejong2016": - return SiteStatsFingerprint( + return cls( AverageBondLength(VoronoiNN()), stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"], ) elif preset == "BondAngle-dejong2016": - return SiteStatsFingerprint( + return cls( AverageBondAngle(VoronoiNN()), stats=["holder_mean::%d" % d for d in range(-4, 4 + 1)] + ["std_dev", "geom_std_dev"], ) @@ -236,8 +238,147 @@ def from_preset(preset, **kwargs): # One of the various Coordination Number presets: # MinimumVIRENN, MinimumDistanceNN, JmolNN, VoronoiNN, etc. try: - return SiteStatsFingerprint(CoordinationNumber.from_preset(preset), **kwargs) + return cls(CoordinationNumber.from_preset(preset), **kwargs) except Exception: pass raise ValueError("Unrecognized preset!") + + +class PartialsSiteStatsFingerprint(SiteStatsFingerprint): + """ + Computes statistics of properties across all sites in a structure, and + breaks these down by element. This featurizer first uses a site featurizer + class (see site.py for options) to compute features of each site of a + specific element in a structure, and then computes features of the entire + structure by measuring statistics of each attribute. + Features: + - Returns each statistic of each site feature, broken down by element + """ + + def __init__( + self, + site_featurizer, + stats=("mean", "std_dev"), + min_oxi=None, + max_oxi=None, + covariance=False, + include_elems=(), + exclude_elems=(), + ): + """ + Args: + site_featurizer (BaseFeaturizer): a site-based featurizer + stats ([str]): list of weighted statistics to compute for each feature. + If stats is None, a list is returned for each features + that contains the calculated feature for each site in the + structure. + *Note for nth mode, stat must be 'n*_mode'; e.g. stat='2nd_mode' + min_oxi (int): minimum site oxidation state for inclusion (e.g., + zero means metals/cations only) + max_oxi (int): maximum site oxidation state for inclusion + covariance (bool): Whether to compute the covariance of site features + """ + + self.include_elems = list(include_elems) + self.exclude_elems = list(exclude_elems) + super().__init__(site_featurizer, stats, min_oxi, max_oxi, covariance) + + def fit(self, X, y=None): + """Define the list of elements to be included in the PRDF. By default, + the PRDF will include all of the elements in `X` + Args: + X: (numpy array nx1) structures used in the training set. Each entry + must be Pymatgen Structure objects. + y: *Not used* + fit_kwargs: *not used* + """ + + # This method largely copies code from the partial-RDF fingerprint + + # Initialize list with included elements + elements = [Element(e) for e in self.include_elems] + + # Get all of elements that appear + for structure in X: + for element in structure.composition.elements: + if isinstance(element, Specie): + element = element.element # converts from Specie to Element object + if element not in elements and element.name not in self.exclude_elems: + elements.append(element) + + # Store the elements + self.elements_ = [e.symbol for e in sorted(elements)] + + def featurize(self, s): + """ + Get PSSF of the input structure. + Args: + s: Pymatgen Structure object. + Returns: + pssf: 1D array of each element's ssf + """ + + if not s.is_ordered: + raise ValueError("Disordered structure support not built yet") + if not hasattr(self, "elements_") or self.elements_ is None: + raise Exception("You must run 'fit' first!") + + output = [] + for e in self.elements_: + pssf_stats = self.compute_pssf(s, e) + output.append(pssf_stats) + + return np.hstack(output) + + def compute_pssf(self, s, e): + + # This code is extremely similar to super().featurize(). The key + # difference is that only one specific element is analyzed. + + # Get each feature for each site + vals = [[] for t in self._site_labels] + for i, site in enumerate(s.sites): + if site.specie.symbol == e: + opvalstmp = self.site_featurizer.featurize(s, i) + for j, opval in enumerate(opvalstmp): + if opval is None: + vals[j].append(0.0) + else: + vals[j].append(opval) + + # If the user does not request statistics, return the site features now + if self.stats is None: + return vals + + # Compute the requested statistics + stats = [] + for op in vals: + for stat in self.stats: + stats.append(PropertyStats().calc_stat(op, stat)) + + # If desired, compute covariances + if self.covariance: + if len(s) == 1: + stats.extend([0] * int(len(vals) * (len(vals) - 1) / 2)) + else: + covar = np.cov(vals) + tri_ind = np.triu_indices(len(vals), 1) + stats.extend(covar[tri_ind].tolist()) + + return stats + + def feature_labels(self): + if not hasattr(self, "elements_") or self.elements_ is None: + raise Exception("You must run 'fit' first!") + + labels = [] + for e in self.elements_: + e_labels = [f"{e} {l}" for l in super().feature_labels()] + for l in e_labels: + labels.append(l) + + return labels + + def implementors(self): + return ["Jack Sundberg"] diff --git a/matminer/featurizers/structure/tests/test_sites.py b/matminer/featurizers/structure/tests/test_sites.py index f34ff7341..8124d313b 100644 --- a/matminer/featurizers/structure/tests/test_sites.py +++ b/matminer/featurizers/structure/tests/test_sites.py @@ -3,7 +3,10 @@ import numpy as np from matminer.featurizers.site import SiteElementalProperty -from matminer.featurizers.structure.sites import SiteStatsFingerprint +from matminer.featurizers.structure.sites import ( + SiteStatsFingerprint, + PartialsSiteStatsFingerprint, +) from matminer.featurizers.structure.tests.base import StructureFeaturesTest @@ -111,5 +114,115 @@ def test_ward_prb_2017_efftcn(self): self.assertArrayAlmostEqual([12, 12, 0, 12, 0], features) +class PartialStructureSitesFeaturesTest(StructureFeaturesTest): + def test_partialsitestatsfingerprint(self): + # Test matrix. + op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint", stats=None) + + op_struct_fp.fit([self.diamond]) + opvals = op_struct_fp.featurize(self.diamond) + _ = op_struct_fp.feature_labels() + self.assertAlmostEqual(opvals[10][0], 0.9995, places=7) + self.assertAlmostEqual(opvals[10][1], 0.9995, places=7) + + op_struct_fp.fit([self.nacl]) + opvals = op_struct_fp.featurize(self.nacl) + self.assertAlmostEqual(opvals[18][0], 0.9995, places=7) + self.assertAlmostEqual(opvals[18][1], 0.9995, places=7) + + op_struct_fp.fit([self.cscl]) + opvals = op_struct_fp.featurize(self.cscl) + self.assertAlmostEqual(opvals[22][0], 0.9995, places=7) + self.assertAlmostEqual(opvals[22][1], 0.9995, places=7) + + # Test stats. + op_struct_fp = PartialsSiteStatsFingerprint.from_preset("OPSiteFingerprint") + op_struct_fp.fit([self.diamond]) + opvals = op_struct_fp.featurize(self.diamond) + self.assertAlmostEqual(opvals[0], 0.0005, places=7) + self.assertAlmostEqual(opvals[1], 0, places=7) + self.assertAlmostEqual(opvals[2], 0.0005, places=7) + self.assertAlmostEqual(opvals[3], 0.0, places=7) + self.assertAlmostEqual(opvals[4], 0.0005, places=7) + self.assertAlmostEqual(opvals[18], 0.0805, places=7) + self.assertAlmostEqual(opvals[20], 0.9995, places=7) + self.assertAlmostEqual(opvals[21], 0, places=7) + self.assertAlmostEqual(opvals[22], 0.0075, places=7) + self.assertAlmostEqual(opvals[24], 0.2355, places=7) + self.assertAlmostEqual(opvals[-1], 0.0, places=7) + + # Test coordination number + cn_fp = PartialsSiteStatsFingerprint.from_preset("JmolNN", stats=("mean",)) + cn_fp.fit([self.diamond]) + cn_vals = cn_fp.featurize(self.diamond) + self.assertEqual(cn_vals[0], 4.0) + + # Test the covariance + prop_fp = PartialsSiteStatsFingerprint( + SiteElementalProperty(properties=["Number", "AtomicWeight"]), + stats=["mean"], + covariance=True, + ) + + # Test the feature labels + prop_fp.fit([self.diamond]) + labels = prop_fp.feature_labels() + self.assertEqual(3, len(labels)) + + # Test a structure with all the same type (cov should be zero) + prop_fp.fit([self.diamond]) + features = prop_fp.featurize(self.diamond) + self.assertArrayAlmostEqual(features, [6, 12.0107, 0]) + + # Test a structure with only one atom (cov should be zero too) + prop_fp.fit([self.sc]) + features = prop_fp.featurize(self.sc) + self.assertArrayAlmostEqual([13, 26.9815386, 0], features) + + # Test a structure with nonzero covariance + prop_fp.fit([self.nacl]) + features = prop_fp.featurize(self.nacl) + self.assertArrayAlmostEqual([11, 22.9897693, np.nan, 17, 35.453, np.nan], features) + + def test_ward_prb_2017_lpd(self): + """Test the local property difference attributes from Ward 2017""" + f = PartialsSiteStatsFingerprint.from_preset("LocalPropertyDifference_ward-prb-2017") + + # Test diamond + f.fit([self.diamond]) + features = f.featurize(self.diamond) + self.assertArrayAlmostEqual(features, [0] * (22 * 5)) + features = f.featurize(self.diamond_no_oxi) + self.assertArrayAlmostEqual(features, [0] * (22 * 5)) + + # Test CsCl + f.fit([self.cscl]) + big_face_area = np.sqrt(3) * 3 / 2 * (2 / 4 / 4) + small_face_area = 0.125 + big_face_diff = 55 - 17 + features = f.featurize(self.cscl) + labels = f.feature_labels() + my_label = "Cs mean local difference in Number" + self.assertAlmostEqual( + (8 * big_face_area * big_face_diff) / (8 * big_face_area + 6 * small_face_area), + features[labels.index(my_label)], + places=3, + ) + my_label = "Cs range local difference in Electronegativity" + self.assertAlmostEqual(0, features[labels.index(my_label)], places=3) + + def test_ward_prb_2017_efftcn(self): + """Test the effective coordination number attributes of Ward 2017""" + f = PartialsSiteStatsFingerprint.from_preset("CoordinationNumber_ward-prb-2017") + + # Test Ni3Al + f.fit([self.ni3al]) + features = f.featurize(self.ni3al) + labels = f.feature_labels() + self.assertAlmostEqual(12, features[labels.index("Al mean CN_VoronoiNN")]) + self.assertAlmostEqual(12, features[labels.index("Ni mean CN_VoronoiNN")]) + self.assertArrayAlmostEqual([12, 12, 0, 12, 0] * 2, features) + + if __name__ == "__main__": unittest.main()