Skip to content

Commit

Permalink
reset worker_seed (#2111)
Browse files Browse the repository at this point in the history
* reset worker_seed

* fix isort

* minor fix

* fix comment
  • Loading branch information
yhcao6 committed Feb 22, 2020
1 parent c47e36a commit 5d75636
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions mmdet/datasets/loader/build_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ def build_dataloader(dataset,
Returns:
DataLoader: A PyTorch dataloader.
"""
rank, world_size = get_dist_info()
if dist:
rank, world_size = get_dist_info()
# DistributedGroupSampler will definitely shuffle the data to satisfy
# that images on each GPU are in the same group
if shuffle:
Expand All @@ -61,6 +61,13 @@ def build_dataloader(dataset,
batch_size = num_gpus * imgs_per_gpu
num_workers = num_gpus * workers_per_gpu

def worker_init_fn(worker_id):
# The seed of each worker equals to
# num_worker * rank + worker_id + user_seed
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)

data_loader = DataLoader(
dataset,
batch_size=batch_size,
Expand All @@ -72,8 +79,3 @@ def build_dataloader(dataset,
**kwargs)

return data_loader


def worker_init_fn(seed):
np.random.seed(seed)
random.seed(seed)

0 comments on commit 5d75636

Please sign in to comment.