diff --git a/options/train/CLIPIQA/train_CLIPIQA_koniq10k.yml b/options/train/CLIPIQA/train_CLIPIQA_koniq10k.yml index 6382ebb..278807c 100644 --- a/options/train/CLIPIQA/train_CLIPIQA_koniq10k.yml +++ b/options/train/CLIPIQA/train_CLIPIQA_koniq10k.yml @@ -2,12 +2,7 @@ # name: debug_DBCNN_LIVEC name: 002_CLIPIQA_ViT-L14_KonIQ10k name: 003_CLIPIQA_RN50_KonIQ10k -# name: debug_CLIPIQA -model_type: GeneralIQAModel -num_gpu: 1 # set num_gpu: 0 for cpu mode -manual_seed: 123 - -# name: debug_CLIPIQA +name: debug_CLIPIQA model_type: GeneralIQAModel num_gpu: 1 # set num_gpu: 0 for cpu mode manual_seed: 123 diff --git a/pyiqa/train.py b/pyiqa/train.py index 93469f1..bcf69f7 100644 --- a/pyiqa/train.py +++ b/pyiqa/train.py @@ -14,8 +14,10 @@ from pyiqa.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str, init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir, load_file_from_url) from pyiqa.utils.options import copy_opt_file, dict2str, parse_options +from pyiqa.utils.dist_util import master_only +@master_only def init_tb_loggers(opt): # initialize wandb logger before tensorboard logger to allow proper sync if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project') @@ -199,7 +201,7 @@ def train_pipeline(root_path, opt=None, args=None): log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()}) log_vars.update(model.get_current_log()) msg_logger(log_vars) - + # log images log_img_freq = opt['logger'].get('log_imgs_freq', 1e99) if current_iter % log_img_freq == 0: @@ -227,10 +229,15 @@ def train_pipeline(root_path, opt=None, args=None): data_timer.start() iter_timer.start() train_data = prefetcher.next() + + if 'debug' in opt['name'] and current_iter >= 8: + break # end of iter # use epoch based learning rate scheduler model.update_learning_rate(epoch+2, warmup_iter=opt['train'].get('warmup_iter', -1)) + if 'debug' in opt['name'] and epoch >= 2: + break # end of epoch consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time))) @@ -243,7 +250,8 @@ def train_pipeline(root_path, opt=None, args=None): if tb_logger: tb_logger.close() - return model.best_metric_results + if opt['rank'] == 0: + return model.best_metric_results if __name__ == '__main__': root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) diff --git a/pyiqa/utils/options.py b/pyiqa/utils/options.py index f5a30b7..4f3b051 100644 --- a/pyiqa/utils/options.py +++ b/pyiqa/utils/options.py @@ -107,7 +107,7 @@ def parse_options(root_path, is_train=True): parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') parser.add_argument('--auto_resume', action='store_true') parser.add_argument('--debug', action='store_true') - parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--local-rank', type=int, default=0) parser.add_argument( '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') args = parser.parse_args()