Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions dpgen2/exploration/task/lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,19 @@ def set_lmp(
plm_template_fname: Optional[str] = None,
revisions: dict = {},
traj_freq: int = 10,
extra_pair_style_args: str = "",
) -> None:
self.lmp_template = Path(lmp_template_fname).read_text().split("\n")
self.revisions = revisions
self.traj_freq = traj_freq
self.extra_pair_style_args = extra_pair_style_args
self.lmp_set = True
self.model_list = sorted([model_name_pattern % ii for ii in range(numb_models)])
self.lmp_template = revise_lmp_input_model(
self.lmp_template, self.model_list, self.traj_freq
self.lmp_template,
self.model_list,
self.traj_freq,
self.extra_pair_style_args,
)
self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq)
if plm_template_fname is not None:
Expand Down Expand Up @@ -138,12 +143,20 @@ def find_only_one_key(lmp_lines, key):
return found[0]


def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version="1"):
def revise_lmp_input_model(
lmp_lines, task_model_list, trj_freq, extra_pair_style_args="", deepmd_version="1"
):
idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"])
if extra_pair_style_args:
extra_pair_style_args = " " + extra_pair_style_args
graph_list = " ".join(task_model_list)
lmp_lines[idx] = "pair_style deepmd %s out_freq %d out_file model_devi.out" % (
graph_list,
trj_freq,
lmp_lines[idx] = (
"pair_style deepmd %s out_freq %d out_file model_devi.out%s"
% (
graph_list,
trj_freq,
extra_pair_style_args,
)
)
return lmp_lines

Expand Down
10 changes: 9 additions & 1 deletion dpgen2/exploration/task/make_task_group_from_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def lmp_template_task_group_args():
doc_plm_template_fname = "The file name of plumed input template"
doc_revisions = "The revisions. Should be a dict providing the key - list of desired values pair. Key is the word to be replaced in the templates, and it may appear in both the lammps and plumed input templates. All values in the value list will be enmerated."
doc_traj_freq = "The frequency of dumping configurations and thermodynamic states"
doc_extra_pair_style_args = "The extra arguments for pair_style"

return [
Argument("conf_idx", list, optional=False, doc=doc_conf_idx, alias=["sys_idx"]),
Expand All @@ -125,7 +126,7 @@ def lmp_template_task_group_args():
doc=doc_plm_template_fname,
alias=["plm_template", "plm"],
),
Argument("revisions", dict, optional=True, default={}),
Argument("revisions", dict, optional=True, default={}, doc=doc_revisions),
Argument(
"traj_freq",
int,
Expand All @@ -134,6 +135,13 @@ def lmp_template_task_group_args():
doc=doc_traj_freq,
alias=["t_freq", "trj_freq", "trj_freq"],
),
Argument(
"extra_pair_style_args",
str,
optional=True,
default="",
doc=doc_extra_pair_style_args,
),
]


Expand Down