Skip to content

Commit

Permalink
fix template script that should be loaded to a dict from a json file (#…
Browse files Browse the repository at this point in the history
…150)

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Mar 19, 2023
1 parent 8e03205 commit 45bea74
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 5 deletions.
4 changes: 2 additions & 2 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

def dp_dist_train_args():
doc_config = "Configuration of training"
doc_template_script = "File names of the template training script. It can be a `List[Dict]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
doc_template_script = "File names of the template training script. It can be a `List[str]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
dock_student_model_path = "The path of student model"

return [
Expand All @@ -55,7 +55,7 @@ def dp_dist_train_args():
def dp_train_args():
doc_numb_models = "Number of models trained for evaluating the model deviation"
doc_config = "Configuration of training"
doc_template_script = "File names of the template training script. It can be a `List[Dict]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
doc_template_script = "File names of the template training script. It can be a `List[str]`, the length of which is the same as `numb_models`. Each template script in the list is used to train a model. Can be a `str`, the models share the same template training script. "
doc_init_models_paths = "the paths to initial models"

return [
Expand Down
5 changes: 3 additions & 2 deletions dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import glob
import json
import logging
import os
import pickle
Expand Down Expand Up @@ -426,9 +427,9 @@ def workflow_concurrent_learning(
else config["train"]["template_script"]
)
if isinstance(template_script_, list):
template_script = [Path(ii).read_text() for ii in template_script_]
template_script = [json.loads(Path(ii).read_text()) for ii in template_script_]
else:
template_script = Path(template_script_).read_text()
template_script = json.loads(Path(template_script_).read_text())
train_config = {} if old_style else config["train"]["config"]
lmp_config = (
config.get("lmp_config", {}) if old_style else config["explore"]["config"]
Expand Down
6 changes: 5 additions & 1 deletion tests/entrypoint/test_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,7 @@ def setUp(self):
config["mode"] = "debug"
self.touched_files = [
"foo",
"foo1",
"init",
"bar",
"tar",
Expand All @@ -350,6 +351,8 @@ def setUp(self):
]
for ii in self.touched_files:
Path(ii).touch()
Path("foo").write_text("{}")
Path("foo1").write_text("{}")

def tearDown(self):
from dflow import (
Expand Down Expand Up @@ -381,6 +384,7 @@ def setUp(self):
for ii in self.touched_files:
Path(ii).touch()
Path("POSCAR").write_text(ifc0)
Path("foo").write_text("{}")

def tearDown(self):
from dflow import (
Expand Down Expand Up @@ -457,7 +461,7 @@ def test(self):
"type" : "dp",
"numb_models" : 2,
"config" : {},
"template_script" : "foo",
"template_script" : ["foo", "foo1"],
"init_models_paths" : ["bar", "tar"],
"_comment" : "all"
},
Expand Down

0 comments on commit 45bea74

Please sign in to comment.