Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] DataLoader always return batches in the same order (Expected Random) #2157

Closed
2 of 4 tasks
allanj opened this issue Nov 16, 2023 · 8 comments
Closed
2 of 4 tasks
Assignees

Comments

@allanj
Copy link

allanj commented Nov 16, 2023

System Info

`accelerate==0.24.1`

The error does not happen with accelerate `0.23.0`

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

Using the following example script

import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
from accelerate import Accelerator

class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        return item


# Example data
data = np.random.randn(10, 5)  # 100 samples, 10 features each

# Convert to PyTorch tensors
data = torch.tensor(data, dtype=torch.float32)
from accelerate.utils import set_seed

accelerator = Accelerator()
dataset = CustomDataset(data)

# Create data loader
data_loader = DataLoader(dataset, batch_size=3, shuffle=True)
data_loader = accelerator.prepare(data_loader)
epochs = 2
for ep in range(2):
    for batch_idx, data in enumerate(data_loader):
        if accelerator.is_main_process:
            print(data)
    
    if accelerator.is_main_process:
        print("end batch")

Run the above in a distributed environment by

accelerate launch --main_process_port 9205 example.py

Expected behavior

It will output something like this, suppose we want different epochs have different order of batches, and different batches

But now is "same batch order" and "same batches" for two epochs

tensor([[ 0.7063, -0.0190, -2.6028,  0.7356,  0.9286],
        [ 0.1154, -1.0464, -0.0675, -0.3442,  0.9731],
        [ 0.8266, -1.0352, -0.6741, -0.8682, -0.5972]], device='cuda:0')
tensor([[ 0.8992, -1.0027,  0.1585,  1.6470, -1.2425],
        [-0.3033,  0.3626, -0.8370,  0.0797,  0.5549],
        [-1.3101,  0.2919, -0.5106,  0.7952,  2.1747]], device='cuda:0')
end batch
tensor([[ 0.7063, -0.0190, -2.6028,  0.7356,  0.9286],
        [ 0.1154, -1.0464, -0.0675, -0.3442,  0.9731],
        [ 0.8266, -1.0352, -0.6741, -0.8682, -0.5972]], device='cuda:0')
tensor([[ 0.8992, -1.0027,  0.1585,  1.6470, -1.2425],
        [-0.3033,  0.3626, -0.8370,  0.0797,  0.5549],
        [-1.3101,  0.2919, -0.5106,  0.7952,  2.1747]], device='cuda:0')
end batch

Quick fixes that we tried:

  1. the data sampler seems broken, we tried the following should work.
data_loader.batch_sampler.batch_sampler.sampler.set_epoch(ep)
  1. We also tried 0.23.0 accelerate version also work
  2. Or simply commented the following code also work
    if isinstance(sampler, RandomSampler):
    # When iterating through the dataloader during distributed processes
    # we want to ensure that on each process we are iterating through the same
    # samples in the same order if a seed is set. This requires a tweak
    # to the `torch.utils.data.RandomSampler` class (if used).
    sampler = SeedableRandomSampler(
    data_source=sampler.data_source,
    replacement=sampler.replacement,
    num_samples=sampler._num_samples,
    generator=getattr(sampler, "generator", torch.Generator()),
    )
@allanj
Copy link
Author

allanj commented Nov 16, 2023

I feel like it would be great to fix this quickly as it seems affect a lot of previous experiments that use the newer version of accelerate

@hyang0511
Copy link

hyang0511 commented Nov 16, 2023

I have the same issue with accelerate version 0.24.1.
It would be great if this could be fixed ASAP.

@muellerzr muellerzr self-assigned this Nov 16, 2023
@allanj
Copy link
Author

allanj commented Nov 18, 2023

I checked the master branch, the output seems fine. Does it mean the master branch fix it? (but may not be an elegant way

@muellerzr
Copy link
Collaborator

Thanks for the update! We’ll be pushing a new release next week that should include the fix. #2126

@ashawkey
Copy link

ashawkey commented Nov 30, 2023

Can confirm the master branch fixes this problem, while the latest release 0.24.1 still doesn't.
It's quite an important issue and hope it should be fixed soon.

@allanj
Copy link
Author

allanj commented Nov 30, 2023

Yes. For now, I think I switch to 0.23.0, it seems they haven't fixed this issue to update the version.

@muellerzr
Copy link
Collaborator

Sorry for the delay on this, a release will be happening Friday with the fix included in it!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jan 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants