Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify: support using true error as error indicator #1321

Merged
merged 2 commits into from
Sep 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ def general_simplify_arginfo() -> Argument:
)
doc_model_devi_e_trust_lo = "The lower bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_model_devi_e_trust_hi = "The higher bound of energy per atom for the selection for the model deviation. Requires DeePMD-kit version >=2.2.2."
doc_true_error_f_trust_lo = "The lower bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_f_trust_hi = "The higher bound of forces for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_e_trust_lo = "The lower bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4."
doc_true_error_e_trust_hi = "The higher bound of energy per atom for the selection for the true error. Requires DeePMD-kit version >=2.2.4."

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Expand Down Expand Up @@ -66,6 +70,34 @@ def general_simplify_arginfo() -> Argument:
default=float("inf"),
doc=doc_model_devi_e_trust_hi,
),
Argument(
"true_error_f_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_f_trust_lo,
),
Argument(
"true_error_f_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_f_trust_hi,
),
Argument(
"true_error_e_trust_lo",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_e_trust_lo,
),
Argument(
"true_error_e_trust_hi",
float,
optional=True,
default=float("inf"),
doc=doc_true_error_e_trust_hi,
),
]


Expand Down
126 changes: 96 additions & 30 deletions dpgen/simplify/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
rest_data_name = "data.rest"
accurate_data_name = "data.accurate"
detail_file_name_prefix = "details"
true_error_file_name = "true_error"
sys_name_fmt = "sys." + data_system_fmt
sys_name_pattern = "sys.[0-9]*[0-9]"

Expand Down Expand Up @@ -238,6 +239,18 @@
forward_files = [system_file_name]
backward_files = [detail_file_name]

f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf"))
e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf"))
if f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf"):
command_true_error = "{dp} model-devi -m {model} -s {system} -o {detail_file} --real_error".format(
dp=mdata.get("model_devi_command", "dp"),
model=" ".join(task_model_list),
system=system_file_name,
detail_file=true_error_file_name,
)
commands.append(command_true_error)
backward_files.append(true_error_file_name)

api_version = mdata.get("api_version", "1.0")
if Version(api_version) < Version("1.0"):
raise RuntimeError(
Expand Down Expand Up @@ -270,6 +283,11 @@
f_trust_hi = jdata["model_devi_f_trust_hi"]
e_trust_lo = jdata["model_devi_e_trust_lo"]
e_trust_hi = jdata["model_devi_e_trust_hi"]
f_trust_lo_err = jdata.get("true_error_f_trust_lo", float("inf"))
f_trust_hi_err = jdata.get("true_error_f_trust_hi", float("inf"))
e_trust_lo_err = jdata.get("true_error_e_trust_lo", float("inf"))
e_trust_hi_err = jdata.get("true_error_e_trust_hi", float("inf"))
use_true_error = f_trust_lo_err < float("inf") or e_trust_lo_err < float("inf")

type_map = jdata.get("type_map", [])
sys_accurate = dpdata.MultiSystems(type_map=type_map)
Expand All @@ -282,38 +300,86 @@
)

detail_file_name = detail_file_name_prefix
with open(os.path.join(work_path, detail_file_name)) as f:
for line in f:
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:
# DeePMD-kit < 2.2.2
cidx_devi_e = None
else:
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
if not use_true_error:
with open(os.path.join(work_path, detail_file_name)) as f:
for line in f:
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
try:
cidx_devi_e = columns.index("devi_e")
except ValueError:

Check warning on line 314 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L314

Added line #L314 was not covered by tests
# DeePMD-kit < 2.2.2
cidx_devi_e = None

Check warning on line 316 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L316

Added line #L316 was not covered by tests
else:
e_devi = 0.0
subsys = sys_entire[name][idx]
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
if cidx_devi_e is not None:
e_devi = float(line.split()[cidx_devi_e])
else:
e_devi = 0.0

Check warning on line 323 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L323

Added line #L323 was not covered by tests
subsys = sys_entire[name][idx]
if f_devi >= f_trust_hi or e_devi >= e_trust_hi:
sys_failed.append(subsys)
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
):
sys_candinate.append(subsys)
elif f_devi < f_trust_lo and e_devi < e_trust_lo:
sys_accurate.append(subsys)
else:
raise RuntimeError(

Check warning on line 335 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L335

Added line #L335 was not covered by tests
"reach a place that should NOT be reached..."
)
else:
with open(os.path.join(work_path, detail_file_name)) as f, open(
os.path.join(work_path, true_error_file_name)
) as f_err:
for line, line_err in zip(f, f_err):
if line.startswith("# data.rest.old"):
name = (line.split()[1]).split("/")[-1]
elif line.startswith("#"):
columns = line.split()[1:]
cidx_step = columns.index("step")
cidx_max_devi_f = columns.index("max_devi_f")
cidx_devi_e = columns.index("devi_e")
else:
raise RuntimeError("reach a place that should NOT be reached...")
idx = int(line.split()[cidx_step])
f_devi = float(line.split()[cidx_max_devi_f])
f_err = float(line_err.split()[cidx_max_devi_f])
e_devi = float(line.split()[cidx_devi_e])
e_err = float(line_err.split()[cidx_devi_e])

subsys = sys_entire[name][idx]
if (
f_devi >= f_trust_hi
or e_devi >= e_trust_hi
or f_err >= f_trust_hi_err
or e_err >= e_trust_hi_err
):
sys_failed.append(subsys)

Check warning on line 364 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L364

Added line #L364 was not covered by tests
elif (
f_trust_lo <= f_devi < f_trust_hi
or e_trust_lo <= e_devi < e_trust_hi
or f_trust_lo_err <= f_err < f_trust_hi_err
or e_trust_lo_err <= e_err < e_trust_hi_err
):
sys_candinate.append(subsys)
elif (

Check warning on line 372 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L372

Added line #L372 was not covered by tests
f_devi < f_trust_lo
and e_devi < e_trust_lo
and f_err < f_trust_lo_err
and e_err < e_trust_lo_err
):
sys_accurate.append(subsys)

Check warning on line 378 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L378

Added line #L378 was not covered by tests
else:
raise RuntimeError(

Check warning on line 380 in dpgen/simplify/simplify.py

View check run for this annotation

Codecov / codecov/patch

dpgen/simplify/simplify.py#L380

Added line #L380 was not covered by tests
"reach a place that should NOT be reached..."
)

counter = {
"candidate": sys_candinate.get_nframes(),
Expand Down
26 changes: 26 additions & 0 deletions tests/simplify/test_post_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ def setUp(self):
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)
np.savetxt(
self.work_path / "true_error",
model_devi,
fmt=["%12d"] + ["%19.6e" for _ in range(7)],
header="data.rest.old/"
+ self.system.formula
+ "\n step max_devi_v min_devi_v avg_devi_v max_devi_f min_devi_f avg_devi_f devi_e",
)

def tearDown(self):
shutil.rmtree("iter.000001", ignore_errors=True)
Expand Down Expand Up @@ -114,3 +122,21 @@ def test_post_model_devi_accurate(self):
{},
)
assert (self.work_path / "data.accurate" / self.system.formula).exists()

def test_post_model_devi_true_error_candidate(self):
dpgen.simplify.simplify.post_model_devi(
1,
{
"model_devi_e_trust_lo": 0.15,
"model_devi_e_trust_hi": 0.25,
"model_devi_f_trust_lo": float("inf"),
"model_devi_f_trust_hi": float("inf"),
"true_error_e_trust_lo": float("inf"),
"true_error_e_trust_hi": float("inf"),
"true_error_f_trust_lo": 0.15,
"true_error_f_trust_hi": 0.25,
"iter_pick_number": 1,
},
{},
)
assert (self.work_path / "data.picked" / self.system.formula).exists()
26 changes: 26 additions & 0 deletions tests/simplify/test_run_model_devi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,29 @@ def test_one_h5(self):
},
}
dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata)

def test_true_error(self):
jdata = {
"type_map": ["H"],
"true_error_f_trust_lo": 0.15,
"true_error_f_trust_hi": 0.25,
}
with tempfile.TemporaryDirectory() as remote_root:
mdata = {
"model_devi_command": (
f"test -d {dpgen.simplify.simplify.rest_data_name}.old"
f"&& touch {dpgen.simplify.simplify.detail_file_name_prefix}"
f"&& touch {dpgen.simplify.simplify.true_error_file_name}"
"&& echo dp"
),
"model_devi_machine": {
"context_type": "LocalContext",
"batch_type": "shell",
"local_root": "./",
"remote_root": remote_root,
},
"model_devi_resources": {
"group_size": 1,
},
}
dpgen.simplify.simplify.run_model_devi(0, jdata=jdata, mdata=mdata)