Skip to content

Commit

Permalink
[Enchance] Support sync random seed for distributed sampler (#57)
Browse files Browse the repository at this point in the history
* [Docs] update batch size

* add sync seed

* add sync seed

* update comments
  • Loading branch information
linyq17 committed Mar 30, 2022
1 parent 1c4c270 commit f871f3c
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 9 deletions.
3 changes: 2 additions & 1 deletion mmfewshot/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collate import multi_pipeline_collate_fn
from .dist_utils import check_dist_init, sync_random_seed
from .infinite_sampler import (DistributedInfiniteGroupSampler,
DistributedInfiniteSampler,
InfiniteGroupSampler, InfiniteSampler)
Expand All @@ -11,5 +12,5 @@
'multi_pipeline_collate_fn', 'local_numpy_seed',
'InfiniteEpochBasedRunner', 'InfiniteSampler', 'InfiniteGroupSampler',
'DistributedInfiniteSampler', 'DistributedInfiniteGroupSampler',
'get_root_logger'
'get_root_logger', 'check_dist_init', 'sync_random_seed'
]
45 changes: 45 additions & 0 deletions mmfewshot/utils/dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import get_dist_info


def check_dist_init():
return dist.is_available() and dist.is_initialized()


def sync_random_seed(seed=None, device='cuda'):
"""Propagating the seed of rank 0 to all other ranks.
Make sure different ranks share the same seed. All workers must call
this function, otherwise it will deadlock. This method is generally used in
`DistributedSampler`, because the seed should be identical across all
processes in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
42 changes: 34 additions & 8 deletions mmfewshot/utils/infinite_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler

from .dist_utils import sync_random_seed


class InfiniteSampler(Sampler):
"""Return a infinite stream of index.
Expand All @@ -28,7 +30,13 @@ def __init__(self,
seed: int = 0,
shuffle: bool = True) -> None:
self.dataset = dataset
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle
self.size = len(dataset)
self.indices = self._indices()
Expand All @@ -37,7 +45,7 @@ def __init__(self,
def _infinite_indices(self) -> Iterator:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
g.manual_seed(self.seed + self.epoch)
while True:
if self.shuffle:
yield from torch.randperm(self.size, generator=g).tolist()
Expand Down Expand Up @@ -89,7 +97,13 @@ def __init__(self,
shuffle: bool = True) -> None:
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle

assert hasattr(self.dataset, 'flag')
Expand All @@ -105,7 +119,7 @@ def __init__(self,
def _infinite_indices(self) -> Iterator:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
g.manual_seed(self.seed + self.epoch)
while True:
if self.shuffle:
yield from torch.randperm(self.size, generator=g).tolist()
Expand Down Expand Up @@ -168,7 +182,13 @@ def __init__(self,
self.rank = rank
self.num_replicas = num_replicas
self.dataset = dataset
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle
self.size = len(dataset)
self.indices = self._indices_of_rank()
Expand All @@ -177,7 +197,7 @@ def __init__(self,
def _infinite_indices(self) -> Iterator:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
g.manual_seed(self.seed + self.epoch)
while True:
if self.shuffle:
indices = []
Expand Down Expand Up @@ -244,7 +264,13 @@ def __init__(self,
self.num_replicas = num_replicas
self.dataset = dataset
self.samples_per_gpu = samples_per_gpu
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle

assert hasattr(self.dataset, 'flag')
Expand All @@ -260,7 +286,7 @@ def __init__(self,
def _infinite_indices(self) -> Iterator:
"""Infinitely yield a sequence of indices."""
g = torch.Generator()
g.manual_seed(self.seed)
g.manual_seed(self.seed + self.epoch)
while True:
if self.shuffle:
indices = []
Expand Down

0 comments on commit f871f3c

Please sign in to comment.