From 1c9dbd85b20b1dde073543f57cc24fd7c268c1ed Mon Sep 17 00:00:00 2001 From: Shyam D Date: Thu, 22 Oct 2020 09:28:58 -0700 Subject: [PATCH] update materialsbuilder to work with new docs --- .../emmet/builders/vasp/materials.py | 87 ++++++++----------- 1 file changed, 36 insertions(+), 51 deletions(-) diff --git a/emmet-builders/emmet/builders/vasp/materials.py b/emmet-builders/emmet/builders/vasp/materials.py index dad29fce6b..d61cdc1d74 100644 --- a/emmet-builders/emmet/builders/vasp/materials.py +++ b/emmet-builders/emmet/builders/vasp/materials.py @@ -9,13 +9,16 @@ from pymatgen.analysis.structure_analyzer import oxide_type from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher -from emmet.builders import SETTINGS +from emmet.core import SETTINGS from emmet.builders.utils import maximal_spanning_non_intersecting_subsets from emmet.core.utils import group_structures, jsanitize -from emmet.core.vasp.calc_types import CalcType, TaskType, run_type, task_type -from emmet.core.vasp.material import MaterialsDoc, PropertyOrigin + +from emmet.core.vasp.material import MaterialsDoc from emmet.stubs import ComputedEntry +from emmet.core.vasp.calc_types import TaskType +from emmet.core.vasp.task import TaskDocument + __author__ = "Shyam Dwaraknath " @@ -45,7 +48,6 @@ def __init__( task_validation: Optional[Store] = None, query: Optional[Dict] = None, allowed_task_types: Optional[List[str]] = None, - tags_to_sandboxes: Optional[Dict[str, List[str]]] = None, symprec: float = SETTINGS.SYMPREC, ltol: float = SETTINGS.LTOL, stol: float = SETTINGS.STOL, @@ -58,7 +60,6 @@ def __init__( materials: Store of materials documents to generate query: dictionary to limit tasks to be analyzed allowed_task_types: list of task_types that can be processed - tags_to_sandboxes: dictionary mapping sandboxes to a list of tags symprec: tolerance for SPGLib spacegroup finding ltol: StructureMatcher tuning parameter for matching tasks to materials stol: StructureMatcher tuning parameter for matching tasks to materials @@ -68,10 +69,11 @@ def __init__( self.tasks = tasks self.materials = materials self.task_validation = task_validation - self.allowed_task_types = {TaskType(t) for t in allowed_task_types} or set( - TaskType - ) - self.tags_to_sandboxes = tags_to_sandboxes or SETTINGS.tags_to_sandboxes + if allowed_task_types is None: + self.allowed_task_types = set(TaskType) + else: + self.allowed_task_types = {TaskType(t) for t in allowed_task_types} + self.query = query if query else {} self.symprec = symprec self.ltol = ltol @@ -90,20 +92,20 @@ def ensure_indexes(self): """ # Basic search index for tasks - self.tasks.ensure_index(self.tasks.key) - self.tasks.ensure_index(self.tasks.last_updated_field) + self.tasks.ensure_index("task_id") + self.tasks.ensure_index("last_updated") self.tasks.ensure_index("state") self.tasks.ensure_index("formula_pretty") # Search index for materials - self.materials.ensure_index(self.materials.key) - self.materials.ensure_index(self.materials.last_updated_field) + self.materials.ensure_index("material_id") + self.materials.ensure_index("last_updated") self.materials.ensure_index("sandboxes") self.materials.ensure_index("task_ids") if self.task_validation: - self.task_validation.ensure_index(self.task_validation.key) - self.task_validation.ensure_index("is_valid") + self.task_validation.ensure_index("task_id") + self.task_validation.ensure_index("valid") def get_items(self) -> Iterator[List[Dict]]: """ @@ -116,7 +118,9 @@ def get_items(self) -> Iterator[List[Dict]]: """ self.logger.info("Materials builder started") - self.logger.info(f"Allowed task types: {self.allowed_task_types}") + self.logger.info( + f"Allowed task types: {[task_type.value for task_type in self.allowed_task_types]}" + ) self.logger.info("Setting indexes") self.ensure_indexes() @@ -159,23 +163,18 @@ def get_items(self) -> Iterator[List[Dict]]: invalid_ids = set() projected_fields = [ - self.tasks.last_updated_field, - self.tasks.key, + "last_updated", + "completed_at", + "task_id", "formula_pretty", "output.energy_per_atom", "output.structure", - "output.parameters", + "input.parameters", "orig_inputs", "input.structure", "tags", ] - sandboxed_tags = { - sandbox - for sandbox in self.tags_to_sandboxes.values() - if self.tags_to_sandboxes is not None - } - for formula in to_process_forms: tasks_query = dict(temp_query) tasks_query["formula_pretty"] = formula @@ -188,18 +187,6 @@ def get_items(self) -> Iterator[List[Dict]]: else: t["is_valid"] = True - if any(tag in sandboxed_tags for tag in t.get("tags", [])): - t["sandboxes"] = [ - sandbox - for sandbox in self.tags_to_sandboxes - if any( - tag in t["tags"] - for tag in set(self.tags_to_sandboxes[sandbox]) - ) - ] - else: - t["sandboxes"] = ["core"] - yield tasks def process_item(self, tasks: List[Dict]) -> List[Dict]: @@ -213,8 +200,9 @@ def process_item(self, tasks: List[Dict]) -> List[Dict]: ([dict],list) : a list of new materials docs and a list of task_ids that were processsed """ - formula = tasks[0]["formula_pretty"] - task_ids = [task[self.tasks.key] for task in tasks] + tasks = [TaskDocument(**task) for task in tasks] + formula = tasks[0].formula_pretty + task_ids = [task.task_id for task in tasks] self.logger.debug(f"Processing {formula} : {task_ids}") materials = [] @@ -223,7 +211,7 @@ def process_item(self, tasks: List[Dict]) -> List[Dict]: materials = [MaterialsDoc.from_tasks(group) for group in grouped_tasks] self.logger.debug(f"Produced {len(materials)} materials for {formula}") - return [mat.dict() for mat in materials if mat is not None] + return [mat.dict() for mat in materials] def update_targets(self, items: List[List[Dict]]): """ @@ -239,19 +227,19 @@ def update_targets(self, items: List[List[Dict]]): for item in items: item.update({"_bt": self.timestamp}) - material_ids = {item[self.materials.key] for item in items} + material_ids = list({item["material_id"] for item in items}) if len(items) > 0: self.logger.info(f"Updating {len(items)} materials") self.materials.remove_docs({self.materials.key: {"$in": material_ids}}) self.materials.update( docs=jsanitize(items, allow_bson=True), - key=(self.materials.key, "sandboxes"), + key=["material_id", "sandboxes"], ) else: self.logger.info("No items to update") - def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]: + def filter_and_group_tasks(self, tasks: List[TaskDocument]) -> Iterator[List[Dict]]: """ Groups tasks by structure matching """ @@ -260,15 +248,15 @@ def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]: task for task in tasks if any( - allowed_type in task_type(task.get("orig_inputs", {})) + allowed_type is task.task_type for allowed_type in self.allowed_task_types ) ] structures = [] - for idx, t in enumerate(filtered_tasks): - s = Structure.from_dict(t["output"]["structure"]) + for idx, task in enumerate(filtered_tasks): + s = task.output.structure s.index = idx structures.append(s) @@ -282,14 +270,11 @@ def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]: for group in grouped_structures: grouped_tasks = [filtered_tasks[struc.index] for struc in group] - sandboxes = [ - task["sandboxes"] for task in grouped_tasks if "sandboxes" in task - ] + sandboxes = {frozenset(task.sandboxes) for task in grouped_tasks} for sbx_set in maximal_spanning_non_intersecting_subsets(sandboxes): yield [ task for task in grouped_tasks - if len(set(task.get("sandboxes", ["core"])).intersection(sbx_set)) - > 0 + if len(set(task.sandboxes).intersection(sbx_set)) > 0 ]