From 2c7a8af32f8d81069f40d87a368b23f26fbc3887 Mon Sep 17 00:00:00 2001 From: Yucheng Zhao Date: Fri, 21 Oct 2022 22:00:13 +0800 Subject: [PATCH] [Fix] Fix bugs about cfg.gpu_ids in distributed training (#745) Co-authored-by: Yucheng Zhao --- tools/train.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tools/train.py b/tools/train.py index 0bc503460..269820cd6 100644 --- a/tools/train.py +++ b/tools/train.py @@ -10,7 +10,7 @@ import torch import torch.distributed as dist from mmcv import Config, DictAction -from mmcv.runner import init_dist +from mmcv.runner import get_dist_info, init_dist from mmdet.apis import set_random_seed from mmtrack import __version__ @@ -135,6 +135,9 @@ def main(): else: distributed = True init_dist(args.launcher, **cfg.dist_params) + # gpu_ids is used to calculate iter when resuming checkpoint, + _, world_size = get_dist_info() + cfg.gpu_ids = range(world_size) # create work_dir mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))