Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP Dataloader shuffle order not synchronize between different ddp workers #2268

Closed
2 of 4 tasks
Teoge opened this issue Dec 20, 2023 · 3 comments · Fixed by #2319
Closed
2 of 4 tasks

DDP Dataloader shuffle order not synchronize between different ddp workers #2268

Teoge opened this issue Dec 20, 2023 · 3 comments · Fixed by #2319

Comments

@Teoge
Copy link

Teoge commented Dec 20, 2023

System Info

accelerate version: 0.25.0
torch version: 1.12.1
accelerate's configuration:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: MULTI_GPU
downcast_bf16: 'no'
gpu_ids: all
machine_rank: 0
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

from accelerate import Accelerator
from torch.utils.data import DataLoader
import time

accelerator = Accelerator()

dataloader = DataLoader(list(range(24)), shuffle=True, batch_size=1)

dataloader = accelerator.prepare(dataloader)

for batch in dataloader:
    print(batch)
    time.sleep(1)

Expected behavior

I expect the dataset to be distributed to all ddp workers without replacement. But results show that the sampling order of different ddp workers are not the same, leading to repeated sampling, like "9" is sampled three times and "3" is sampled two times.

tensor([1], device='cuda:7')
tensor([9], device='cuda:3')
tensor([15], device='cuda:5')
tensor([3], device='cuda:1')
tensor([23], device='cuda:6')
tensor([16], device='cuda:2')
tensor([3], device='cuda:0')
tensor([14], device='cuda:4')
tensor([4], device='cuda:7')
tensor([12], device='cuda:3')
tensor([10], device='cuda:5')
tensor([0], device='cuda:1')
tensor([20], device='cuda:6')
tensor([13], device='cuda:2')
tensor([4], device='cuda:0')
tensor([19], device='cuda:4')
tensor([14], device='cuda:7')
tensor([7], device='cuda:3')
tensor([9], device='cuda:5')
tensor([9], device='cuda:1')
tensor([10], device='cuda:6')
tensor([10], device='cuda:2')
tensor([5], device='cuda:0')
tensor([8], device='cuda:4')
@Teoge Teoge changed the title DDP Dataloader shuffle order not synchronize between different devices DDP Dataloader shuffle order not synchronize between different ddp workers Dec 20, 2023
@SunMarc
Copy link
Member

SunMarc commented Dec 20, 2023

Hi @Teoge, thanks for reporting. I'm unable to reproduce the error. I'm using 5 processes with 25 elements.

tensor([2], device='cuda:4')
tensor([19], device='cuda:3')
tensor([18], device='cuda:4')
tensor([9], device='cuda:3')
tensor([11], device='cuda:4')
tensor([23], device='cuda:3')
tensor([4], device='cuda:4')
tensor([20], device='cuda:3')
tensor([0], device='cuda:4')
tensor([15], device='cuda:3')
tensor([3], device='cuda:2')
tensor([1], device='cuda:2')
tensor([14], device='cuda:2')
tensor([5], device='cuda:2')
tensor([13], device='cuda:2')
tensor([21], device='cuda:1')
tensor([10], device='cuda:0')
tensor([24], device='cuda:0')
tensor([7], device='cuda:0')
tensor([12], device='cuda:0')
tensor([22], device='cuda:0')
tensor([8], device='cuda:1')
tensor([16], device='cuda:1')
tensor([6], device='cuda:1')
tensor([17], device='cuda:1')

Do you know what might be happening @muellerzr ?

@Teoge
Copy link
Author

Teoge commented Dec 27, 2023

I am able to reproduce it on another machine. And it can only be reproduced in 0.25.0, not 0.24.0 or lower version.

@SunMarc
Copy link
Member

SunMarc commented Dec 27, 2023

Thanks for testing again @Teoge ! I can indeed reproduce it now. I was probably on another version when I first tested ! To add more details, it seems that it also only happens when suffle=True. We will fix this asap. cc @muellerzr

EDIT: I was able to find the PR #2126 that caused this. Also noticed that if you set the seed, the issue is solved but that's not intuitive at all:

from accelerate.utils import set_seed
set_seed(42)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants