Skip to content

Commit

Permalink
Fix bugs (#107)
Browse files Browse the repository at this point in the history
1. init_data_sys can be str
2. not overwriting the batch size in dp input script
3. update dependencies.

Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
  • Loading branch information
wanghan-iapcm and Han Wang committed Dec 29, 2022
1 parent 22e9b13 commit 9a722c1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def input_args():
Argument("type_map", list, optional=False, doc=doc_type_map),
Argument("mass_map", list, optional=False, doc=doc_mass_map),
Argument("init_data_prefix", str, optional=True, default=None, doc=doc_init_data_prefix),
Argument("init_data_sys", list, optional=False, default=None, doc=doc_init_sys),
Argument("init_data_sys", [list,str], optional=False, default=None, doc=doc_init_sys),
]


Expand Down
4 changes: 2 additions & 2 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def write_data_to_input_script(
if major_version == "1":
# v1 behavior
odict['training']['systems'] = data_list
odict['training']['batch_size'] = "auto"
odict['training'].setdefault('batch_size', "auto")
odict['training']['auto_prob_style'] = auto_prob_str
elif major_version == "2":
# v2 behavior
odict['training']['training_data']['systems'] = data_list
odict['training']['training_data']['batch_size'] = "auto"
odict['training']['training_data'].setdefault('batch_size', "auto")
odict['training']['training_data']['auto_prob'] = auto_prob_str
odict['training'].pop('validation_data', None)
else:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ classifiers = [
dependencies = [
'numpy',
'dpdata',
'pydflow>=1.6.18',
'pydflow>=1.6.23',
'dargs>=0.3.1',
'scipy',
'lbg',
]
requires-python = ">=3.7"
readme = "README.md"
Expand Down

0 comments on commit 9a722c1

Please sign in to comment.