diff --git a/dpgen/generator/arginfo.py b/dpgen/generator/arginfo.py index 0ad63f22e..d13d7f6c7 100644 --- a/dpgen/generator/arginfo.py +++ b/dpgen/generator/arginfo.py @@ -154,28 +154,28 @@ def training_args() -> list[Argument]: [None, int], alias=["training_reuse_stop_batch"], optional=True, - default=400000, + default=None, doc=doc_training_reuse_numb_steps, ), Argument( "training_reuse_start_lr", [None, float], optional=True, - default=1e-4, + default=None, doc=doc_training_reuse_start_lr, ), Argument( "training_reuse_start_pref_e", [None, float, int], optional=True, - default=0.1, + default=None, doc=doc_training_reuse_start_pref_e, ), Argument( "training_reuse_start_pref_f", [None, float, int], optional=True, - default=100, + default=None, doc=doc_training_reuse_start_pref_f, ), Argument( diff --git a/dpgen/generator/run.py b/dpgen/generator/run.py index d773cb861..cc5a9fc44 100644 --- a/dpgen/generator/run.py +++ b/dpgen/generator/run.py @@ -276,11 +276,11 @@ def make_train(iter_index, jdata, mdata): elif "training_reuse_numb_steps" in jdata.keys(): training_reuse_stop_batch = jdata["training_reuse_numb_steps"] else: - training_reuse_stop_batch = 400000 + training_reuse_stop_batch = None - training_reuse_start_lr = jdata.get("training_reuse_start_lr", 1e-4) - training_reuse_start_pref_e = jdata.get("training_reuse_start_pref_e", 0.1) - training_reuse_start_pref_f = jdata.get("training_reuse_start_pref_f", 100) + training_reuse_start_lr = jdata.get("training_reuse_start_lr") + training_reuse_start_pref_e = jdata.get("training_reuse_start_pref_e") + training_reuse_start_pref_f = jdata.get("training_reuse_start_pref_f") model_devi_activation_func = jdata.get("model_devi_activation_func", None) auto_ratio = False @@ -509,11 +509,18 @@ def make_train(iter_index, jdata, mdata): raise RuntimeError( "Unsupported DeePMD-kit version: %s" % mdata["deepmd_version"] ) - if jinput["loss"].get("start_pref_e") is not None: + if ( + jinput["loss"].get("start_pref_e") is not None + and training_reuse_start_pref_e is not None + ): jinput["loss"]["start_pref_e"] = training_reuse_start_pref_e - if jinput["loss"].get("start_pref_f") is not None: + if ( + jinput["loss"].get("start_pref_f") is not None + and training_reuse_start_pref_f is not None + ): jinput["loss"]["start_pref_f"] = training_reuse_start_pref_f - jinput["learning_rate"]["start_lr"] = training_reuse_start_lr + if training_reuse_start_lr is not None: + jinput["learning_rate"]["start_lr"] = training_reuse_start_lr input_files = [] for ii in range(numb_models):