Skip to content

Commit

Permalink
fix: 📝 update distributed train
Browse files Browse the repository at this point in the history
  • Loading branch information
chaofengc committed Apr 25, 2024
1 parent 7a006e3 commit f1c7005
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
7 changes: 1 addition & 6 deletions options/train/CLIPIQA/train_CLIPIQA_koniq10k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,7 @@
# name: debug_DBCNN_LIVEC
name: 002_CLIPIQA_ViT-L14_KonIQ10k
name: 003_CLIPIQA_RN50_KonIQ10k
# name: debug_CLIPIQA
model_type: GeneralIQAModel
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 123

# name: debug_CLIPIQA
name: debug_CLIPIQA
model_type: GeneralIQAModel
num_gpu: 1 # set num_gpu: 0 for cpu mode
manual_seed: 123
Expand Down
12 changes: 10 additions & 2 deletions pyiqa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
from pyiqa.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir, load_file_from_url)
from pyiqa.utils.options import copy_opt_file, dict2str, parse_options
from pyiqa.utils.dist_util import master_only


@master_only
def init_tb_loggers(opt):
# initialize wandb logger before tensorboard logger to allow proper sync
if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
Expand Down Expand Up @@ -199,7 +201,7 @@ def train_pipeline(root_path, opt=None, args=None):
log_vars.update({'time': iter_timer.get_avg_time(), 'data_time': data_timer.get_avg_time()})
log_vars.update(model.get_current_log())
msg_logger(log_vars)

# log images
log_img_freq = opt['logger'].get('log_imgs_freq', 1e99)
if current_iter % log_img_freq == 0:
Expand Down Expand Up @@ -227,10 +229,15 @@ def train_pipeline(root_path, opt=None, args=None):
data_timer.start()
iter_timer.start()
train_data = prefetcher.next()

if 'debug' in opt['name'] and current_iter >= 8:
break
# end of iter
# use epoch based learning rate scheduler
model.update_learning_rate(epoch+2, warmup_iter=opt['train'].get('warmup_iter', -1))

if 'debug' in opt['name'] and epoch >= 2:
break
# end of epoch

consumed_time = str(datetime.timedelta(seconds=int(time.time() - start_time)))
Expand All @@ -243,7 +250,8 @@ def train_pipeline(root_path, opt=None, args=None):
if tb_logger:
tb_logger.close()

return model.best_metric_results
if opt['rank'] == 0:
return model.best_metric_results

if __name__ == '__main__':
root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
Expand Down
2 changes: 1 addition & 1 deletion pyiqa/utils/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def parse_options(root_path, is_train=True):
parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
parser.add_argument('--auto_resume', action='store_true')
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument('--local-rank', type=int, default=0)
parser.add_argument(
'--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
args = parser.parse_args()
Expand Down

0 comments on commit f1c7005

Please sign in to comment.