Skip to content

Commit

Permalink
Fix transformation portion of materials builder
Browse files Browse the repository at this point in the history
  • Loading branch information
munrojm committed Oct 18, 2023
1 parent c70ea4d commit f17f5f7
Showing 1 changed file with 26 additions and 59 deletions.
85 changes: 26 additions & 59 deletions emmet-builders/emmet/builders/vasp/materials.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,11 @@ def prechunk(self, number_splits: int) -> Iterable[Dict]: # pragma: no cover
temp_query["tags"] = {"$in": self.settings.BUILD_TAGS}

self.logger.info("Finding tasks to process")
all_tasks = list(
self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])
)
all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]))

processed_tasks = set(self.materials.distinct("task_ids"))
to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks
to_process_forms = {
d["formula_pretty"]
for d in all_tasks
if d[self.tasks.key] in to_process_tasks
}
to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks}

N = ceil(len(to_process_forms) / number_splits)

Expand Down Expand Up @@ -152,17 +146,11 @@ def get_items(self) -> Iterator[List[Dict]]:
temp_query["tags"] = {"$in": self.settings.BUILD_TAGS}

self.logger.info("Finding tasks to process")
all_tasks = list(
self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"])
)
all_tasks = list(self.tasks.query(temp_query, [self.tasks.key, "formula_pretty"]))

processed_tasks = set(self.materials.distinct("task_ids"))
to_process_tasks = {d[self.tasks.key] for d in all_tasks} - processed_tasks
to_process_forms = {
d["formula_pretty"]
for d in all_tasks
if d[self.tasks.key] in to_process_tasks
}
to_process_forms = {d["formula_pretty"] for d in all_tasks if d[self.tasks.key] in to_process_tasks}

self.logger.info(f"Found {len(to_process_tasks)} unprocessed tasks")
self.logger.info(f"Found {len(to_process_forms)} unprocessed formulas")
Expand All @@ -172,10 +160,7 @@ def get_items(self) -> Iterator[List[Dict]]:

if self.task_validation:
invalid_ids = {
doc[self.tasks.key]
for doc in self.task_validation.query(
{"valid": False}, [self.task_validation.key]
)
doc[self.tasks.key] for doc in self.task_validation.query({"valid": False}, [self.task_validation.key])
}
else:
invalid_ids = set()
Expand All @@ -199,17 +184,15 @@ def get_items(self) -> Iterator[List[Dict]]:
"input.hubbards",
"input.potcar_spec",
# needed for transform deformation structure back for grouping
"transmuter",
"transformations",
# misc info for materials doc
"tags",
]

for formula in to_process_forms:
tasks_query = dict(temp_query)
tasks_query["formula_pretty"] = formula
tasks = list(
self.tasks.query(criteria=tasks_query, properties=projected_fields)
)
tasks = list(self.tasks.query(criteria=tasks_query, properties=projected_fields))
for t in tasks:
t["is_valid"] = t[self.tasks.key] not in invalid_ids

Expand All @@ -231,12 +214,12 @@ def process_item(self, items: List[Dict]) -> List[Dict]:
formula = tasks[0].formula_pretty
task_ids = [task.task_id for task in tasks]

# not all tasks contains transmuter
transmuters = [task.get("transmuter", None) for task in items]
# not all tasks contains transformation information
task_transformations = [task.get("transformations", None) for task in items]

self.logger.debug(f"Processing {formula}: {task_ids}")

grouped_tasks = self.filter_and_group_tasks(tasks, transmuters)
grouped_tasks = self.filter_and_group_tasks(tasks, task_transformations)
materials = []
for group in grouped_tasks:
try:
Expand All @@ -253,8 +236,7 @@ def process_item(self, items: List[Dict]) -> List[Dict]:
doc.warnings.append(str(e))
materials.append(doc)
self.logger.warn(
f"Failed making material for {failed_ids}."
f" Inserted as deprecated Material: {doc.material_id}"
f"Failed making material for {failed_ids}." f" Inserted as deprecated Material: {doc.material_id}"
)

self.logger.debug(f"Produced {len(materials)} materials for {formula}")
Expand Down Expand Up @@ -285,38 +267,29 @@ def update_targets(self, items: List[List[Dict]]):
self.logger.info("No items to update")

def filter_and_group_tasks(
self, tasks: List[TaskDocument], transmuters: List[Union[Dict, None]]
self, tasks: List[TaskDocument], task_transformations: List[Union[Dict, None]]
) -> Iterator[List[TaskDocument]]:
"""
Groups tasks by structure matching
"""

filtered_tasks = []
filtered_transmuters = []
for task, transmuter in zip(tasks, transmuters):
if any(
allowed_type == task.task_type
for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES
):
filtered_transformations = []
for task, transformations in zip(tasks, task_transformations):
if any(allowed_type == task.task_type for allowed_type in self.settings.VASP_ALLOWED_VASP_TYPES):
filtered_tasks.append(task)
filtered_transmuters.append(transmuter)
filtered_transformations.append(transformations)

structures = []
for idx, (task, transmuter) in enumerate(
zip(filtered_tasks, filtered_transmuters)
):
for idx, (task, transformations) in enumerate(zip(filtered_tasks, filtered_transformations)):
if task.task_type == TaskType.Deformation:
if (
transmuter is None
): # Do not include deformed tasks without transmuter information
if transformations is None: # Do not include deformed tasks without transformation information
self.logger.debug(
"Cannot find transmuter for deformation task {}. Excluding task.".format(
task.task_id
)
"Cannot find transformation for deformation task {}. Excluding task.".format(task.task_id)
)
continue
else:
s = undeform_structure(task.input.structure, transmuter)
s = undeform_structure(task.input.structure, transformations)
else:
s = task.output.structure
s.index: int = idx # type: ignore
Expand All @@ -334,32 +307,26 @@ def filter_and_group_tasks(
yield grouped_tasks


def undeform_structure(structure: Structure, transmuter: Dict) -> Structure:
def undeform_structure(structure: Structure, transformations: Dict) -> Structure:
"""
Get the undeformed structure by applying the transformations in a reverse order.
Args:
structure: deformed structure
transmuter: transformation that deforms the structure
transformation: transformation that deforms the structure
Returns:
undeformed structure
"""

for trans, params in reversed(
list(zip(transmuter["transformations"], transmuter["transformation_params"]))
):
# The transmuter only stores the transformation class and parameter, without
# module info and such. Therefore, there is no general way to reconstruct it,
# and has to do if else check.
if trans == "DeformStructureTransformation":
deform = Deformation(params["deformation"])
for transformation in reversed(transformations.get("history", [])):
if transformation["@class"] == "DeformStructureTransformation":
deform = Deformation(transformation["deformation"])
dst = DeformStructureTransformation(deform.inv)
structure = dst.apply_transformation(structure)
else:
raise RuntimeError(
"Expect transformation to be `DeformStructureTransformation`; "
f"got {trans}"
"Expect transformation to be `DeformStructureTransformation`; " f"got {transformation['@class']}"
)

return structure

0 comments on commit f17f5f7

Please sign in to comment.