Skip to content

Commit

Permalink
make student_model_path optional in distillation (#173)
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <liuxin_zijian@163.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
zjgemi and pre-commit-ci[bot] committed Oct 10, 2023
1 parent 8e8bc16 commit 2bcc460
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 1 addition & 3 deletions dpgen2/entrypoint/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ def dp_dist_train_args():
Argument(
"template_script", [list, str], optional=False, doc=doc_template_script
),
Argument(
"student_model_path", str, optional=False, doc=dock_student_model_path
),
Argument("student_model_path", str, optional=True, doc=dock_student_model_path),
]


Expand Down
6 changes: 5 additions & 1 deletion dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,11 @@ def workflow_concurrent_learning(
"not match numb_models={numb_models}"
)
elif train_style == "dp-dist" and not old_style:
init_models_paths = [config["train"].get("student_model_path", None)]
init_models_paths = (
[config["train"]["student_model_path"]]
if "student_model_path" in config["train"]
else None
)
config["train"]["numb_models"] = 1
else:
raise RuntimeError(
Expand Down

0 comments on commit 2bcc460

Please sign in to comment.