-
Notifications
You must be signed in to change notification settings - Fork 62
/
task_validator.py
75 lines (65 loc) · 2.43 KB
/
task_validator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from typing import Dict, List, Optional, Union
import numpy as np
from maggma.builders import MapBuilder
from maggma.core import Store
from pymatgen.core import Structure
from emmet.builders import SETTINGS
from emmet.builders.settings import EmmetBuildSettings
from emmet.core.vasp.calc_types import run_type, task_type
from emmet.core.vasp.task import TaskDocument
from emmet.core.vasp.validation import DeprecationMessage, ValidationDoc
class TaskValidator(MapBuilder):
def __init__(
self,
tasks: Store,
task_validation: Store,
settings: Optional[EmmetBuildSettings] = None,
query: Optional[Dict] = None,
**kwargs,
):
"""
Creates task_types from tasks and type definitions
Args:
tasks: Store of task documents
task_validation: Store of task_types for tasks
"""
self.tasks = tasks
self.task_validation = task_validation
self.settings = EmmetBuildSettings.autoload(settings)
self.query = query
self.kwargs = kwargs
super().__init__(
source=tasks,
target=task_validation,
projection=[
"orig_inputs",
"output.structure",
"output.bandgap",
"input.parameters",
"calcs_reversed.output.ionic_steps.electronic_steps.e_fr_energy",
"tags",
],
query=query,
**kwargs,
)
def unary_function(self, item):
"""
Find the task_type for the item
Args:
item (dict): a (projection of a) task doc
"""
task_doc = TaskDocument(**item)
validation_doc = ValidationDoc.from_task_doc(
task_doc=task_doc,
kpts_tolerance=self.settings.VASP_KPTS_TOLERANCE,
kspacing_tolerance=self.settings.VASP_KSPACING_TOLERANCE,
input_sets=self.settings.VASP_DEFAULT_INPUT_SETS,
LDAU_fields=self.settings.VASP_CHECKED_LDAU_FIELDS,
max_allowed_scf_gradient=self.settings.VASP_MAX_SCF_GRADIENT,
)
bad_tags = list(set(task_doc.tags).intersection(self.settings.DEPRECATED_TAGS))
if len(bad_tags) > 0:
validation_doc.warnings.append(f"Manual Deprecation by tags: {bad_tags}")
validation_doc.valid = False
validation_doc.reasons.append(DeprecationMessage.MANUAL)
return validation_doc