Skip to content

Commit

Permalink
Merge 70d6b6f into cb7e04d
Browse files Browse the repository at this point in the history
  • Loading branch information
C1rN09 committed Mar 1, 2023
2 parents cb7e04d + 70d6b6f commit 09b696a
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions mmengine/dist/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
'nccl', 'gloo' and 'mpi'. Defaults to 'nccl'.
**kwargs: keyword arguments are passed to ``init_process_group``.
"""
# TODO: use local_rank instead of rank % num_gpus
rank = int(os.environ['RANK'])
if is_mlu_available():
import torch_mlu # noqa: F401
Expand All @@ -103,8 +102,9 @@ def _init_dist_pytorch(backend, **kwargs) -> None:
world_size=int(os.environ['WORLD_SIZE']),
**kwargs)
else:
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(rank % num_gpus)
# LOCAL_RANK is set by `torch.distributed.launch` since PyTorch 1.1
local_rank = int(os.environ['LOCAL_RANK'])
torch.cuda.set_device(local_rank)
torch_dist.init_process_group(backend=backend, **kwargs)


Expand Down Expand Up @@ -151,8 +151,14 @@ def _init_dist_slurm(backend, port=None) -> None:
proc_id = int(os.environ['SLURM_PROCID'])
ntasks = int(os.environ['SLURM_NTASKS'])
node_list = os.environ['SLURM_NODELIST']
num_gpus = torch.cuda.device_count()
torch.cuda.set_device(proc_id % num_gpus)
# Not sure when this environment variable could be None, so use a fallback
local_rank_env = os.environ.get('SLURM_LOCALID', None)
if local_rank_env is not None:
local_rank = int(local_rank_env)
else:
num_gpus = torch.cuda.device_count()
local_rank = proc_id % num_gpus
torch.cuda.set_device(local_rank)
addr = subprocess.getoutput(
f'scontrol show hostname {node_list} | head -n1')
# specify master port
Expand All @@ -167,7 +173,7 @@ def _init_dist_slurm(backend, port=None) -> None:
if 'MASTER_ADDR' not in os.environ:
os.environ['MASTER_ADDR'] = addr
os.environ['WORLD_SIZE'] = str(ntasks)
os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
os.environ['LOCAL_RANK'] = str(local_rank)
os.environ['RANK'] = str(proc_id)
torch_dist.init_process_group(backend=backend)

Expand Down

0 comments on commit 09b696a

Please sign in to comment.