Skip to content

Commit

Permalink
Correct typing of TaskDoc.transformations field (#937)
Browse files Browse the repository at this point in the history
* Fix TaskDoc.transformations

* Add tests for transformation typing
  • Loading branch information
esoteric-ephemera committed Feb 15, 2024
1 parent d0276a1 commit 2fa80de
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
2 changes: 1 addition & 1 deletion emmet-core/emmet/core/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class TaskDoc(StructureMetadata, extra="allow"):
icsd_id: Optional[Union[str, int]] = Field(
None, description="Inorganic Crystal Structure Database id of the structure"
)
transformations: Optional[Dict[str, Any]] = Field(
transformations: Optional[Any] = Field(
None,
description="Information on the structural transformations, parsed from a "
"transformations.json file",
Expand Down
45 changes: 45 additions & 0 deletions emmet-core/tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def test_output_summary(test_dir, object_name, task_name):
)
def test_task_doc(test_dir, object_name):
from monty.json import jsanitize
from monty.serialization import dumpfn
from pymatgen.alchemy.materials import TransformedStructure
from pymatgen.entries.computed_entries import ComputedEntry
from pymatgen.transformations.standard_transformations import (
DeformStructureTransformation,
)

from emmet.core.tasks import TaskDoc

Expand All @@ -138,3 +143,43 @@ def test_task_doc(test_dir, object_name):
assert isinstance(
test_doc.entry, ComputedEntry
), f"Unexpected entry {test_doc.entry} for {object_name}"

# Test that transformations field works, using hydrostatic compression as example
ts = TransformedStructure(
test_doc.output.structure,
transformations=[
DeformStructureTransformation(
deformation=[
[0.9 if i == j else 0.0 for j in range(3)] for i in range(3)
]
)
],
)
ts_json = jsanitize(ts.as_dict())
dumpfn(ts, f"{dir_name}/transformations.json")
test_doc = TaskDoc.from_directory(dir_name)
# if other_parameters == {}, this is popped from the TaskDoc.transformations field
# seems like @version is added by monty serialization
# jsanitize needed because pymatgen.core.Structure.pbc is a tuple
assert all(
test_doc.transformations[k] == v
for k, v in ts_json.items()
if k
not in (
"other_parameters",
"@version",
"last_modified",
)
)
assert isinstance(test_doc.transformations, dict)

# now test case when transformations are serialized, relevant for atomate2
test_doc = TaskDoc(
**{
"transformations": ts,
**{
k: v for k, v in test_doc.model_dump().items() if k != "transformations"
},
}
)
assert test_doc.transformations == ts

0 comments on commit 2fa80de

Please sign in to comment.