Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix broadcast deadlock for incomplete batches in data sample for data analysis #5117

Conversation

bm-synth
Copy link
Contributor

@bm-synth bm-synth commented Feb 12, 2024

When the batch is not a full batch (drop_last=False), then the size of the current batch is smaller than the expected:

self.global_batch_size = self.micro_batch_times_data_parallel_size * self.gradient_accumulation_steps

The get_next_global_batch() method will try to broadcast the tensor of a size smaller than self.global_batch_size from a master rank (0). However, in this case, the master rank will send a shorter tensor. This leads to an unexpected behaviour (deadlock, crash, or None tensor on receiving ranks). The documentation for the broadcast operation says "tensor must have the same number of elements in all processes participating in the collective." In the following call, tensor can have different sizes when comparing master with other participant ranks. File deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py, like 289:

dist.broadcast(batch, 0, group=self.data_parallel_group)

This PR fixes that bug, by filling incomplete batch indices with -1 so that the batch tensor is always of the same size.

Note: an alternative resolution is to broadcast beforehand the size of the batches tensor, but adds an extra comm step. The current method of extending the batch tensor with -1s is memory-safe as the batch tensor will match the one used in previous iterations with a full batch.

@bm-synth bm-synth marked this pull request as ready for review February 12, 2024 14:45
@conglongli conglongli added this pull request to the merge queue Feb 15, 2024
Merged via the queue into microsoft:master with commit 2b41110 Feb 15, 2024
12 checks passed
@bm-synth bm-synth deleted the fix_broadcast_deadlock_for_incomplete_batches branch February 15, 2024 14:37
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
… analysis (microsoft#5117)

When the batch is not a full batch (`drop_last=False`), then the size of
the current batch is smaller than the expected:
```
self.global_batch_size = self.micro_batch_times_data_parallel_size * self.gradient_accumulation_steps
```

The `get_next_global_batch()` method will try to broadcast the tensor of
a size smaller than `self.global_batch_size` from a master rank (`0`).
However, in this case, the master rank will send a shorter tensor. This
leads to an unexpected behaviour (deadlock, crash, or `None` tensor on
receiving ranks). The documentation for the
[broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast)
operation says "tensor must have the same number of elements in all
processes participating in the collective." In the following call,
`tensor` can have different sizes when comparing master with other
participant ranks. File
`deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py`, like
`289`:
```
dist.broadcast(batch, 0, group=self.data_parallel_group)
```

This PR fixes that bug, by filling incomplete batch indices with `-1` so
that the batch tensor is always of the same size.

Note: an alternative resolution is to broadcast beforehand the size of
the batches tensor, but adds an extra comm step. The current method of
extending the `batch` tensor with `-1`s is memory-safe as the batch
tensor will match the one used in previous iterations with a full batch.
rraminen pushed a commit to ROCm/DeepSpeed that referenced this pull request May 9, 2024
… analysis (microsoft#5117)

When the batch is not a full batch (`drop_last=False`), then the size of
the current batch is smaller than the expected:
```
self.global_batch_size = self.micro_batch_times_data_parallel_size * self.gradient_accumulation_steps
```

The `get_next_global_batch()` method will try to broadcast the tensor of
a size smaller than `self.global_batch_size` from a master rank (`0`).
However, in this case, the master rank will send a shorter tensor. This
leads to an unexpected behaviour (deadlock, crash, or `None` tensor on
receiving ranks). The documentation for the
[broadcast](https://pytorch.org/docs/stable/distributed.html#torch.distributed.broadcast)
operation says "tensor must have the same number of elements in all
processes participating in the collective." In the following call,
`tensor` can have different sizes when comparing master with other
participant ranks. File
`deepspeed/runtime/data_pipeline/data_sampling/data_sampler.py`, like
`289`:
```
dist.broadcast(batch, 0, group=self.data_parallel_group)
```

This PR fixes that bug, by filling incomplete batch indices with `-1` so
that the batch tensor is always of the same size.

Note: an alternative resolution is to broadcast beforehand the size of
the batches tensor, but adds an extra comm step. The current method of
extending the `batch` tensor with `-1`s is memory-safe as the batch
tensor will match the one used in previous iterations with a full batch.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants