Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
working structure grouping document and tests
- Loading branch information
Showing
3 changed files
with
287 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
import logging | ||
import operator | ||
from datetime import datetime | ||
from itertools import groupby | ||
from typing import List | ||
|
||
from monty.json import MontyDecoder | ||
from pydantic import BaseModel, Field, validator | ||
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher | ||
from pymatgen.entries.computed_entries import ComputedStructureEntry | ||
|
||
from emmet.stubs import Composition, Structure | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def generic_groupby(list_in, comp=operator.eq): | ||
""" | ||
Group a list of unsortable objects | ||
Args: | ||
list_in: A list of generic objects | ||
comp: (Default value = operator.eq) The comparator | ||
Returns: | ||
[int] list of labels for the input list | ||
""" | ||
list_out = [None] * len(list_in) | ||
label_num = 0 | ||
for i1, ls1 in enumerate(list_out): | ||
if ls1 is not None: | ||
continue | ||
list_out[i1] = label_num | ||
for i2, ls2 in list(enumerate(list_out))[i1 + 1 :]: | ||
if comp(list_in[i1], list_in[i2]): | ||
if list_out[i2] is None: | ||
list_out[i2] = list_out[i1] | ||
else: | ||
list_out[i1] = list_out[i2] | ||
label_num -= 1 | ||
label_num += 1 | ||
return list_out | ||
|
||
|
||
def s_hash(el): | ||
return el.data["comp_delith"] | ||
|
||
|
||
class StructureGroupDoc(BaseModel): | ||
""" | ||
Group of structure | ||
""" | ||
|
||
task_id: str = Field( | ||
None, | ||
description="The id of the group is represented by the lowest numerical valued id amoung this group.", | ||
) | ||
|
||
grouped_ids: list = Field( | ||
None, | ||
description="A list of materials ids for all of the materials that were grouped together.", | ||
) | ||
|
||
framework: str = Field( | ||
None, | ||
description="The chemical formula for the framwork (the materials system without working ion).", | ||
) | ||
|
||
ignored_species: list = Field(None, description="List of ignored atomic species.") | ||
|
||
chemsys: str = Field( | ||
None, | ||
description="The chemical system this group belongs to, if the atoms for the ignored species is present the chemsys will also include the ignored species.", | ||
) | ||
|
||
last_updated: datetime = Field( | ||
None, | ||
description="Timestamp when this document was built.", | ||
) | ||
|
||
# Make sure that the datetime field is properly formatted | ||
@validator("last_updated", pre=True) | ||
def last_updated_dict_ok(cls, v): | ||
return MontyDecoder().process_decoded(v) | ||
|
||
@classmethod | ||
def from_grouped_entries( | ||
cls, entries: List[ComputedStructureEntry], ignored_species: List[str] | ||
) -> "StructureGroupDoc": | ||
""" " | ||
Assuming a list of entries are already grouped together, create a StructureGroupDoc | ||
Args: | ||
entries: A list of entries that is already grouped together. | ||
""" | ||
all_atoms = set() | ||
for ient in entries: | ||
all_atoms |= set(ient.composition.as_dict().keys()) | ||
|
||
common_atoms = all_atoms - set(ignored_species) | ||
if len(common_atoms) == 0: | ||
framework_str = "ignored" | ||
else: | ||
comp_d = {k: entries[0].composition.as_dict()[k] for k in common_atoms} | ||
framework_comp = Composition.from_dict(comp_d) | ||
framework_str = framework_comp.reduced_formula | ||
ids = [ient.entry_id for ient in entries] | ||
lowest_id = min(ids, key=_get_id_num) | ||
|
||
fields = { | ||
"task_id": lowest_id, | ||
"grouped_ids": ids, | ||
"framework": framework_str, | ||
"ignored_species": sorted(ignored_species), | ||
"chemsys": "-".join(sorted(all_atoms)), | ||
} | ||
|
||
return cls(**fields) | ||
|
||
@classmethod | ||
def from_ungrouped_structure_entries( | ||
cls, | ||
entries: List[ComputedStructureEntry], | ||
ignored_species: List[str], | ||
ltol: float = 0.2, | ||
stol: float = 0.3, | ||
angle_tol: float = 5.0, | ||
) -> List["StructureGroupDoc"]: | ||
""" | ||
Create a list of StructureGroupDocs from a list of ungrouped entries. | ||
Args: | ||
entries: The list of ComputedStructureEntries to process. | ||
ignored_species: the list of ignored species for the structure matcher | ||
ltol: length tolerance for the structure matcher | ||
stol: site position tolerance for the structure matcher | ||
angle_tol: angel tolerance for the structure matcher | ||
""" | ||
|
||
results = [] | ||
sm = StructureMatcher( | ||
comparator=ElementComparator(), | ||
primitive_cell=True, | ||
ignored_species=ignored_species, | ||
ltol=ltol, | ||
stol=stol, | ||
angle_tol=angle_tol, | ||
) | ||
|
||
# Add a framework field to each entry's data attribute | ||
for ient in entries: | ||
ient.data["framework"] = _get_framework( | ||
ient.composition.reduced_formula, ignored_species | ||
) | ||
|
||
# split into groups for each framework, must sort before grouping | ||
entries.sort(key=lambda x: x.data["framework"]) | ||
framework_groups = groupby(entries, key=lambda x: x.data["framework"]) | ||
|
||
cnt_ = 0 | ||
for framework, f_group in framework_groups: | ||
# if you only have ignored atoms put them into one "ignored" groupd | ||
f_group_l = list(f_group) | ||
if framework == "ignored": | ||
struct_group = cls.from_grouped_entries( | ||
f_group_l, ignored_species=ignored_species | ||
) | ||
cnt_ += len(struct_group.grouped_ids) | ||
continue | ||
|
||
logger.debug( | ||
f"Performing structure matching for {framework} with {len(f_group_l)} documents." | ||
) | ||
for g in group_entries_with_structure_matcher(f_group_l, sm): | ||
struct_group = cls.from_grouped_entries( | ||
g, ignored_species=ignored_species | ||
) | ||
cnt_ += len(struct_group.grouped_ids) | ||
results.append(struct_group) | ||
if cnt_ != len(entries): | ||
raise RuntimeError( | ||
"The number of entries in all groups the end does not match the number of supplied entries documents." | ||
"Something is seriously wrong, please rebuild the entire database and see if the problem persists." | ||
) | ||
return results | ||
|
||
|
||
def group_entries_with_structure_matcher(g, struct_matcher): | ||
""" | ||
Group the entries together based on similarity of the primitive cells | ||
Args: | ||
g: a list of entries | ||
Returns: | ||
subgroups: subgroups that are grouped together based on structure similarity | ||
""" | ||
labs = generic_groupby( | ||
g, | ||
comp=lambda x, y: struct_matcher.fit(x.structure, y.structure, symmetric=True), | ||
) | ||
for ilab in set(labs): | ||
sub_g = [g[itr] for itr, jlab in enumerate(labs) if jlab == ilab] | ||
yield [el for el in sub_g] | ||
|
||
|
||
def _get_id_num(task_id): | ||
if isinstance(task_id, int): | ||
return task_id | ||
if isinstance(task_id, str) and "-" in task_id: | ||
return int(task_id.split("-")[-1]) | ||
else: | ||
raise ValueError("TaskID needs to be either a number or of the form xxx-#####") | ||
|
||
|
||
def _get_framework(formula, ignored_species) -> str: | ||
""" | ||
Return the reduced formula of the entry without any of the ignored species | ||
Return 'ignored' if the all the atoms are ignored | ||
""" | ||
dd_ = Composition(formula).as_dict() | ||
if dd_.keys() == set(ignored_species): | ||
return "ignored" | ||
for ignored_sp in ignored_species: | ||
if ignored_sp in dd_: | ||
dd_.pop(ignored_sp) | ||
return Composition.from_dict(dd_).reduced_formula |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import pytest | ||
from monty.serialization import loadfn | ||
from pymatgen import Composition | ||
from pymatgen.apps.battery.conversion_battery import ConversionElectrode | ||
from pymatgen.apps.battery.insertion_battery import InsertionElectrode | ||
from pymatgen.entries.computed_entries import ComputedEntry | ||
|
||
from emmet.core.electrode import ( | ||
ConversionElectrodeDoc, | ||
ConversionVoltagePairDoc, | ||
InsertionElectrodeDoc, | ||
InsertionVoltagePairDoc, | ||
) | ||
from emmet.core.structure_group import StructureGroupDoc | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def entries_lto(test_dir): | ||
""" | ||
Recycle the test cases from pymatgen | ||
""" | ||
entries = loadfn(test_dir / "LiTiO2_batt.json") | ||
for itr, ient in enumerate(entries): | ||
ient.entry_id = f"mp-{itr}" | ||
return entries | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def entries_lfeo(test_dir): | ||
""" | ||
Recycle the test cases from pymatgen | ||
""" | ||
entries = loadfn(test_dir / "Li-Fe-O.json") | ||
return entries | ||
|
||
|
||
def test_StructureGroupDoc_from_grouped_entries(entries_lto): | ||
sgroup_doc = StructureGroupDoc.from_grouped_entries( | ||
entries_lto, ignored_species=["Li"] | ||
) | ||
assert sgroup_doc.task_id == "mp-0" | ||
assert sgroup_doc.grouped_ids == ["mp-0", "mp-1", "mp-2", "mp-3", "mp-4", "mp-5"] | ||
assert sgroup_doc.framework == "TiO2" | ||
assert sgroup_doc.ignored_species == ["Li"] | ||
assert sgroup_doc.chemsys == "Li-O-Ti" | ||
|
||
|
||
def test_StructureGroupDoc_from_ungrouped_entries(entries_lfeo): | ||
entry_dict = {ient.entry_id: ient for ient in entries_lfeo} | ||
sgroup_docs = StructureGroupDoc.from_ungrouped_structure_entries( | ||
entries_lfeo, ignored_species=["Li"] | ||
) | ||
|
||
# Make sure that all the structure in each group has the same framework | ||
for sgroup_doc in sgroup_docs: | ||
framework_ref = sgroup_doc.framework | ||
ignored = sgroup_doc.ignored_species | ||
for entry_id in sgroup_doc.grouped_ids: | ||
dd_ = entry_dict[entry_id].composition.as_dict() | ||
for k in ignored: | ||
if k in dd_: | ||
dd_.pop(k) | ||
framework = Composition.from_dict(dd_).reduced_formula | ||
assert framework == framework_ref |
Large diffs are not rendered by default.
Oops, something went wrong.