From 166110da3b92f64234d6e84c99a00f47fbcf2ce8 Mon Sep 17 00:00:00 2001 From: zhezhaoa <1152543959@qq.com> Date: Wed, 1 May 2024 18:36:45 +0800 Subject: [PATCH] Update config.py --- uer/utils/config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/uer/utils/config.py b/uer/utils/config.py index 6f221a37..e247cace 100644 --- a/uer/utils/config.py +++ b/uer/utils/config.py @@ -3,21 +3,21 @@ from argparse import Namespace -def load_hyperparam(default_args): +def load_hyperparam(args): """ Load arguments form argparse and config file Priority: default options < config file < command line args """ - with open(default_args.config_path, mode="r", encoding="utf-8") as f: + with open(dargs.config_path, mode="r", encoding="utf-8") as f: config_args_dict = json.load(f) - default_args_dict = vars(default_args) + args_dict = vars(args) - command_line_args_dict = {k: default_args_dict[k] for k in [ + command_line_args_dict = {k: args_dict[k] for k in [ a[2:] for a in sys.argv if (a[:2] == "--" and "local_rank" not in a) ]} - default_args_dict.update(config_args_dict) - default_args_dict.update(command_line_args_dict) - args = Namespace(**default_args_dict) + args_dict.update(config_args_dict) + args_dict.update(command_line_args_dict) + args = Namespace(**args_dict) return args