diff --git a/emmet-core/emmet/core/vasp/validation.py b/emmet-core/emmet/core/vasp/validation.py index 9de9997a73..f4961dc6fd 100644 --- a/emmet-core/emmet/core/vasp/validation.py +++ b/emmet-core/emmet/core/vasp/validation.py @@ -13,6 +13,7 @@ class DeprecationMessage(DocEnum): MANUAL = "M", "manual deprecation" + DEPRECATED_TAGS = "M001", "Deprecated tag" KPTS = "C001", "Too few KPoints" KSPACING = "C002", "KSpacing not high enough" ENCUT = "C002", "ENCUT too low" @@ -56,6 +57,7 @@ def from_task_doc( input_sets: Dict[str, PyObject] = SETTINGS.VASP_DEFAULT_INPUT_SETS, LDAU_fields: List[str] = SETTINGS.VASP_CHECKED_LDAU_FIELDS, max_allowed_scf_gradient: float = SETTINGS.VASP_MAX_SCF_GRADIENT, + deprecated_tags: Optional[List[str]] = None, ) -> "ValidationDoc": """ Determines if a calculation is valid based on expected input parameters from a pymatgen inputset @@ -73,12 +75,18 @@ def from_task_doc( structure = task_doc.output.structure calc_type = task_doc.calc_type inputs = task_doc.orig_inputs + bandgap = task_doc.output.bandgap reasons = [] data = {} if str(calc_type) in input_sets: - valid_input_set = input_sets[str(calc_type)](structure) + + # Ensure inputsets that need the bandgap get it + try: + valid_input_set = input_sets[str(calc_type)](structure, bandgap=bandgap) + except TypeError: + valid_input_set = input_sets[str(calc_type)](structure) # Checking K-Points # Calculations that use KSPACING will not have a .kpoints attr @@ -137,18 +145,25 @@ def from_task_doc( ): reasons.append(DeprecationMessage.LDAU) - # Check the max upwards SCF step - skip = inputs.get("incar", {}).get("NLEMDL") - energies = [ - d["e_fr_energy"] - for d in task_doc.calcs_reversed[0]["output"]["ionic_steps"][-1][ - "electronic_steps" - ] + # Check the max upwards SCF step + skip = inputs.get("incar", {}).get("NLEMDL") + energies = [ + d["e_fr_energy"] + for d in task_doc.calcs_reversed[0]["output"]["ionic_steps"][-1][ + "electronic_steps" ] - max_gradient = np.max(np.gradient(energies)[skip:]) - data["max_gradient"] = max_gradient - if max_gradient > max_allowed_scf_gradient: - reasons.append(DeprecationMessage.MAX_SCF) + ] + max_gradient = np.max(np.gradient(energies)[skip:]) + data["max_gradient"] = max_gradient + if max_gradient > max_allowed_scf_gradient: + reasons.append(DeprecationMessage.MAX_SCF) + + # Check for Manual deprecations + if deprecated_tags is not None: + bad_tags = list(set(task_doc.tags).intersection(deprecated_tags)) + if len(bad_tags) > 0: + reasons.append(DeprecationMessage.DEPRECATED_TAGS) + data["bad_tags"] = bad_tags doc = ValidationDoc( task_id=task_doc.task_id,