Skip to content

Commit

Permalink
Fix TaskDoc.set_entry (#853)
Browse files Browse the repository at this point in the history
* ruff auto fix

* TaskDoc.set_entry don't require self.task_id to set self.entry

fixes atomate2 defect workflow test, see materialsproject/atomate2#548 (comment)

* deprecated doc.copy to doc.model_copy

* test_task_doc ensure test_doc.entry is ComputedEntry

* Linting

---------

Co-authored-by: Jason Munro <jmunro@lbl.gov>
  • Loading branch information
janosh and munrojm committed Oct 12, 2023
1 parent e363e07 commit 475c79e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 27 deletions.
39 changes: 16 additions & 23 deletions emmet-core/emmet/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import numpy as np
from monty.json import MontyDecoder
from monty.serialization import loadfn
from pydantic import field_validator, ConfigDict, BaseModel, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from pymatgen.analysis.structure_analyzer import oxide_type
from pymatgen.core.structure import Structure
from pymatgen.core.trajectory import Trajectory
Expand Down Expand Up @@ -70,7 +70,7 @@ class OrigInputs(BaseModel):
@classmethod
def potcar_ok(cls, v):
if isinstance(v, list):
return [i for i in v]
return list(v)
return v

model_config = ConfigDict(arbitrary_types_allowed=True)
Expand Down Expand Up @@ -173,7 +173,9 @@ class InputDoc(BaseModel):
is_lasph: Optional[bool] = Field(
None, description="Whether the calculation was run with aspherical corrections"
)
is_hubbard: bool = Field(False, description="Is this a Hubbard +U calculation")
is_hubbard: bool = Field(
default=False, description="Is this a Hubbard +U calculation"
)
hubbards: Optional[dict] = Field(None, description="The hubbard parameters used")

@classmethod
Expand Down Expand Up @@ -306,9 +308,7 @@ def from_vasp_calc_docs(


class TaskDoc(StructureMetadata, extra="allow"):
"""
Calculation-level details about VASP calculations that power Materials Project.
"""
"""Calculation-level details about VASP calculations that power Materials Project."""

tags: Union[List[str], None] = Field(
[], title="tag", description="Metadata tagged to a given task."
Expand Down Expand Up @@ -406,7 +406,7 @@ def last_updated_dict_ok(cls, v) -> datetime:

@model_validator(mode="after")
def set_entry(self) -> datetime:
if not self.entry and (self.calcs_reversed and self.task_id):
if not self.entry and self.calcs_reversed:
self.entry = self.get_entry(self.calcs_reversed, self.task_id)
return self

Expand All @@ -416,7 +416,7 @@ def from_directory(
dir_name: Union[Path, str],
volumetric_files: Tuple[str, ...] = _VOLUMETRIC_FILES,
store_additional_json: bool = True,
additional_fields: Dict[str, Any] = None,
additional_fields: Optional[Dict[str, Any]] = None,
volume_change_warning_tol: float = 0.2,
**vasp_calculation_kwargs,
) -> _T:
Expand Down Expand Up @@ -514,14 +514,13 @@ def from_directory(
included_objects=included_objects,
task_type=calcs_reversed[0].task_type,
)
doc = doc.model_copy(update=additional_fields)
return doc
return doc.model_copy(update=additional_fields)

@classmethod
def from_vasprun(
cls: Type[_T],
path: Union[str, Path],
additional_fields: Dict[str, Any] = None,
additional_fields: Optional[Dict[str, Any]] = None,
volume_change_warning_tol: float = 0.2,
**vasp_calculation_kwargs,
) -> _T:
Expand Down Expand Up @@ -588,7 +587,7 @@ def from_vasprun(
task_type=calcs_reversed[0].task_type,
)
if additional_fields:
doc = doc.copy(update=additional_fields)
doc = doc.model_copy(update=additional_fields)
return doc

@staticmethod
Expand Down Expand Up @@ -654,9 +653,7 @@ def structure_entry(self) -> ComputedStructureEntry:


class TrajectoryDoc(BaseModel):
"""
Model for task trajectory data
"""
"""Model for task trajectory data."""

task_id: Optional[str] = Field(
None,
Expand All @@ -671,9 +668,7 @@ class TrajectoryDoc(BaseModel):


class EntryDoc(BaseModel):
"""
Model for task entry data
"""
"""Model for task entry data."""

task_id: Optional[str] = Field(
None,
Expand All @@ -688,9 +683,7 @@ class EntryDoc(BaseModel):


class DeprecationDoc(BaseModel):
"""
Model for task deprecation data
"""
"""Model for task deprecation data."""

task_id: Optional[str] = Field(
None,
Expand Down Expand Up @@ -871,7 +864,7 @@ def _get_drift_warnings(calc_doc: Calculation) -> List[str]:
def _get_state(calcs_reversed: List[Calculation], analysis: AnalysisDoc) -> TaskState:
"""Get state from calculation documents and relaxation analysis."""
all_calcs_completed = all(
[c.has_vasp_completed == TaskState.SUCCESS for c in calcs_reversed]
c.has_vasp_completed == TaskState.SUCCESS for c in calcs_reversed
)
if len(analysis.errors) == 0 and all_calcs_completed:
return TaskState.SUCCESS # type: ignore
Expand Down Expand Up @@ -956,7 +949,7 @@ def _get_task_files(files, suffix=""):
vasp_files["outcar_file"] = file
elif file.match(f"*CONTCAR{suffix}*"):
vasp_files["contcar_file"] = file
elif any([file.match(f"*{f}{suffix}*") for f in volumetric_files]):
elif any(file.match(f"*{f}{suffix}*") for f in volumetric_files):
vol_files.append(file)
elif file.match(f"*POSCAR.T=*{suffix}*"):
elph_poscars.append(file)
Expand Down
18 changes: 14 additions & 4 deletions emmet-core/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from tests.conftest import assert_schemas_equal, get_test_object


Expand Down Expand Up @@ -39,7 +40,7 @@ def test_analysis_summary(test_dir, object_name):


@pytest.mark.parametrize(
"object_name,task_name",
("object_name", "task_name"),
[
pytest.param("SiOptimizeDouble", "relax1", id="SiOptimizeDouble"),
pytest.param("SiStatic", "standard", id="SiStatic"),
Expand Down Expand Up @@ -70,7 +71,7 @@ def test_input_summary(test_dir, object_name, task_name):


@pytest.mark.parametrize(
"object_name,task_name",
("object_name", "task_name"),
[
pytest.param("SiOptimizeDouble", "relax2", id="SiOptimizeDouble"),
pytest.param("SiStatic", "standard", id="SiStatic"),
Expand Down Expand Up @@ -110,6 +111,7 @@ def test_output_summary(test_dir, object_name, task_name):
)
def test_task_doc(test_dir, object_name):
from monty.json import MontyDecoder, jsanitize
from pymatgen.entries.computed_entries import ComputedEntry

from emmet.core.tasks import TaskDoc

Expand All @@ -119,11 +121,19 @@ def test_task_doc(test_dir, object_name):
assert_schemas_equal(test_doc, test_object.task_doc)

# test document can be jsanitized
d = jsanitize(test_doc, strict=True, enum_values=True, allow_bson=True)
dct = jsanitize(test_doc, strict=True, enum_values=True, allow_bson=True)

# and decoded
MontyDecoder().process_decoded(d)
MontyDecoder().process_decoded(dct)

# Test that additional_fields works
test_doc = TaskDoc.from_directory(dir_name, additional_fields={"foo": "bar"})
assert test_doc.model_dump()["foo"] == "bar"

assert len(test_doc.calcs_reversed) == len(test_object.task_files)

# Check that entry is populated when calcs_reversed is not None
if test_doc.calcs_reversed:
assert isinstance(
test_doc.entry, ComputedEntry
), f"Unexpected entry {test_doc.entry} for {object_name}"

0 comments on commit 475c79e

Please sign in to comment.