Skip to content

Commit

Permalink
Allow ThermoBuilder to build with multiple compatibility schemes (#486
Browse files Browse the repository at this point in the history
)

* Support multiple correction schemes in thermo

* Allow multiple compatability schemes in builder

* Fix emme-core tests

* Updated version file

* Updated version file

* Updated version file

* ThermoType now a ValueEnum

* Updated CHANGELOG.md

* Ensure correct store keys in thermo builder

* Linting

* Fix mypy

* More mypy fixes

* Add `ensure_indexes`

* Support multiple correction schemes in thermo

* Allow multiple compatability schemes in builder

* Fix emme-core tests

* ThermoType now a ValueEnum

* Ensure correct store keys in thermo builder

* Linting

* Fix mypy

* More mypy fixes

* Flake8 linting

Co-authored-by: materialsproject <feedback@materialsproject.org>
Co-authored-by: Matthew Horton <mkhorton@users.noreply.github.com>
  • Loading branch information
3 people committed Jul 28, 2022
1 parent ba9b447 commit 3e2ea81
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 76 deletions.
2 changes: 1 addition & 1 deletion emmet-builders/emmet/builders/materials/alloys.py
Expand Up @@ -56,7 +56,7 @@ def __init__(
targets=[alloy_pairs],
chunk_size=8,
)

def ensure_indexes(self):

self.alloy_pairs.ensure_index("pair_id")
Expand Down
183 changes: 112 additions & 71 deletions emmet-builders/emmet/builders/vasp/thermo.py
@@ -1,18 +1,18 @@
import warnings
from collections import defaultdict
from itertools import chain
from typing import Dict, Iterable, Iterator, List, Optional, Set
from typing import Dict, Iterable, Iterator, List, Optional, Set, Union
from math import ceil

from maggma.core import Builder, Store
from maggma.utils import grouper
from monty.json import MontyDecoder
from pymatgen.analysis.phase_diagram import PhaseDiagramError
from pymatgen.entries.compatibility import MaterialsProject2020Compatibility
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.entries.compatibility import Compatibility

from emmet.builders.utils import chemsys_permutations
from emmet.core.thermo import ThermoDoc, PhaseDiagramDoc
from emmet.core.thermo import ThermoDoc, PhaseDiagramDoc, ThermoType
from emmet.core.utils import jsanitize


Expand All @@ -24,7 +24,7 @@ def __init__(
phase_diagram: Optional[Store] = None,
oxidation_states: Optional[Store] = None,
query: Optional[Dict] = None,
compatibility=None,
compatibility: Optional[List[Compatibility]] = None,
num_phase_diagram_eles: Optional[int] = None,
**kwargs,
):
Expand All @@ -39,7 +39,7 @@ def __init__(
phase_diagram (Store): Store of phase diagram data for each unique chemical system
oxidation_states (Store): Store of oxidation state data to use in correction scheme application
query (dict): dictionary to limit materials to be analyzed
compatibility (PymatgenCompatability): Compatability module
compatibility ([Compatability]): Compatability module
to ensure energies are compatible
num_phase_diagram_eles (int): Maximum number of elements to use in phase diagram construction
for data within the separate phase_diagram collection
Expand All @@ -48,20 +48,40 @@ def __init__(
self.materials = materials
self.thermo = thermo
self.query = query if query else {}
self.compatibility = compatibility
self.compatibility = compatibility or [None]
self.oxidation_states = oxidation_states
self.phase_diagram = phase_diagram
self.num_phase_diagram_eles = num_phase_diagram_eles
self._completed_tasks: Set[str] = set()
self._entries_cache: Dict[str, List[ComputedStructureEntry]] = defaultdict(list)

if self.thermo.key != "thermo_id":
warnings.warn(f"Key for the thermo store is incorrect and has been changed from {self.thermo.key} to thermo_id!")
self.thermo.key = "thermo_id"

if self.materials.key != "material_id":
warnings.warn(f"Key for the materials store is incorrect and has been changed from {self.materials.key} to material_id!")
self.materials.key = "material_id"

sources = [materials]
if oxidation_states is not None:
sources.append(oxidation_states)

if self.oxidation_states is not None:

if self.oxidation_states.key != "material_id":
warnings.warn(f"Key for the oxidation states store is incorrect and has been changed from {self.oxidation_states.key} to material_id!")
self.oxidation_states.key = "material_id"

sources.append(oxidation_states) # type: ignore

targets = [thermo]
if phase_diagram is not None:
targets.append(phase_diagram)

if self.phase_diagram is not None:

if self.phase_diagram.key != "phase_diagram_id":
warnings.warn(f"Key for the phase diagram store is incorrect and has been changed from {self.thphase_diagramermo.key} to phase_diagram_id!")
self.phase_diagram.key = "phase_diagram_id"

targets.append(phase_diagram) # type: ignore

super().__init__(sources=sources, targets=targets, **kwargs)

Expand All @@ -77,11 +97,14 @@ def ensure_indexes(self):

# Search index for thermo
self.thermo.ensure_index("material_id")
self.thermo.ensure_index("thermo_id")
self.thermo.ensure_index("thermo_type")
self.thermo.ensure_index("last_updated")

# Search index for phase_diagram
if self.phase_diagram:
self.phase_diagram.ensure_index("chemsys")
self.phase_diagram.ensure_index("phase_diagram_id")

def prechunk(self, number_splits: int) -> Iterable[Dict]: # pragma: no cover
updated_chemsys = self.get_updated_chemsys()
Expand Down Expand Up @@ -155,61 +178,80 @@ def process_item(self, item: List[Dict]):

self.logger.debug(f"Processing {len(entries)} entries for {chemsys}")

material_entries: Dict[str, Dict[str, ComputedStructureEntry]] = defaultdict(
dict
)
pd_entries = []
for entry in entries:
material_entries[entry.data["material_id"]][entry.data["run_type"]] = entry

if self.compatibility:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Failed to guess oxidation states.*"
)
pd_entries = self.compatibility.process_entries(entries)
else:
pd_entries = entries
self.logger.debug(f"{len(pd_entries)} remain in {chemsys} after filtering")
docs_pd_pair_list = []

try:
docs, pd = ThermoDoc.from_entries(pd_entries, deprecated=False)
for compatability in self.compatibility:

# for doc in docs:
# doc.entries = material_entries[doc.material_id]
# doc.entry_types = list(material_entries[doc.material_id].keys())
pd_entries = []

pd_data = None
if compatability:
if compatability.name == "MP DFT mixing scheme":
thermo_type = ThermoType.GGA_GGA_U_R2SCAN
elif compatability.name == "MP2020":
thermo_type = ThermoType.GGA_GGA_U
else:
thermo_type = ThermoType.UNKNOWN

if self.phase_diagram:
if (
self.num_phase_diagram_eles is None
or len(elements) <= self.num_phase_diagram_eles
):
pd_doc = PhaseDiagramDoc(chemsys=chemsys, phase_diagram=pd)
pd_data = jsanitize(pd_doc.dict(), allow_bson=True)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore", message="Failed to guess oxidation states.*"
)
pd_entries = compatability.process_entries(entries)
else:
all_entry_types = {e.data["run_type"] for e in entries}
if len(all_entry_types) > 1:
raise ValueError(
"More than one functional type has been provided without a mixing scheme!"
)
else:
thermo_type = all_entry_types.pop()
pd_entries = entries
self.logger.debug(f"{len(pd_entries)} remain in {chemsys} after filtering")

try:
docs, pd = ThermoDoc.from_entries(
pd_entries, thermo_type, deprecated=False
)

docs_pd_pair = (
jsanitize([d.dict() for d in docs], allow_bson=True),
[pd_data],
)
pd_data = None

if self.phase_diagram:
if (
self.num_phase_diagram_eles is None
or len(elements) <= self.num_phase_diagram_eles
):
pd_id = "{}_{}".format(chemsys, str(thermo_type))
pd_doc = PhaseDiagramDoc(
phase_diagram_id=pd_id,
chemsys=chemsys,
phase_diagram=pd,
thermo_type=thermo_type,
)
pd_data = jsanitize(pd_doc.dict(), allow_bson=True)

docs_pd_pair = (
jsanitize([d.dict() for d in docs], allow_bson=True),
[pd_data],
)

except PhaseDiagramError as p:
elsyms = []
for e in entries:
elsyms.extend([el.symbol for el in e.composition.elements])
docs_pd_pair_list.append(docs_pd_pair)

self.logger.warning(
f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}"
)
return []
except Exception as e:
self.logger.error(
f"Got unexpected error while processing {[ent_.entry_id for ent_ in entries]}: {e}"
)
return []
except PhaseDiagramError as p:
elsyms = []
for e in entries:
elsyms.extend([el.symbol for el in e.composition.elements])

self.logger.warning(
f"Phase diagram error in chemsys {'-'.join(sorted(set(elsyms)))}: {p}"
)
return []
except Exception as e:
self.logger.error(
f"Got unexpected error while processing {[ent_.entry_id for ent_ in entries]}: {e}"
)
return []

return docs_pd_pair
return docs_pd_pair_list

def update_targets(self, items):
"""
Expand All @@ -218,19 +260,19 @@ def update_targets(self, items):
items ([[tuple(List[dict],List[dict])]]): a list of list of thermo dictionaries to update
"""

thermo_docs = [item[0] for item in items]
phase_diagram_docs = [item[1] for item in items]
thermo_docs = [item[0] for pair_list in items for item in pair_list]
phase_diagram_docs = [item[1] for pair_list in items for item in pair_list]

# flatten out lists
thermo_docs = list(filter(None, chain.from_iterable(thermo_docs)))
phase_diagram_docs = list(filter(None, chain.from_iterable(phase_diagram_docs)))

# Check if already updated this run
thermo_docs = [
i for i in thermo_docs if i["material_id"] not in self._completed_tasks
i for i in thermo_docs if i["thermo_id"] not in self._completed_tasks
]

self._completed_tasks |= {i["material_id"] for i in thermo_docs}
self._completed_tasks |= {i["thermo_id"] for i in thermo_docs}

for item in thermo_docs:
if isinstance(item["last_updated"], dict):
Expand All @@ -243,7 +285,7 @@ def update_targets(self, items):

if len(thermo_docs) > 0:
self.logger.info(f"Updating {len(thermo_docs)} thermo documents")
self.thermo.update(docs=thermo_docs, key=["material_id"])
self.thermo.update(docs=thermo_docs, key=["thermo_id"])
else:
self.logger.info("No thermo items to update")

Expand Down Expand Up @@ -302,17 +344,16 @@ def get_entries(self, chemsys: str) -> List[Dict]:
f"Got {len(materials_docs)} entries from DB for {len(query_chemsys)} sub-chemsys for {chemsys}"
)

# Convert GGA, GGA+U, R2SCAN entries into ComputedEntries and store
# Convert entries into ComputedEntries and store
for doc in materials_docs:
for r_type, entry_dict in doc.get("entries", {}).items():
if r_type in ["GGA", "GGA+U", "R2SCAN"]:
entry_dict["data"]["oxidation_states"] = oxi_states_data.get(
entry_dict["data"]["material_id"], {}
)
entry_dict["data"]["run_type"] = r_type
elsyms = sorted(set([el for el in entry_dict["composition"]]))
self._entries_cache["-".join(elsyms)].append(entry_dict)
all_entries.append(entry_dict)
entry_dict["data"]["oxidation_states"] = oxi_states_data.get(
entry_dict["data"]["material_id"], {}
)
entry_dict["data"]["run_type"] = r_type
elsyms = sorted(set([el for el in entry_dict["composition"]]))
self._entries_cache["-".join(elsyms)].append(entry_dict)
all_entries.append(entry_dict)

self.logger.info(f"Total entries in {chemsys} : {len(all_entries)}")

Expand Down
37 changes: 34 additions & 3 deletions emmet-core/emmet/core/thermo.py
Expand Up @@ -2,6 +2,7 @@
from collections import defaultdict
from typing import Dict, List, Union
from datetime import datetime
from emmet.core.utils import ValueEnum

from pydantic import BaseModel, Field
from pymatgen.analysis.phase_diagram import PhaseDiagram
Expand All @@ -10,6 +11,7 @@
from emmet.core.material_property import PropertyDoc
from emmet.core.material import PropertyOrigin
from emmet.core.mpid import MPID
from emmet.core.vasp.calc_types.enums import RunType


class DecompositionProduct(BaseModel):
Expand All @@ -31,13 +33,29 @@ class DecompositionProduct(BaseModel):
)


class ThermoType(ValueEnum):
GGA_GGA_U = "GGA/GGA+U"
GGA_GGA_U_R2SCAN = "GGA/GGA+U/R2SCAN"
UNKNOWN = "UNKNOWN"


class ThermoDoc(PropertyDoc):
"""
A thermo entry document
"""

property_name = "thermo"

thermo_type: Union[ThermoType, RunType] = Field(
...,
description="Functional types of calculations involved in the energy mixing scheme.",
)

thermo_id: str = Field(
...,
description="Unique document ID which is composed of the Material ID and thermo data type.",
)

uncorrected_energy_per_atom: float = Field(
..., description="The total DFT energy of this material per atom in eV/atom."
)
Expand Down Expand Up @@ -99,7 +117,10 @@ class ThermoDoc(PropertyDoc):

@classmethod
def from_entries(
cls, entries: List[Union[ComputedEntry, ComputedStructureEntry]], **kwargs
cls,
entries: List[Union[ComputedEntry, ComputedStructureEntry]],
thermo_type: Union[ThermoType, RunType],
**kwargs
):

entries_by_comp = defaultdict(list)
Expand Down Expand Up @@ -145,7 +166,9 @@ def _energy_eval(entry: ComputedStructureEntry):
(decomp, ehull) = pd.get_decomp_and_e_above_hull(blessed_entry)

d = {
"thermo_id": "{}_{}".format(material_id, str(thermo_type)),
"material_id": material_id,
"thermo_type": thermo_type,
"uncorrected_energy_per_atom": blessed_entry.uncorrected_energy
/ blessed_entry.composition.num_atoms,
"energy_per_atom": blessed_entry.energy
Expand Down Expand Up @@ -235,10 +258,18 @@ class PhaseDiagramDoc(BaseModel):

property_name = "phase_diagram"

phase_diagram_id: str = Field(
...,
description="Phase diagram ID consisting of the chemical system and thermo type",
)

chemsys: str = Field(
..., description="Dash-delimited string of elements in the material",
)

thermo_type: Union[ThermoType, RunType] = Field(
...,
title="Chemical System",
description="Dash-delimited string of elements in the material",
description="Functional types of calculations involved in the energy mixing scheme.",
)

phase_diagram: PhaseDiagram = Field(
Expand Down
2 changes: 1 addition & 1 deletion tests/emmet-core/test_thermo.py
Expand Up @@ -130,7 +130,7 @@ def entries(


def test_from_entries(entries):
docs, pd = ThermoDoc.from_entries(entries, deprecated=False)
docs, pd = ThermoDoc.from_entries(entries, thermo_type="UNKNOWN", deprecated=False)

assert len(docs) == len(entries)

Expand Down

0 comments on commit 3e2ea81

Please sign in to comment.