Skip to content

Commit

Permalink
update materialsbuilder to work with new docs
Browse files Browse the repository at this point in the history
  • Loading branch information
shyamd committed Oct 22, 2020
1 parent bdb28a5 commit 1c9dbd8
Showing 1 changed file with 36 additions and 51 deletions.
87 changes: 36 additions & 51 deletions emmet-builders/emmet/builders/vasp/materials.py
Expand Up @@ -9,13 +9,16 @@
from pymatgen.analysis.structure_analyzer import oxide_type
from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher

from emmet.builders import SETTINGS
from emmet.core import SETTINGS
from emmet.builders.utils import maximal_spanning_non_intersecting_subsets
from emmet.core.utils import group_structures, jsanitize
from emmet.core.vasp.calc_types import CalcType, TaskType, run_type, task_type
from emmet.core.vasp.material import MaterialsDoc, PropertyOrigin

from emmet.core.vasp.material import MaterialsDoc
from emmet.stubs import ComputedEntry

from emmet.core.vasp.calc_types import TaskType
from emmet.core.vasp.task import TaskDocument

__author__ = "Shyam Dwaraknath <shyamd@lbl.gov>"


Expand Down Expand Up @@ -45,7 +48,6 @@ def __init__(
task_validation: Optional[Store] = None,
query: Optional[Dict] = None,
allowed_task_types: Optional[List[str]] = None,
tags_to_sandboxes: Optional[Dict[str, List[str]]] = None,
symprec: float = SETTINGS.SYMPREC,
ltol: float = SETTINGS.LTOL,
stol: float = SETTINGS.STOL,
Expand All @@ -58,7 +60,6 @@ def __init__(
materials: Store of materials documents to generate
query: dictionary to limit tasks to be analyzed
allowed_task_types: list of task_types that can be processed
tags_to_sandboxes: dictionary mapping sandboxes to a list of tags
symprec: tolerance for SPGLib spacegroup finding
ltol: StructureMatcher tuning parameter for matching tasks to materials
stol: StructureMatcher tuning parameter for matching tasks to materials
Expand All @@ -68,10 +69,11 @@ def __init__(
self.tasks = tasks
self.materials = materials
self.task_validation = task_validation
self.allowed_task_types = {TaskType(t) for t in allowed_task_types} or set(
TaskType
)
self.tags_to_sandboxes = tags_to_sandboxes or SETTINGS.tags_to_sandboxes
if allowed_task_types is None:
self.allowed_task_types = set(TaskType)
else:
self.allowed_task_types = {TaskType(t) for t in allowed_task_types}

self.query = query if query else {}
self.symprec = symprec
self.ltol = ltol
Expand All @@ -90,20 +92,20 @@ def ensure_indexes(self):
"""

# Basic search index for tasks
self.tasks.ensure_index(self.tasks.key)
self.tasks.ensure_index(self.tasks.last_updated_field)
self.tasks.ensure_index("task_id")
self.tasks.ensure_index("last_updated")
self.tasks.ensure_index("state")
self.tasks.ensure_index("formula_pretty")

# Search index for materials
self.materials.ensure_index(self.materials.key)
self.materials.ensure_index(self.materials.last_updated_field)
self.materials.ensure_index("material_id")
self.materials.ensure_index("last_updated")
self.materials.ensure_index("sandboxes")
self.materials.ensure_index("task_ids")

if self.task_validation:
self.task_validation.ensure_index(self.task_validation.key)
self.task_validation.ensure_index("is_valid")
self.task_validation.ensure_index("task_id")
self.task_validation.ensure_index("valid")

def get_items(self) -> Iterator[List[Dict]]:
"""
Expand All @@ -116,7 +118,9 @@ def get_items(self) -> Iterator[List[Dict]]:
"""

self.logger.info("Materials builder started")
self.logger.info(f"Allowed task types: {self.allowed_task_types}")
self.logger.info(
f"Allowed task types: {[task_type.value for task_type in self.allowed_task_types]}"
)

self.logger.info("Setting indexes")
self.ensure_indexes()
Expand Down Expand Up @@ -159,23 +163,18 @@ def get_items(self) -> Iterator[List[Dict]]:
invalid_ids = set()

projected_fields = [
self.tasks.last_updated_field,
self.tasks.key,
"last_updated",
"completed_at",
"task_id",
"formula_pretty",
"output.energy_per_atom",
"output.structure",
"output.parameters",
"input.parameters",
"orig_inputs",
"input.structure",
"tags",
]

sandboxed_tags = {
sandbox
for sandbox in self.tags_to_sandboxes.values()
if self.tags_to_sandboxes is not None
}

for formula in to_process_forms:
tasks_query = dict(temp_query)
tasks_query["formula_pretty"] = formula
Expand All @@ -188,18 +187,6 @@ def get_items(self) -> Iterator[List[Dict]]:
else:
t["is_valid"] = True

if any(tag in sandboxed_tags for tag in t.get("tags", [])):
t["sandboxes"] = [
sandbox
for sandbox in self.tags_to_sandboxes
if any(
tag in t["tags"]
for tag in set(self.tags_to_sandboxes[sandbox])
)
]
else:
t["sandboxes"] = ["core"]

yield tasks

def process_item(self, tasks: List[Dict]) -> List[Dict]:
Expand All @@ -213,8 +200,9 @@ def process_item(self, tasks: List[Dict]) -> List[Dict]:
([dict],list) : a list of new materials docs and a list of task_ids that were processsed
"""

formula = tasks[0]["formula_pretty"]
task_ids = [task[self.tasks.key] for task in tasks]
tasks = [TaskDocument(**task) for task in tasks]
formula = tasks[0].formula_pretty
task_ids = [task.task_id for task in tasks]
self.logger.debug(f"Processing {formula} : {task_ids}")

materials = []
Expand All @@ -223,7 +211,7 @@ def process_item(self, tasks: List[Dict]) -> List[Dict]:
materials = [MaterialsDoc.from_tasks(group) for group in grouped_tasks]
self.logger.debug(f"Produced {len(materials)} materials for {formula}")

return [mat.dict() for mat in materials if mat is not None]
return [mat.dict() for mat in materials]

def update_targets(self, items: List[List[Dict]]):
"""
Expand All @@ -239,19 +227,19 @@ def update_targets(self, items: List[List[Dict]]):
for item in items:
item.update({"_bt": self.timestamp})

material_ids = {item[self.materials.key] for item in items}
material_ids = list({item["material_id"] for item in items})

if len(items) > 0:
self.logger.info(f"Updating {len(items)} materials")
self.materials.remove_docs({self.materials.key: {"$in": material_ids}})
self.materials.update(
docs=jsanitize(items, allow_bson=True),
key=(self.materials.key, "sandboxes"),
key=["material_id", "sandboxes"],
)
else:
self.logger.info("No items to update")

def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]:
def filter_and_group_tasks(self, tasks: List[TaskDocument]) -> Iterator[List[Dict]]:
"""
Groups tasks by structure matching
"""
Expand All @@ -260,15 +248,15 @@ def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]:
task
for task in tasks
if any(
allowed_type in task_type(task.get("orig_inputs", {}))
allowed_type is task.task_type
for allowed_type in self.allowed_task_types
)
]

structures = []

for idx, t in enumerate(filtered_tasks):
s = Structure.from_dict(t["output"]["structure"])
for idx, task in enumerate(filtered_tasks):
s = task.output.structure
s.index = idx
structures.append(s)

Expand All @@ -282,14 +270,11 @@ def filter_and_group_tasks(self, tasks: List[Dict]) -> Iterator[List[Dict]]:

for group in grouped_structures:
grouped_tasks = [filtered_tasks[struc.index] for struc in group]
sandboxes = [
task["sandboxes"] for task in grouped_tasks if "sandboxes" in task
]
sandboxes = {frozenset(task.sandboxes) for task in grouped_tasks}

for sbx_set in maximal_spanning_non_intersecting_subsets(sandboxes):
yield [
task
for task in grouped_tasks
if len(set(task.get("sandboxes", ["core"])).intersection(sbx_set))
> 0
if len(set(task.sandboxes).intersection(sbx_set)) > 0
]

0 comments on commit 1c9dbd8

Please sign in to comment.