Skip to content

Commit

Permalink
simplify: support model deviation of energy per atom (deepmodeling#1312)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <jinzhe.zeng@rutgers.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Han Wang <92130845+wanghan-iapcm@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 25, 2023
1 parent 6c5a48f commit 0761763
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 7 deletions.
16 changes: 16 additions & 0 deletions dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ def general_simplify_arginfo() -> Argument:
doc_model_devi_f_trust_hi = (
"The higher bound of forces for the selection for the model deviation."
)
doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Expand All @@ -50,6 +52,20 @@ def general_simplify_arginfo() -> Argument:
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Argument(
"model_devi_e_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_model_devi_e_trust_lo,
),
Argument(
"model_devi_e_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_model_devi_e_trust_hi,
),
]


Expand Down
30 changes: 23 additions & 7 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ def post_model_devi(iter_index, jdata, mdata):

f_trust_lo = jdata["model_devi_f_trust_lo"]
f_trust_hi = jdata["model_devi_f_trust_hi"]
e_trust_lo = jdata["model_devi_e_trust_lo"]
e_trust_hi = jdata["model_devi_e_trust_hi"]

type_map = jdata.get("type_map", [])
sys_accurate = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -285,16 +287,30 @@ def post_model_devi(iter_index, jdata, mdata):
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
pass
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:
# DeePMD-kit < 2.2.2
cidx_devi_e = None
else:
idx = int(line.split()[0])
f_devi = float(line.split()[4])
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
else:
e_devi = 0.0
subsys = sys_entire[name][idx]
if f_trust_lo <= f_devi < f_trust_hi:
sys_candinate.append(subsys)
elif f_devi >= f_trust_hi:
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif f_devi < f_trust_lo:
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
else:
raise RuntimeError("reach a place that should NOT be reached...")
Expand Down
116 changes: 116 additions & 0 deletions tests/simplify/test_post_model_devi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import os
import shutil
import sys
import unittest
from pathlib import Path

import dpdata
import numpy as np

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
__package__ = "simplify"
from .context import dpgen


class TestSimplifyModelDevi(unittest.TestCase):
def setUp(self):
self.work_path = Path("iter.000001") / dpgen.simplify.simplify.model_devi_name
self.work_path.mkdir(exist_ok=True, parents=True)
self.system = dpdata.System(
data={
"atom_names": ["H"],
"atom_numbs": [1],
"atom_types": np.zeros((1,), dtype=int),
"coords": np.zeros((1, 1, 3), dtype=np.float32),
"cells": np.zeros((1, 3, 3), dtype=np.float32),
"orig": np.zeros(3, dtype=np.float32),
"nopbc": True,
"energies": np.zeros((1,), dtype=np.float32),
"forces": np.zeros((1, 1, 3), dtype=np.float32),
}
)
self.system.to_deepmd_npy(
self.work_path / "data.rest.old" / self.system.formula
)
model_devi = np.array([[0, 0.2, 0.1, 0.15, 0.2, 0.1, 0.15, 0.2]])
np.savetxt(
self.work_path / "details",
model_devi,
fmt=["%12d"] + ["%19.6e" for _ in range(7)],
header="data.rest.old/"
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)

def tearDown(self):
shutil.rmtree("iter.000001", ignore_errors=True)

def test_post_model_devi_f_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.15,
"model_devi_f_trust_hi": 0.25,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_e_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()

def test_post_model_devi_f_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_f_trust_lo": 0.0,
"model_devi_f_trust_hi": 0.0,
"model_devi_e_trust_lo": float("inf"),
"model_devi_e_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_e_failed(self):
with self.assertRaises(RuntimeError):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.0,
"model_devi_e_trust_hi": 0.0,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"iter_pick_number": 1,
},
{},
)

def test_post_model_devi_accurate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.3,
"model_devi_e_trust_hi": 0.4,
"model_devi_f_trust_lo": 0.3,
"model_devi_f_trust_hi": 0.4,
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.accurate" / self.system.formula).exists()

0 comments on commit 0761763

Please sign in to comment.