diff --git a/emmet-builders/emmet/builders/materials/alloys.py b/emmet-builders/emmet/builders/materials/alloys.py index 4016a65c4a..23044e7735 100644 --- a/emmet-builders/emmet/builders/materials/alloys.py +++ b/emmet-builders/emmet/builders/materials/alloys.py @@ -56,7 +56,7 @@ def __init__( targets=[alloy_pairs], chunk_size=8, ) - + def ensure_indexes(self): self.alloy_pairs.ensure_index("pair_id") diff --git a/emmet-builders/emmet/builders/vasp/thermo.py b/emmet-builders/emmet/builders/vasp/thermo.py index 527b316fd0..88a7409260 100644 --- a/emmet-builders/emmet/builders/vasp/thermo.py +++ b/emmet-builders/emmet/builders/vasp/thermo.py @@ -1,18 +1,18 @@ import warnings from collections import defaultdict from itertools import chain -from typing import Dict, Iterable, Iterator, List, Optional, Set +from typing import Dict, Iterable, Iterator, List, Optional, Set, Union from math import ceil from maggma.core import Builder, Store from maggma.utils import grouper from monty.json import MontyDecoder from pymatgen.analysis.phase_diagram import PhaseDiagramError -from pymatgen.entries.compatibility import MaterialsProject2020Compatibility from pymatgen.entries.computed_entries import ComputedStructureEntry +from pymatgen.entries.compatibility import Compatibility from emmet.builders.utils import chemsys_permutations -from emmet.core.thermo import ThermoDoc, PhaseDiagramDoc +from emmet.core.thermo import ThermoDoc, PhaseDiagramDoc, ThermoType from emmet.core.utils import jsanitize @@ -24,7 +24,7 @@ def __init__( phase_diagram: Optional[Store] = None, oxidation_states: Optional[Store] = None, query: Optional[Dict] = None, - compatibility=None, + compatibility: Optional[List[Compatibility]] = None, num_phase_diagram_eles: Optional[int] = None, **kwargs, ): @@ -39,7 +39,7 @@ def __init__( phase_diagram (Store): Store of phase diagram data for each unique chemical system oxidation_states (Store): Store of oxidation state data to use in correction scheme application query (dict): dictionary to limit materials to be analyzed - compatibility (PymatgenCompatability): Compatability module + compatibility ([Compatability]): Compatability module to ensure energies are compatible num_phase_diagram_eles (int): Maximum number of elements to use in phase diagram construction for data within the separate phase_diagram collection @@ -48,20 +48,40 @@ def __init__( self.materials = materials self.thermo = thermo self.query = query if query else {} - self.compatibility = compatibility + self.compatibility = compatibility or [None] self.oxidation_states = oxidation_states self.phase_diagram = phase_diagram self.num_phase_diagram_eles = num_phase_diagram_eles self._completed_tasks: Set[str] = set() self._entries_cache: Dict[str, List[ComputedStructureEntry]] = defaultdict(list) + if self.thermo.key != "thermo_id": + warnings.warn(f"Key for the thermo store is incorrect and has been changed from {self.thermo.key} to thermo_id!") + self.thermo.key = "thermo_id" + + if self.materials.key != "material_id": + warnings.warn(f"Key for the materials store is incorrect and has been changed from {self.materials.key} to material_id!") + self.materials.key = "material_id" + sources = [materials] - if oxidation_states is not None: - sources.append(oxidation_states) + + if self.oxidation_states is not None: + + if self.oxidation_states.key != "material_id": + warnings.warn(f"Key for the oxidation states store is incorrect and has been changed from {self.oxidation_states.key} to material_id!") + self.oxidation_states.key = "material_id" + + sources.append(oxidation_states) # type: ignore targets = [thermo] - if phase_diagram is not None: - targets.append(phase_diagram) + + if self.phase_diagram is not None: + + if self.phase_diagram.key != "phase_diagram_id": + warnings.warn(f"Key for the phase diagram store is incorrect and has been changed from {self.thphase_diagramermo.key} to phase_diagram_id!") + self.phase_diagram.key = "phase_diagram_id" + + targets.append(phase_diagram) # type: ignore super().__init__(sources=sources, targets=targets, **kwargs) @@ -77,11 +97,14 @@ def ensure_indexes(self): # Search index for thermo self.thermo.ensure_index("material_id") + self.thermo.ensure_index("thermo_id") + self.thermo.ensure_index("thermo_type") self.thermo.ensure_index("last_updated") # Search index for phase_diagram if self.phase_diagram: self.phase_diagram.ensure_index("chemsys") + self.phase_diagram.ensure_index("phase_diagram_id") def prechunk(self, number_splits: int) -> Iterable[Dict]: # pragma: no cover updated_chemsys = self.get_updated_chemsys() @@ -155,61 +178,80 @@ def process_item(self, item: List[Dict]): self.logger.debug(f"Processing {len(entries)} entries for {chemsys}") - material_entries: Dict[str, Dict[str, ComputedStructureEntry]] = defaultdict( - dict - ) - pd_entries = [] - for entry in entries: - material_entries[entry.data["material_id"]][entry.data["run_type"]] = entry - - if self.compatibility: - with warnings.catch_warnings(): - warnings.filterwarnings( - "ignore", message="Failed to guess oxidation states.*" - ) - pd_entries = self.compatibility.process_entries(entries) - else: - pd_entries = entries - self.logger.debug(f"{len(pd_entries)} remain in {chemsys} after filtering") + docs_pd_pair_list = [] - try: - docs, pd = ThermoDoc.from_entries(pd_entries, deprecated=False) + for compatability in self.compatibility: - # for doc in docs: - # doc.entries = material_entries[doc.material_id] - # doc.entry_types = list(material_entries[doc.material_id].keys()) + pd_entries = [] - pd_data = None + if compatability: + if compatability.name == "MP DFT mixing scheme": + thermo_type = ThermoType.GGA_GGA_U_R2SCAN + elif compatability.name == "MP2020": + thermo_type = ThermoType.GGA_GGA_U + else: + thermo_type = ThermoType.UNKNOWN - if self.phase_diagram: - if ( - self.num_phase_diagram_eles is None - or len(elements) <= self.num_phase_diagram_eles - ): - pd_doc = PhaseDiagramDoc(chemsys=chemsys, phase_diagram=pd) - pd_data = jsanitize(pd_doc.dict(), allow_bson=True) + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", message="Failed to guess oxidation states.*" + ) + pd_entries = compatability.process_entries(entries) + else: + all_entry_types = {e.data["run_type"] for e in entries} + if len(all_entry_types) > 1: + raise ValueError( + "More than one functional type has been provided without a mixing scheme!" + ) + else: + thermo_type = all_entry_types.pop() + pd_entries = entries + self.logger.debug(f"{len(pd_entries)} remain in {chemsys} after filtering") + + try: + docs, pd = ThermoDoc.from_entries( + pd_entries, thermo_type, deprecated=False + ) - docs_pd_pair = ( - jsanitize([d.dict() for d in docs], allow_bson=True), - [pd_data], - ) + pd_data = None + + if self.phase_diagram: + if ( + self.num_phase_diagram_eles is None + or len(elements) <= self.num_phase_diagram_eles + ): + pd_id = "{}_{}".format(chemsys, str(thermo_type)) + pd_doc = PhaseDiagramDoc( + phase_diagram_id=pd_id, + chemsys=chemsys, + phase_diagram=pd, + thermo_type=thermo_type, + ) + pd_data = jsanitize(pd_doc.dict(), allow_bson=True) + + docs_pd_pair = ( + jsanitize([d.dict() for d in docs], allow_bson=True), + [pd_data], + ) - except PhaseDiagramError as p: - elsyms = [] - for e in entries: - elsyms.extend([el.symbol for el in e.composition.elements]) + docs_pd_pair_list.append(docs_pd_pair) - self.logger.warning( - f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}" - ) - return [] - except Exception as e: - self.logger.error( - f"Got unexpected error while processing {[ent_.entry_id for ent_ in entries]}: {e}" - ) - return [] + except PhaseDiagramError as p: + elsyms = [] + for e in entries: + elsyms.extend([el.symbol for el in e.composition.elements]) + + self.logger.warning( + f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}" + ) + return [] + except Exception as e: + self.logger.error( + f"Got unexpected error while processing {[ent_.entry_id for ent_ in entries]}: {e}" + ) + return [] - return docs_pd_pair + return docs_pd_pair_list def update_targets(self, items): """ @@ -218,8 +260,8 @@ def update_targets(self, items): items ([[tuple(List[dict],List[dict])]]): a list of list of thermo dictionaries to update """ - thermo_docs = [item[0] for item in items] - phase_diagram_docs = [item[1] for item in items] + thermo_docs = [item[0] for pair_list in items for item in pair_list] + phase_diagram_docs = [item[1] for pair_list in items for item in pair_list] # flatten out lists thermo_docs = list(filter(None, chain.from_iterable(thermo_docs))) @@ -227,10 +269,10 @@ def update_targets(self, items): # Check if already updated this run thermo_docs = [ - i for i in thermo_docs if i["material_id"] not in self._completed_tasks + i for i in thermo_docs if i["thermo_id"] not in self._completed_tasks ] - self._completed_tasks |= {i["material_id"] for i in thermo_docs} + self._completed_tasks |= {i["thermo_id"] for i in thermo_docs} for item in thermo_docs: if isinstance(item["last_updated"], dict): @@ -243,7 +285,7 @@ def update_targets(self, items): if len(thermo_docs) > 0: self.logger.info(f"Updating {len(thermo_docs)} thermo documents") - self.thermo.update(docs=thermo_docs, key=["material_id"]) + self.thermo.update(docs=thermo_docs, key=["thermo_id"]) else: self.logger.info("No thermo items to update") @@ -302,17 +344,16 @@ def get_entries(self, chemsys: str) -> List[Dict]: f"Got {len(materials_docs)} entries from DB for {len(query_chemsys)} sub-chemsys for {chemsys}" ) - # Convert GGA, GGA+U, R2SCAN entries into ComputedEntries and store + # Convert entries into ComputedEntries and store for doc in materials_docs: for r_type, entry_dict in doc.get("entries", {}).items(): - if r_type in ["GGA", "GGA+U", "R2SCAN"]: - entry_dict["data"]["oxidation_states"] = oxi_states_data.get( - entry_dict["data"]["material_id"], {} - ) - entry_dict["data"]["run_type"] = r_type - elsyms = sorted(set([el for el in entry_dict["composition"]])) - self._entries_cache["-".join(elsyms)].append(entry_dict) - all_entries.append(entry_dict) + entry_dict["data"]["oxidation_states"] = oxi_states_data.get( + entry_dict["data"]["material_id"], {} + ) + entry_dict["data"]["run_type"] = r_type + elsyms = sorted(set([el for el in entry_dict["composition"]])) + self._entries_cache["-".join(elsyms)].append(entry_dict) + all_entries.append(entry_dict) self.logger.info(f"Total entries in {chemsys} : {len(all_entries)}") diff --git a/emmet-core/emmet/core/thermo.py b/emmet-core/emmet/core/thermo.py index 304f6916f3..64cfa6fdb5 100644 --- a/emmet-core/emmet/core/thermo.py +++ b/emmet-core/emmet/core/thermo.py @@ -2,6 +2,7 @@ from collections import defaultdict from typing import Dict, List, Union from datetime import datetime +from emmet.core.utils import ValueEnum from pydantic import BaseModel, Field from pymatgen.analysis.phase_diagram import PhaseDiagram @@ -10,6 +11,7 @@ from emmet.core.material_property import PropertyDoc from emmet.core.material import PropertyOrigin from emmet.core.mpid import MPID +from emmet.core.vasp.calc_types.enums import RunType class DecompositionProduct(BaseModel): @@ -31,6 +33,12 @@ class DecompositionProduct(BaseModel): ) +class ThermoType(ValueEnum): + GGA_GGA_U = "GGA/GGA+U" + GGA_GGA_U_R2SCAN = "GGA/GGA+U/R2SCAN" + UNKNOWN = "UNKNOWN" + + class ThermoDoc(PropertyDoc): """ A thermo entry document @@ -38,6 +46,16 @@ class ThermoDoc(PropertyDoc): property_name = "thermo" + thermo_type: Union[ThermoType, RunType] = Field( + ..., + description="Functional types of calculations involved in the energy mixing scheme.", + ) + + thermo_id: str = Field( + ..., + description="Unique document ID which is composed of the Material ID and thermo data type.", + ) + uncorrected_energy_per_atom: float = Field( ..., description="The total DFT energy of this material per atom in eV/atom." ) @@ -99,7 +117,10 @@ class ThermoDoc(PropertyDoc): @classmethod def from_entries( - cls, entries: List[Union[ComputedEntry, ComputedStructureEntry]], **kwargs + cls, + entries: List[Union[ComputedEntry, ComputedStructureEntry]], + thermo_type: Union[ThermoType, RunType], + **kwargs ): entries_by_comp = defaultdict(list) @@ -145,7 +166,9 @@ def _energy_eval(entry: ComputedStructureEntry): (decomp, ehull) = pd.get_decomp_and_e_above_hull(blessed_entry) d = { + "thermo_id": "{}_{}".format(material_id, str(thermo_type)), "material_id": material_id, + "thermo_type": thermo_type, "uncorrected_energy_per_atom": blessed_entry.uncorrected_energy / blessed_entry.composition.num_atoms, "energy_per_atom": blessed_entry.energy @@ -235,10 +258,18 @@ class PhaseDiagramDoc(BaseModel): property_name = "phase_diagram" + phase_diagram_id: str = Field( + ..., + description="Phase diagram ID consisting of the chemical system and thermo type", + ) + chemsys: str = Field( + ..., description="Dash-delimited string of elements in the material", + ) + + thermo_type: Union[ThermoType, RunType] = Field( ..., - title="Chemical System", - description="Dash-delimited string of elements in the material", + description="Functional types of calculations involved in the energy mixing scheme.", ) phase_diagram: PhaseDiagram = Field( diff --git a/tests/emmet-core/test_thermo.py b/tests/emmet-core/test_thermo.py index aac230a823..859f558b65 100644 --- a/tests/emmet-core/test_thermo.py +++ b/tests/emmet-core/test_thermo.py @@ -130,7 +130,7 @@ def entries( def test_from_entries(entries): - docs, pd = ThermoDoc.from_entries(entries, deprecated=False) + docs, pd = ThermoDoc.from_entries(entries, thermo_type="UNKNOWN", deprecated=False) assert len(docs) == len(entries)