Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] add seed to distributed sampler #250

Merged
merged 2 commits into from Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion mmselfsup/datasets/builder.py
Expand Up @@ -108,7 +108,12 @@ def build_dataloader(dataset,
rank, world_size = get_dist_info()
if dist:
sampler = DistributedSampler(
dataset, world_size, rank, shuffle=shuffle, replace=replace)
dataset,
world_size,
rank,
shuffle=shuffle,
replace=replace,
seed=seed)
shuffle = False
batch_size = samples_per_gpu
num_workers = workers_per_gpu
Expand Down
19 changes: 17 additions & 2 deletions mmselfsup/datasets/samplers/distributed_sampler.py
Expand Up @@ -5,6 +5,8 @@
from torch.utils.data import DistributedSampler as _DistributedSampler
from torch.utils.data import Sampler

from mmselfsup.utils import sync_random_seed


class DistributedSampler(_DistributedSampler):

Expand All @@ -13,12 +15,21 @@ def __init__(self,
num_replicas=None,
rank=None,
shuffle=True,
replace=False):
replace=False,
seed=0):
super().__init__(dataset, num_replicas=num_replicas, rank=rank)
self.shuffle = shuffle
self.replace = replace
self.unif_sampling_flag = False

# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
if not self.unif_sampling_flag:
Expand All @@ -31,7 +42,11 @@ def __iter__(self):
def generate_new_list(self):
if self.shuffle:
g = torch.Generator()
g.manual_seed(self.epoch)
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
if self.replace:
indices = torch.randint(
low=0,
Expand Down
8 changes: 5 additions & 3 deletions mmselfsup/utils/__init__.py
Expand Up @@ -3,6 +3,7 @@
from .batch_shuffle import batch_shuffle_ddp, batch_unshuffle_ddp
from .collect import dist_forward_collect, nondist_forward_collect
from .collect_env import collect_env
from .dist_utils import sync_random_seed
from .distributed_sinkhorn import distributed_sinkhorn
from .extractor import Extractor
from .gather import concat_all_gather, gather_tensors, gather_tensors_batch
Expand All @@ -14,7 +15,8 @@
__all__ = [
'AliasMethod', 'batch_shuffle_ddp', 'batch_unshuffle_ddp',
'dist_forward_collect', 'nondist_forward_collect', 'collect_env',
'distributed_sinkhorn', 'Extractor', 'concat_all_gather', 'gather_tensors',
'gather_tensors_batch', 'get_root_logger', 'find_latest_checkpoint',
'multi_gpu_test', 'single_gpu_test', 'setup_multi_processes'
'sync_random_seed', 'distributed_sinkhorn', 'Extractor',
'concat_all_gather', 'gather_tensors', 'gather_tensors_batch',
'get_root_logger', 'find_latest_checkpoint', 'multi_gpu_test',
'single_gpu_test', 'setup_multi_processes'
]
44 changes: 44 additions & 0 deletions mmselfsup/utils/dist_utils.py
@@ -0,0 +1,44 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.

In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.

Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
References:
.. [1] https://github.com/open-mmlab/mmdetection
/blob/master/mmdet/core/utils/dist_utils.py
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()