diff --git a/emmet-builders/emmet/builders/feff/xas.py b/emmet-builders/emmet/builders/feff/xas.py index 1c7e2e7a38..dbc83140b8 100644 --- a/emmet-builders/emmet/builders/feff/xas.py +++ b/emmet-builders/emmet/builders/feff/xas.py @@ -1,19 +1,13 @@ from typing import List, Dict -from itertools import groupby, chain +from itertools import chain from datetime import datetime import traceback -import numpy as np -from monty.json import jsanitize - from maggma.core import Store from maggma.builders import GroupBuilder - -from pymatgen.core import Structure -from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum -from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from emmet.core.feff.task import TaskDocument as FEFFTaskDocument from emmet.core.xas import XASDoc -from emmet.builders.utils import maximal_spanning_non_intersecting_subsets +from emmet.core.utils import jsanitize class XASBuilder(GroupBuilder): @@ -34,137 +28,36 @@ def __init__(self, tasks: Store, xas: Store, num_samples: int = 200.0, **kwargs) def process_item(self, spectra: List[Dict]) -> Dict: + # TODO: Change this to do structure matching against materials collection mpid = spectra[0]["mp_id"] - sandboxes = [doc.get("sandboxes", []) for doc in spectra] - sbxn_sets = maximal_spanning_non_intersecting_subsets(sandboxes) self.logger.debug(f"Processing: {mpid}") - all_processed = [] - - for sbxns in sbxn_sets: - sbxn_spectra = [ - doc - for doc in spectra - if doc.get("sandboxes", []) == list(sbxns) - or doc.get("sandboxes", []) == [] - ] - - try: - processed = self.process_spectra(sbxn_spectra) - for d in processed: - d.update({"state": "successful"}) - - all_processed.extend(processed) - except Exception as e: - self.logger.error(traceback.format_exc()) - all_processed.append( - { - "error": str(e), - "state": "failed", - "task_ids": list(d[self.xas.key] for d in sbxn_spectra), - } - ) + tasks = [FEFFTaskDocument(**task) for task in spectra] + + try: + docs = XASDoc.from_task_docs(tasks, material_id=mpid) + processed = [d.dict() for d in docs] + + for d in processed: + d.update({"state": "successful"}) + except Exception as e: + self.logger.error(traceback.format_exc()) + processed = [ + { + "error": str(e), + "state": "failed", + "task_ids": list(d.task_id for d in tasks), + } + ] update_doc = { "_bt": datetime.utcnow(), } - all_processed.update( - {k: v for k, v in update_doc.items() if k not in processed} - ) - - return all_processed + for d in processed: + d.update({k: v for k, v in update_doc.items() if k not in d}) - def process_spectra(self, items: List[Dict]) -> Dict: - - all_spectra = [feff_task_to_spectrum(task) for task in items] - - # Dictionary of all site to spectra mapping - sites_to_spectra = { - index: list(group) - for index, group in groupby( - sorted(all_spectra, key=lambda x: x.absorbing_index), - key=lambda x: x.absorbing_index, - ) - } - - # perform spectra merging - for site, spectra in sites_to_spectra.items(): - type_to_spectra = { - index: list(group) - for index, group in groupby( - sorted( - spectra, key=lambda x: (x.edge, x.spectrum_type, x.last_updated) - ), - key=lambda x: (x.edge, x.spectrum_type), - ) - } - # Make K-Total - if ("K", "XANES") in type_to_spectra and ("K", "EXAFS") in type_to_spectra: - xanes = type_to_spectra[("K", "XANES")][-1] - exafs = type_to_spectra[("K", "EXAFS")][-1] - try: - total_spectrum = xanes.stitch(exafs, mode="XAFS") - total_spectrum.absorbing_index = site - total_spectrum.task_ids = xanes.task_ids + exafs.task_ids - all_spectra.append(total_spectrum) - except ValueError as e: - self.logger.warning(e) - - # Make L23 - if ("L2", "XANES") in type_to_spectra and ( - "L3", - "XANES", - ) in type_to_spectra: - l2 = type_to_spectra[("L2", "XANES")][-1] - l3 = type_to_spectra[("L3", "XANES")][-1] - try: - total_spectrum = l2.stitch(l3, mode="L23") - total_spectrum.absorbing_index = site - total_spectrum.task_ids = l2.task_ids + l3.task_ids - all_spectra.append(total_spectrum) - except ValueError as e: - self.logger.warning(e) - - self.logger.debug(f"Found {len(all_spectra)} spectra") - - # Site-weighted averaging - spectra_to_average = [ - list(group) - for _, group in groupby( - sorted( - all_spectra, - key=lambda x: (x.absorbing_element, x.edge, x.spectrum_type), - ), - key=lambda x: lambda x: (x.absorbing_element, x.edge, x.spectrum_type), - ) - ] - averaged_spectra = [] - - for relevant_spectra in spectra_to_average: - - if len(relevant_spectra) > 0 and not is_missing_sites(relevant_spectra): - if len(relevant_spectra) > 1: - try: - avg_spectrum = site_weighted_spectrum( - relevant_spectra, num_samples=self.num_samples - ) - avg_spectrum.task_ids = [ - id - for spectrum in relevant_spectra - for id in spectrum.task_ids - ] - averaged_spectra.append(avg_spectrum) - except ValueError as e: - self.logger.error(e) - else: - averaged_spectra.append(relevant_spectra[0]) - - spectra_docs = [ - XASDoc.from_spectrum(spectrum).dict() for spectrum in averaged_spectra - ] - - return spectra_docs + return jsanitize(processed, allow_bson=True) def update_targets(self, items): """ @@ -173,77 +66,3 @@ def update_targets(self, items): items = list(filter(None.__ne__, chain.from_iterable(items))) super().update_targets(items) - - -def is_missing_sites(spectra): - """ - Determines if the collection of spectra are missing any indicies for the given element - """ - structure = spectra[0].structure - element = spectra[0].absorbing_element - - # Find missing symmeterically inequivalent sites - symm_sites = SymmSites(structure) - absorption_indicies = {spectrum.absorbing_index for spectrum in spectra} - - missing_site_spectra_indicies = ( - set(structure.indices_from_symbol(element)) - absorption_indicies - ) - for site_index in absorption_indicies: - missing_site_spectra_indicies -= set( - symm_sites.get_equivalent_site_indices(site_index) - ) - - return len(missing_site_spectra_indicies) != 0 - - -class SymmSites: - """ - Wrapper to get equivalent site indicies from SpacegroupAnalyzer - """ - - def __init__(self, structure): - self.structure = structure - sa = SpacegroupAnalyzer(self.structure) - symm_data = sa.get_symmetry_dataset() - # equivalency mapping for the structure - # i'th site in the input structure equivalent to eq_atoms[i]'th site - self.eq_atoms = symm_data["equivalent_atoms"] - - def get_equivalent_site_indices(self, i): - """ - Site indices in the structure that are equivalent to the given site i. - """ - rv = np.argwhere(self.eq_atoms == self.eq_atoms[i]).squeeze().tolist() - if isinstance(rv, int): - rv = [rv] - return rv - - -def feff_task_to_spectrum(doc): - energy = doc["spectrum"][0] # (eV) - intensity = doc["spectrum"][3] # (mu) - structure: Structure = Structure.from_dict(doc["structure"]) - # Clean site properties - for site_prop in structure.site_properties.keys(): - structure.remove_site_property(site_prop) - - absorbing_index = doc["absorbing_atom"] - absorbing_element = structure[absorbing_index].specie - edge = doc["edge"] - spectrum_type = doc["spectrum_type"] - - spectrum = XAS( - x=energy, - y=intensity, - structure=structure, - absorbing_element=absorbing_element, - absorbing_index=absorbing_index, - edge=edge, - spectrum_type=spectrum_type, - ) - # Adding a attr is not a robust process - # Figure out better solution later - spectrum.last_updated = doc["last_updated"] - spectrum.task_ids = [doc["xas_id"]] - return spectrum diff --git a/emmet-core/emmet/core/feff/__init__.py b/emmet-core/emmet/core/feff/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/emmet-core/emmet/core/feff/task.py b/emmet-core/emmet/core/feff/task.py new file mode 100644 index 0000000000..11fa2d126b --- /dev/null +++ b/emmet-core/emmet/core/feff/task.py @@ -0,0 +1,86 @@ +""" Core definition of a VASP Task Document """ +from functools import lru_cache +from typing import Any, ClassVar, Dict, List + +from pydantic import Field +from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum +from pymatgen.core import Structure +from pymatgen.core.periodic_table import Element + +from emmet.core.mpid import MPID +from emmet.core.structure import StructureMetadata +from emmet.core.task import TaskDocument as BaseTaskDocument +from emmet.core.utils import ValueEnum + + +class CalcType(ValueEnum): + """ + The type of FEFF Calculation + XANES - Just the near-edge region + EXAFS - Just the extended region + XAFS - Fully stitchted XANES + EXAFS + """ + + XANES = "XANES" + EXAFS = "EXAFS" + XAFS = "XAFS" + + +class TaskDocument(BaseTaskDocument, StructureMetadata): + """Task Document for a FEFF XAS Calculation. Doesn't support EELS for now""" + + calc_code: ClassVar[str] = "FEFF" + + structure: Structure + input_parameters: Dict[str, Any] = Field( + {}, description="Input parameters for the FEFF calculation" + ) + spectrum: List[List[float]] = Field( + [[]], description="Raw spectrum data from FEFF xmu.dat or eels.dat" + ) + + absorbing_atom: int = Field( + ..., description="Index in the cluster or structure for the absorbing atom" + ) + spectrum_type: CalcType = Field(..., title="XAS Spectrum Type") + edge: str = Field( + ..., title="Absorption Edge", description="The interaction edge for XAS" + ) + + # TEMP Stub properties for compatability with atomate drone + + @property + def absorbing_element(self) -> Element: + if isinstance(self.structure[self.absorbing_atom].specie, Element): + return self.structure[self.absorbing_atom].specie + return self.structure[self.absorbing_atom].specie.element + + @property + def xas_spectrum(self) -> XAS: + + if not hasattr(self, "_xas_spectrum"): + + if not all([len(p) == 6 for p in self.spectrum]): + raise ValueError( + "Spectrum data doesn't appear to be from xmu.dat which holds XAS data" + ) + + energy = [point[0] for point in self.spectrum] # (eV) + intensity = [point[3] for point in self.spectrum] # (mu) + structure = self.structure + absorbing_index = self.absorbing_atom + absorbing_element = self.absorbing_element + edge = self.edge + spectrum_type = str(self.spectrum_type) + + self._xas_spectrum = XAS( + x=energy, + y=intensity, + structure=structure, + absorbing_element=absorbing_element, + absorbing_index=absorbing_index, + edge=edge, + spectrum_type=spectrum_type, + ) + + return self._xas_spectrum diff --git a/emmet-core/emmet/core/task.py b/emmet-core/emmet/core/task.py new file mode 100644 index 0000000000..3db9c22a51 --- /dev/null +++ b/emmet-core/emmet/core/task.py @@ -0,0 +1,35 @@ +""" Core definition of a Task Document which represents a calculation from some program""" +from datetime import datetime +from typing import ClassVar, List + +from pydantic import BaseModel, Field + +from emmet.core.mpid import MPID + + +class TaskDocument(BaseModel): + """ + Definition of Task Document + """ + + calc_code: ClassVar[str] = Field( + ..., description="The calculation code used to compute this task" + ) + version: str = Field(None, description="The version of the calculation code") + dir_name: str = Field(None, description="The directory for this task") + task_id: MPID = Field(None, description="the Task ID For this document") + + completed: bool = Field(False, description="Whether this calcuation completed") + completed_at: datetime = Field( + None, description="Timestamp for when this task was completed" + ) + last_updated: datetime = Field( + default_factory=datetime.utcnow, + description="Timestamp for this task document was last updateed", + ) + + tags: List[str] = Field([], description="Metadata tags for this task document") + + warnings: List[str] = Field( + None, description="Any warnings related to this property" + ) diff --git a/emmet-core/emmet/core/vasp/task.py b/emmet-core/emmet/core/vasp/task.py index 02de32b11e..681e7a0b3e 100644 --- a/emmet-core/emmet/core/vasp/task.py +++ b/emmet-core/emmet/core/vasp/task.py @@ -13,6 +13,7 @@ from emmet.core.math import Matrix3D, Vector3D from emmet.core.mpid import MPID from emmet.core.structure import StructureMetadata +from emmet.core.task import TaskDocument as BaseTaskDocument from emmet.core.utils import ValueEnum from emmet.core.vasp.calc_types import ( CalcType, @@ -98,22 +99,16 @@ class RunStatistics(BaseModel): ) -class TaskDocument(StructureMetadata): +class TaskDocument(BaseTaskDocument, StructureMetadata): """ Definition of VASP Task Document """ - dir_name: str = Field(None, description="The directory for this VASP task") + calc_code: ClassVar[str] = "VASP" run_stats: Dict[str, RunStatistics] = Field( {}, description="Summary of runtime statisitics for each calcualtion in this task", ) - completed_at: datetime = Field( - None, description="Timestamp for when this task was completed" - ) - last_updated: datetime = Field( - None, description="Timestamp for this task document was last updateed" - ) is_valid: bool = Field( True, description="Whether this task document passed validation or not" @@ -127,8 +122,6 @@ class TaskDocument(StructureMetadata): orig_inputs: Dict[str, Any] = Field( {}, description="Summary of the original VASP inputs" ) - task_id: MPID = Field(None, description="the Task ID For this document") - tags: List[str] = Field([], description="Metadata tags for this task document") calcs_reversed: List[Dict] = Field( [], description="The 'raw' calculation docs used to assembled this task" diff --git a/emmet-core/emmet/core/xas.py b/emmet-core/emmet/core/xas.py index cd4cddb9bb..55a895052d 100644 --- a/emmet-core/emmet/core/xas.py +++ b/emmet-core/emmet/core/xas.py @@ -1,11 +1,15 @@ -from datetime import datetime -from typing import List, Optional, Union +import warnings +from itertools import groupby +from typing import Dict, List -from pydantic import BaseModel, Field, root_validator -from pymatgen.analysis.xas.spectrum import XAS +import numpy as np +from pydantic import Field +from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum from pymatgen.core import Structure from pymatgen.core.periodic_table import Element +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer +from emmet.core.feff.task import TaskDocument from emmet.core.mpid import MPID from emmet.core.spectrum import SpectrumDoc from emmet.core.utils import ValueEnum @@ -47,7 +51,7 @@ class XASDoc(SpectrumDoc): spectrum: XAS - xas_ids: List[str] = Field( + task_ids: List[str] = Field( ..., title="Calculation IDs", description="List of Calculations IDs used to make this XAS spectrum.", @@ -70,6 +74,7 @@ def from_spectrum( el = xas_spectrum.absorbing_element edge = xas_spectrum.edge xas_id = f"{material_id}-{spectrum_type}-{el}-{edge}" + if xas_spectrum.absorbing_index is not None: xas_id += f"-{xas_spectrum.absorbing_index}" @@ -83,3 +88,180 @@ def from_spectrum( spectrum_id=xas_id, **kwargs, ) + + @classmethod + def from_task_docs( + cls, all_tasks: List[TaskDocument], material_id: MPID, num_samples: int = 200 + ) -> List[XASDoc]: + """ + Converts a set of FEFF Task Documents into XASDocs by merging XANES + EXAFS into XAFS spectra first + and then merging along equivalent elements to get element averaged spectra + + Args: + all_tasks: FEFF Task documents that have matching structure + material_id: The material ID for the generated XASDocs + num_samples: number of sampled points for site-weighted averaging + """ + + all_spectra: List[XAS] = [] + averaged_spectra: List[XAS] = [] + + # This is a hack using extra attributes within this function to carry some extra information + # without generating new objects + for task in all_tasks: + spectrum = task.xas_spectrum + spectrum.last_updated = task.last_updated + spectrum.task_ids = [task.task_id] + all_spectra.append(spectrum) + + # Pre sort by keys to remove needing to sort in the group by stage + all_spectra = sorted( + all_spectra, + key=lambda x: ( + x.absorbing_index, + x.edge, + x.spectrum_type, + -1 * x.last_updated, + ), + ) + + # Generate Merged Spectra + # Dictionary of all site to spectra mapping + sites_to_spectra = { + index: list(group) + for index, group in groupby( + all_spectra, + key=lambda x: x.absorbing_index, + ) + } + + # perform spectra merging + for site, spectra in sites_to_spectra.items(): + type_to_spectra = { + index: list(group) + for index, group in groupby( + spectra, + key=lambda x: (x.edge, x.spectrum_type), + ) + } + # Make K-edge XAFS spectra by merging XANES + EXAFS + if ("K", "XANES") in type_to_spectra and ("K", "EXAFS") in type_to_spectra: + xanes = type_to_spectra[("K", "XANES")][-1] + exafs = type_to_spectra[("K", "EXAFS")][-1] + try: + total_spectrum = xanes.stitch(exafs, mode="XAFS") + total_spectrum.absorbing_index = site + total_spectrum.task_ids = xanes.task_ids + exafs.task_ids + all_spectra.append(total_spectrum) + except ValueError as e: + warnings.warn(f"Warning during spectral merging in XASDoC: {e}") + + # Make L2,3 XANES spectra by merging L2 and L3 spectra + if ("L2", "XANES") in type_to_spectra and ( + "L3", + "XANES", + ) in type_to_spectra: + l2 = type_to_spectra[("L2", "XANES")][-1] + l3 = type_to_spectra[("L3", "XANES")][-1] + try: + total_spectrum = l2.stitch(l3, mode="L23") + total_spectrum.absorbing_index = site + total_spectrum.task_ids = l2.task_ids + l3.task_ids + all_spectra.append(total_spectrum) + except ValueError as e: + warnings.warn(f"Warning during spectral merging in XASDoC: {e}") + + # We don't have L2,3 EXAFS yet so don't have any merging + + # Site-weighted averaging + spectra_to_average = [ + list(group) + for _, group in groupby( + sorted( + all_spectra, + key=lambda x: (x.absorbing_element, x.edge, x.spectrum_type), + ), + key=lambda x: (x.absorbing_element, x.edge, x.spectrum_type), + ) + ] + + for relevant_spectra in spectra_to_average: + if len(relevant_spectra) > 0 and not _is_missing_sites(relevant_spectra): + if len(relevant_spectra) > 1: + try: + avg_spectrum = site_weighted_spectrum( + relevant_spectra, num_samples=num_samples + ) + avg_spectrum.task_ids = [ + id + for spectrum in relevant_spectra + for id in spectrum.task_ids + ] + avg_spectrum.last_updated = max( + [spectrum.last_updated for spectrum in relevant_spectra] + ) + averaged_spectra.append(avg_spectrum) + except ValueError as e: + warnings.warn( + f"Warning during site-weighted averaging in XASDoC: {e}" + ) + else: + averaged_spectra.append(relevant_spectra[0]) + + spectra_docs = [] + + for spectrum in averaged_spectra: + doc = XASDoc.from_spectrum( + xas_spectrum=spectrum, + material_id=material_id, + task_ids=spectrum.task_ids, + last_updated=spectrum.last_updated, + ) + spectra_docs.append(doc) + + return spectra_docs + + +def _is_missing_sites(spectra: List[XAS]): + """ + Determines if the collection of spectra are missing any indicies for the given element + """ + structure = spectra[0].structure + element = spectra[0].absorbing_element + + # Find missing symmeterically inequivalent sites + symm_sites = SymmSites(structure) + absorption_indicies = {spectrum.absorbing_index for spectrum in spectra} + + missing_site_spectra_indicies = ( + set(structure.indices_from_symbol(element)) - absorption_indicies + ) + for site_index in absorption_indicies: + missing_site_spectra_indicies -= set( + symm_sites.get_equivalent_site_indices(site_index) + ) + + return len(missing_site_spectra_indicies) != 0 + + +class SymmSites: + """ + Wrapper to get equivalent site indicies from SpacegroupAnalyzer + """ + + def __init__(self, structure): + self.structure = structure + sa = SpacegroupAnalyzer(self.structure) + symm_data = sa.get_symmetry_dataset() + # equivalency mapping for the structure + # i'th site in the input structure equivalent to eq_atoms[i]'th site + self.eq_atoms = symm_data["equivalent_atoms"] + + def get_equivalent_site_indices(self, i): + """ + Site indices in the structure that are equivalent to the given site i. + """ + rv = np.argwhere(self.eq_atoms == self.eq_atoms[i]).squeeze().tolist() + if isinstance(rv, int): + rv = [rv] + return rv