Skip to content

Commit

Permalink
limit the filename length dumped by MultiSystems (#554)
Browse files Browse the repository at this point in the history
Fix #553.

---------

Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] committed Oct 15, 2023
1 parent 5e8f0ba commit e73531b
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 2 deletions.
49 changes: 47 additions & 2 deletions dpdata/system.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# %%
import glob
import hashlib
import os
import warnings
from copy import deepcopy
Expand All @@ -19,7 +20,13 @@
from dpdata.driver import Driver, Minimizer
from dpdata.format import Format
from dpdata.plugin import Plugin
from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names
from dpdata.utils import (
add_atom_names,
elements_index_map,
remove_pbc,
sort_atom_names,
utf8len,
)


def load_format(fmt):
Expand Down Expand Up @@ -562,6 +569,42 @@ def uniq_formula(self):
]
)

@property
def short_formula(self) -> str:
"""Return the short formula of this system. Elements with zero number
will be removed.
"""
return "".join(
[
f"{symbol}{numb}"
for symbol, numb in zip(
self.data["atom_names"], self.data["atom_numbs"]
)
if numb
]
)

@property
def formula_hash(self) -> str:
"""Return the hash of the formula of this system."""
return hashlib.sha256(self.formula.encode("utf-8")).hexdigest()

@property
def short_name(self) -> str:
"""Return the short name of this system (no more than 255 bytes), in
the following order:
- formula
- short_formula
- formula_hash.
"""
formula = self.formula
if utf8len(formula) <= 255:
return formula
short_formula = self.short_formula
if utf8len(short_formula) <= 255:
return short_formula
return self.formula_hash

def extend(self, systems):
"""Extend a system list to this system.
Expand Down Expand Up @@ -1247,7 +1290,9 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs):
def to_fmt_obj(self, fmtobj, directory, *args, **kwargs):
if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat):
for fn, ss in zip(
fmtobj.to_multi_systems(self.systems.keys(), directory, **kwargs),
fmtobj.to_multi_systems(
[ss.short_name for ss in self.systems.values()], directory, **kwargs
),
self.systems.values(),
):
ss.to_fmt_obj(fmtobj, fn, *args, **kwargs)
Expand Down
5 changes: 5 additions & 0 deletions dpdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ def uniq_atom_names(data):
sum(ii == data["atom_types"]) for ii in range(len(data["atom_names"]))
]
return data


def utf8len(s: str) -> int:
"""Return the byte length of a string."""
return len(s.encode("utf-8"))
34 changes: 34 additions & 0 deletions tests/test_multisystems.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import os
import tempfile
import unittest
from itertools import permutations

import numpy as np
from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems
from context import dpdata

Expand Down Expand Up @@ -200,5 +202,37 @@ def setUp(self):
self.atom_names = ["C", "H", "O"]


class TestLongFilename(unittest.TestCase):
def test_long_filename1(self):
system = dpdata.System(
data={
"atom_names": [f"TYPE{ii}" for ii in range(200)],
"atom_numbs": [1] + [0 for _ in range(199)],
"atom_types": np.arange(1),
"coords": np.zeros((1, 1, 3)),
"orig": np.zeros(3),
"cells": np.zeros((1, 3, 3)),
}
)
ms = dpdata.MultiSystems(system)
with tempfile.TemporaryDirectory() as tmpdir:
ms.to_deepmd_npy(tmpdir)

def test_long_filename2(self):
system = dpdata.System(
data={
"atom_names": [f"TYPE{ii}" for ii in range(200)],
"atom_numbs": [1 for _ in range(200)],
"atom_types": np.arange(200),
"coords": np.zeros((1, 200, 3)),
"orig": np.zeros(3),
"cells": np.zeros((1, 3, 3)),
}
)
ms = dpdata.MultiSystems(system)
with tempfile.TemporaryDirectory() as tmpdir:
ms.to_deepmd_npy(tmpdir)


if __name__ == "__main__":
unittest.main()

0 comments on commit e73531b

Please sign in to comment.