Skip to content

Commit

Permalink
[Fix] add seed to distributed sampler (#250)
Browse files Browse the repository at this point in the history
* [Fix] add seed to distributed sampler

* fix lint
  • Loading branch information
fangyixiao18 authored and YuanLiuuuuuu committed Mar 31, 2022
1 parent c2b1dc4 commit b08271f
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 6 deletions.
7 changes: 6 additions & 1 deletion mmselfsup/datasets/builder.py
Expand Up @@ -108,7 +108,12 @@ def build_dataloader(dataset,
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, replace=replace)
dataset,
world_size,
rank,
shuffle=shuffle,
replace=replace,
seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
Expand Down
19 changes: 17 additions & 2 deletions mmselfsup/datasets/samplers/distributed_sampler.py
Expand Up @@ -5,6 +5,8 @@
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler

from mmselfsup.utils import sync_random_seed


class DistributedSampler(_DistributedSampler):

Expand All @@ -13,12 +15,21 @@ def __init__(self,
num_replicas=None,
rank=None,
shuffle=True,
replace=False):
replace=False,
seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
self.replace = replace
self.unif_sampling_flag = False

# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
if not self.unif_sampling_flag:
Expand All @@ -31,7 +42,11 @@ def __iter__(self):
def generate_new_list(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
if self.replace:
indices = torch.randint(
low=0,
Expand Down
8 changes: 5 additions & 3 deletions mmselfsup/utils/__init__.py
Expand Up @@ -3,6 +3,7 @@
from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp
from .collect import dist_forward_collect, nondist_forward_collect
from .collect_env import collect_env
from .dist_utils import sync_random_seed
from .distributed_sinkhorn import distributed_sinkhorn
from .extractor import Extractor
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
Expand All @@ -14,7 +15,8 @@
__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'distributed_sinkhorn', 'Extractor', 'concat_all_gather', 'gather_tensors',
'gather_tensors_batch', 'get_root_logger', 'find_latest_checkpoint',
'multi_gpu_test', 'single_gpu_test', 'setup_multi_processes'
'sync_random_seed', 'distributed_sinkhorn', 'Extractor',
'concat_all_gather', 'gather_tensors', 'gather_tensors_batch',
'get_root_logger', 'find_latest_checkpoint', 'multi_gpu_test',
'single_gpu_test', 'setup_multi_processes'
]
44 changes: 44 additions & 0 deletions mmselfsup/utils/dist_utils.py
@@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
References:
.. [1] https://github.com/open-mmlab/mmdetection
/blob/master/mmdet/core/utils/dist_utils.py
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()

0 comments on commit b08271f

Please sign in to comment.