Skip to content

Commit

Permalink
Merge pull request #197 from materialsproject/xas
Browse files Browse the repository at this point in the history
Move XAS Logic into document and clean up builder
  • Loading branch information
shyamd committed May 21, 2021
2 parents 8f3edad + 03592ab commit 4f5f85d
Show file tree
Hide file tree
Showing 6 changed files with 335 additions and 220 deletions.
229 changes: 24 additions & 205 deletions emmet-builders/emmet/builders/feff/xas.py
@@ -1,19 +1,13 @@
from typing import List, Dict
from itertools import groupby, chain
from itertools import chain
from datetime import datetime
import traceback

import numpy as np
from monty.json import jsanitize

from maggma.core import Store
from maggma.builders import GroupBuilder

from pymatgen.core import Structure
from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from emmet.core.feff.task import TaskDocument as FEFFTaskDocument
from emmet.core.xas import XASDoc
from emmet.builders.utils import maximal_spanning_non_intersecting_subsets
from emmet.core.utils import jsanitize


class XASBuilder(GroupBuilder):
Expand All @@ -34,137 +28,36 @@ def __init__(self, tasks: Store, xas: Store, num_samples: int = 200.0, **kwargs)

def process_item(self, spectra: List[Dict]) -> Dict:

# TODO: Change this to do structure matching against materials collection
mpid = spectra[0]["mp_id"]
sandboxes = [doc.get("sandboxes", []) for doc in spectra]
sbxn_sets = maximal_spanning_non_intersecting_subsets(sandboxes)

self.logger.debug(f"Processing: {mpid}")
all_processed = []

for sbxns in sbxn_sets:
sbxn_spectra = [
doc
for doc in spectra
if doc.get("sandboxes", []) == list(sbxns)
or doc.get("sandboxes", []) == []
]

try:

processed = self.process_spectra(sbxn_spectra)
for d in processed:
d.update({"state": "successful"})

all_processed.extend(processed)
except Exception as e:
self.logger.error(traceback.format_exc())
all_processed.append(
{
"error": str(e),
"state": "failed",
"task_ids": list(d[self.xas.key] for d in sbxn_spectra),
}
)
tasks = [FEFFTaskDocument(**task) for task in spectra]

try:
docs = XASDoc.from_task_docs(tasks, material_id=mpid)
processed = [d.dict() for d in docs]

for d in processed:
d.update({"state": "successful"})
except Exception as e:
self.logger.error(traceback.format_exc())
processed = [
{
"error": str(e),
"state": "failed",
"task_ids": list(d.task_id for d in tasks),
}
]

update_doc = {
"_bt": datetime.utcnow(),
}
all_processed.update(
{k: v for k, v in update_doc.items() if k not in processed}
)

return all_processed
for d in processed:
d.update({k: v for k, v in update_doc.items() if k not in d})

def process_spectra(self, items: List[Dict]) -> Dict:

all_spectra = [feff_task_to_spectrum(task) for task in items]

# Dictionary of all site to spectra mapping
sites_to_spectra = {
index: list(group)
for index, group in groupby(
sorted(all_spectra, key=lambda x: x.absorbing_index),
key=lambda x: x.absorbing_index,
)
}

# perform spectra merging
for site, spectra in sites_to_spectra.items():
type_to_spectra = {
index: list(group)
for index, group in groupby(
sorted(
spectra, key=lambda x: (x.edge, x.spectrum_type, x.last_updated)
),
key=lambda x: (x.edge, x.spectrum_type),
)
}
# Make K-Total
if ("K", "XANES") in type_to_spectra and ("K", "EXAFS") in type_to_spectra:
xanes = type_to_spectra[("K", "XANES")][-1]
exafs = type_to_spectra[("K", "EXAFS")][-1]
try:
total_spectrum = xanes.stitch(exafs, mode="XAFS")
total_spectrum.absorbing_index = site
total_spectrum.task_ids = xanes.task_ids + exafs.task_ids
all_spectra.append(total_spectrum)
except ValueError as e:
self.logger.warning(e)

# Make L23
if ("L2", "XANES") in type_to_spectra and (
"L3",
"XANES",
) in type_to_spectra:
l2 = type_to_spectra[("L2", "XANES")][-1]
l3 = type_to_spectra[("L3", "XANES")][-1]
try:
total_spectrum = l2.stitch(l3, mode="L23")
total_spectrum.absorbing_index = site
total_spectrum.task_ids = l2.task_ids + l3.task_ids
all_spectra.append(total_spectrum)
except ValueError as e:
self.logger.warning(e)

self.logger.debug(f"Found {len(all_spectra)} spectra")

# Site-weighted averaging
spectra_to_average = [
list(group)
for _, group in groupby(
sorted(
all_spectra,
key=lambda x: (x.absorbing_element, x.edge, x.spectrum_type),
),
key=lambda x: lambda x: (x.absorbing_element, x.edge, x.spectrum_type),
)
]
averaged_spectra = []

for relevant_spectra in spectra_to_average:

if len(relevant_spectra) > 0 and not is_missing_sites(relevant_spectra):
if len(relevant_spectra) > 1:
try:
avg_spectrum = site_weighted_spectrum(
relevant_spectra, num_samples=self.num_samples
)
avg_spectrum.task_ids = [
id
for spectrum in relevant_spectra
for id in spectrum.task_ids
]
averaged_spectra.append(avg_spectrum)
except ValueError as e:
self.logger.error(e)
else:
averaged_spectra.append(relevant_spectra[0])

spectra_docs = [
XASDoc.from_spectrum(spectrum).dict() for spectrum in averaged_spectra
]

return spectra_docs
return jsanitize(processed, allow_bson=True)

def update_targets(self, items):
"""
Expand All @@ -173,77 +66,3 @@ def update_targets(self, items):

items = list(filter(None.__ne__, chain.from_iterable(items)))
super().update_targets(items)


def is_missing_sites(spectra):
"""
Determines if the collection of spectra are missing any indicies for the given element
"""
structure = spectra[0].structure
element = spectra[0].absorbing_element

# Find missing symmeterically inequivalent sites
symm_sites = SymmSites(structure)
absorption_indicies = {spectrum.absorbing_index for spectrum in spectra}

missing_site_spectra_indicies = (
set(structure.indices_from_symbol(element)) - absorption_indicies
)
for site_index in absorption_indicies:
missing_site_spectra_indicies -= set(
symm_sites.get_equivalent_site_indices(site_index)
)

return len(missing_site_spectra_indicies) != 0


class SymmSites:
"""
Wrapper to get equivalent site indicies from SpacegroupAnalyzer
"""

def __init__(self, structure):
self.structure = structure
sa = SpacegroupAnalyzer(self.structure)
symm_data = sa.get_symmetry_dataset()
# equivalency mapping for the structure
# i'th site in the input structure equivalent to eq_atoms[i]'th site
self.eq_atoms = symm_data["equivalent_atoms"]

def get_equivalent_site_indices(self, i):
"""
Site indices in the structure that are equivalent to the given site i.
"""
rv = np.argwhere(self.eq_atoms == self.eq_atoms[i]).squeeze().tolist()
if isinstance(rv, int):
rv = [rv]
return rv


def feff_task_to_spectrum(doc):
energy = doc["spectrum"][0] # (eV)
intensity = doc["spectrum"][3] # (mu)
structure: Structure = Structure.from_dict(doc["structure"])
# Clean site properties
for site_prop in structure.site_properties.keys():
structure.remove_site_property(site_prop)

absorbing_index = doc["absorbing_atom"]
absorbing_element = structure[absorbing_index].specie
edge = doc["edge"]
spectrum_type = doc["spectrum_type"]

spectrum = XAS(
x=energy,
y=intensity,
structure=structure,
absorbing_element=absorbing_element,
absorbing_index=absorbing_index,
edge=edge,
spectrum_type=spectrum_type,
)
# Adding a attr is not a robust process
# Figure out better solution later
spectrum.last_updated = doc["last_updated"]
spectrum.task_ids = [doc["xas_id"]]
return spectrum
Empty file.
86 changes: 86 additions & 0 deletions emmet-core/emmet/core/feff/task.py
@@ -0,0 +1,86 @@
""" Core definition of a VASP Task Document """
from functools import lru_cache
from typing import Any, ClassVar, Dict, List

from pydantic import Field
from pymatgen.analysis.xas.spectrum import XAS, site_weighted_spectrum
from pymatgen.core import Structure
from pymatgen.core.periodic_table import Element

from emmet.core.mpid import MPID
from emmet.core.structure import StructureMetadata
from emmet.core.task import TaskDocument as BaseTaskDocument
from emmet.core.utils import ValueEnum


class CalcType(ValueEnum):
"""
The type of FEFF Calculation
XANES - Just the near-edge region
EXAFS - Just the extended region
XAFS - Fully stitchted XANES + EXAFS
"""

XANES = "XANES"
EXAFS = "EXAFS"
XAFS = "XAFS"


class TaskDocument(BaseTaskDocument, StructureMetadata):
"""Task Document for a FEFF XAS Calculation. Doesn't support EELS for now"""

calc_code: ClassVar[str] = "FEFF"

structure: Structure
input_parameters: Dict[str, Any] = Field(
{}, description="Input parameters for the FEFF calculation"
)
spectrum: List[List[float]] = Field(
[[]], description="Raw spectrum data from FEFF xmu.dat or eels.dat"
)

absorbing_atom: int = Field(
..., description="Index in the cluster or structure for the absorbing atom"
)
spectrum_type: CalcType = Field(..., title="XAS Spectrum Type")
edge: str = Field(
..., title="Absorption Edge", description="The interaction edge for XAS"
)

# TEMP Stub properties for compatability with atomate drone

@property
def absorbing_element(self) -> Element:
if isinstance(self.structure[self.absorbing_atom].specie, Element):
return self.structure[self.absorbing_atom].specie
return self.structure[self.absorbing_atom].specie.element

@property
def xas_spectrum(self) -> XAS:

if not hasattr(self, "_xas_spectrum"):

if not all([len(p) == 6 for p in self.spectrum]):
raise ValueError(
"Spectrum data doesn't appear to be from xmu.dat which holds XAS data"
)

energy = [point[0] for point in self.spectrum] # (eV)
intensity = [point[3] for point in self.spectrum] # (mu)
structure = self.structure
absorbing_index = self.absorbing_atom
absorbing_element = self.absorbing_element
edge = self.edge
spectrum_type = str(self.spectrum_type)

self._xas_spectrum = XAS(
x=energy,
y=intensity,
structure=structure,
absorbing_element=absorbing_element,
absorbing_index=absorbing_index,
edge=edge,
spectrum_type=spectrum_type,
)

return self._xas_spectrum
35 changes: 35 additions & 0 deletions emmet-core/emmet/core/task.py
@@ -0,0 +1,35 @@
""" Core definition of a Task Document which represents a calculation from some program"""
from datetime import datetime
from typing import ClassVar, List

from pydantic import BaseModel, Field

from emmet.core.mpid import MPID


class TaskDocument(BaseModel):
"""
Definition of Task Document
"""

calc_code: ClassVar[str] = Field(
..., description="The calculation code used to compute this task"
)
version: str = Field(None, description="The version of the calculation code")
dir_name: str = Field(None, description="The directory for this task")
task_id: MPID = Field(None, description="the Task ID For this document")

completed: bool = Field(False, description="Whether this calcuation completed")
completed_at: datetime = Field(
None, description="Timestamp for when this task was completed"
)
last_updated: datetime = Field(
default_factory=datetime.utcnow,
description="Timestamp for this task document was last updateed",
)

tags: List[str] = Field([], description="Metadata tags for this task document")

warnings: List[str] = Field(
None, description="Any warnings related to this property"
)

0 comments on commit 4f5f85d

Please sign in to comment.