Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ASE's traj support #614

Merged
merged 55 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
820ca84
add ASE's traj
thangckt Mar 13, 2024
4594981
Update ase.py
thangckt Mar 13, 2024
b055b88
Update ase.py
thangckt Mar 13, 2024
bde8b7b
Update ase.py
thangckt Mar 13, 2024
0377b27
Update ase.py
thangckt Mar 13, 2024
da3daa8
finish add ASE's traj support
thangckt Mar 13, 2024
70f7813
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
65d2797
Update ase.py
thangckt Mar 13, 2024
42ff2b0
Merge branch 'tha_devel' of https://github.com/thangckt/dpdata into t…
thangckt Mar 13, 2024
7f1a398
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
59e3247
Update ase.py
thangckt Mar 13, 2024
a381f89
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 13, 2024
1585ff8
Update ase.py
thangckt Mar 14, 2024
2fefd7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
793745c
Update ase.py
thangckt Mar 14, 2024
76895a7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
a12c374
Update ase.py
thangckt Mar 14, 2024
0a77be5
Update test_ase_traj.py
thangckt Mar 14, 2024
ea85559
u
thangckt Mar 14, 2024
9dbcbd8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
0b27552
Update test_ase_traj.py
thangckt Mar 14, 2024
c09894b
Merge branch 'tha_devel' of https://github.com/thangckt/dpdata into t…
thangckt Mar 14, 2024
67e7876
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
2afdc4f
up
thangckt Mar 14, 2024
8771f8c
u
thangckt Mar 14, 2024
fe5163a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 14, 2024
d41573b
Update test_ase_traj.py
thangckt Mar 14, 2024
17fcb16
Merge branch 'tha_devel' of https://github.com/thangckt/dpdata into t…
thangckt Mar 14, 2024
e05c9ce
u
thangckt Mar 14, 2024
95939bd
Update ase.py
thangckt Mar 14, 2024
52ef114
Update ase.py
thangckt Mar 14, 2024
60212f6
Update ase.py
thangckt Mar 14, 2024
57d938e
Update test_ase_traj.py
thangckt Mar 14, 2024
3c95376
Update test_ase_traj.py
thangckt Mar 14, 2024
c3c45ef
Update ase.py
thangckt Mar 15, 2024
46171fc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
1efef9b
Update test_ase_traj.py
thangckt Mar 15, 2024
b669ae1
Merge branch 'tha_devel' of https://github.com/thangckt/dpdata into t…
thangckt Mar 15, 2024
33c5084
Update ase.py
thangckt Mar 15, 2024
790d4b5
u
thangckt Mar 15, 2024
f0c5c7d
u
thangckt Mar 15, 2024
551920e
u
thangckt Mar 15, 2024
d7c98f9
u
thangckt Mar 15, 2024
e0b53f7
u
thangckt Mar 15, 2024
81f402a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 15, 2024
9248833
u
thangckt Mar 15, 2024
8f4684d
u
thangckt Mar 16, 2024
eb66571
u
thangckt Mar 16, 2024
9893423
Update test_ase_traj.py
thangckt Mar 16, 2024
77e8ef4
u
thangckt Mar 17, 2024
02b4dac
u
thangckt Mar 18, 2024
45c8018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 18, 2024
71d0e70
Update ase.py
thangckt Mar 18, 2024
e3909e7
add a test for unlabeled system
njzjz Mar 19, 2024
b49a4e2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions dpdata/plugins/ase.py
njzjz marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
try:
import ase.io
from ase.calculators.calculator import PropertyNotImplementedError
from ase.io import Trajectory

if TYPE_CHECKING:
from ase.optimize.optimize import Optimizer
Expand Down Expand Up @@ -187,6 +188,115 @@ def to_labeled_system(self, data, *args, **kwargs):
return structures


@Format.register("ase/traj")
class ASETrajFormat(Format):
thangckt marked this conversation as resolved.
Show resolved Hide resolved
"""Format for the ASE's trajectory format <https://wiki.fysik.dtu.dk/ase/ase/io/trajectory.html#module-ase.io.trajectory>`_ (ase).'
a `traj' contains a sequence of frames, each of which is an `Atoms' object.
"""

def from_system(
self,
file_name: str,
begin: Optional[int] = 0,
end: Optional[int] = None,
step: Optional[int] = 1,
**kwargs,
) -> dict:
"""Read ASE's trajectory file to `System` of multiple frames.

Parameters
----------
file_name : str
ASE's trajectory file
begin : int, optional
begin frame index
end : int, optional
end frame index
step : int, optional
frame index step
**kwargs : dict
other parameters

Returns
-------
dict_frames: dict
a dictionary containing data of multiple frames
"""
traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]
dict_frames = ASEStructureFormat().from_system(sub_traj[0])
for atoms in sub_traj[1:]:
tmp = ASEStructureFormat().from_system(atoms)
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])

## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)

return dict_frames

def from_labeled_system(
self,
file_name: str,
begin: Optional[int] = 0,
end: Optional[int] = None,
step: Optional[int] = 1,
**kwargs,
) -> dict:
"""Read ASE's trajectory file to `System` of multiple frames.

Parameters
----------
file_name : str
ASE's trajectory file
begin : int, optional
begin frame index
end : int, optional
end frame index
step : int, optional
frame index step
**kwargs : dict
other parameters

Returns
-------
dict_frames: dict
a dictionary containing data of multiple frames
"""
traj = Trajectory(file_name)
sub_traj = traj[begin:end:step]

## check if the first frame has a calculator
if sub_traj[0].calc is None:
raise ValueError(
"The input trajectory does not contain energies and forces, may not be a labeled system."
)

dict_frames = ASEStructureFormat().from_labeled_system(sub_traj[0])
for atoms in sub_traj[1:]:
tmp = ASEStructureFormat().from_labeled_system(atoms)
dict_frames["cells"] = np.append(dict_frames["cells"], tmp["cells"][0])
dict_frames["coords"] = np.append(dict_frames["coords"], tmp["coords"][0])
dict_frames["energies"] = np.append(
dict_frames["energies"], tmp["energies"][0]
)
dict_frames["forces"] = np.append(dict_frames["forces"], tmp["forces"][0])
if "virials" in tmp.keys() and "virials" in dict_frames.keys():
dict_frames["virials"] = np.append(
dict_frames["virials"], tmp["virials"][0]
)

## Correct the shape of numpy arrays
dict_frames["cells"] = dict_frames["cells"].reshape(-1, 3, 3)
dict_frames["coords"] = dict_frames["coords"].reshape(len(sub_traj), -1, 3)
dict_frames["forces"] = dict_frames["forces"].reshape(len(sub_traj), -1, 3)
if "virials" in dict_frames.keys():
dict_frames["virials"] = dict_frames["virials"].reshape(-1, 3, 3)

return dict_frames


@Driver.register("ase")
class ASEDriver(Driver):
"""ASE Driver.
Expand Down
Binary file added tests/ase_traj/Cu.traj
Binary file not shown.
13 changes: 13 additions & 0 deletions tests/test_ase_traj.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,18 @@ def setUp(self):
self.v_places = 4


@unittest.skipIf(skip_ase, "skip ase related test. install ase to fix")
class TestASEtraj3(unittest.TestCase, CompLabeledSys, IsPBC):
def setUp(self):
self.system_1 = dpdata.System("ase_traj/Cu.traj", fmt="ase/traj")
self.system_2 = dpdata.LabeledSystem("ase_traj/Cu.traj", fmt="ase/traj")
thangckt marked this conversation as resolved.
Show resolved Hide resolved
thangckt marked this conversation as resolved.
Show resolved Hide resolved
self.system_3 = dpdata.System("ase_traj/HeAlO.traj", fmt="ase/traj")
self.system_4 = dpdata.LabeledSystem("ase_traj/HeAlO.traj", fmt="ase/traj")
self.places = 6
self.e_places = 6
self.f_places = 6
self.v_places = 4


if __name__ == "__main__":
unittest.main()