Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] partition_balanced return wrong result. #4312

Merged
merged 10 commits into from Dec 8, 2023

Conversation

zjjMaiMai
Copy link
Contributor

Background

In pipeline parallelism, deepspeed uses ds_utils.partition_balanced to balance the partitioning of the model according to the number of parameters or class names.

if method == 'uniform':
num_layers = len(self._layer_specs)
self.parts = ds_utils.partition_uniform(num_items=num_layers, num_parts=num_stages)
elif method == 'parameters':
param_counts = self._count_layer_params()
self.parts = ds_utils.partition_balanced(weights=param_counts, num_parts=num_stages)
elif method.startswith('type:'):
layertype = method.split(':')[1]
binary_weights = [0] * len(self._layer_specs)
for idx in self._find_layer_type(layertype):
binary_weights[idx] = 1
self.parts = ds_utils.partition_balanced(weights=binary_weights, num_parts=num_stages)
elif method == 'profile':
raise NotImplementedError(f'Partitioning method {method} not implemented.')
else:
raise NotImplementedError(f'Partitioning method {method} not implemented.')

What wrong?

>>> import deepspeed
>>> deepspeed.__version__
'0.10.3+542dc0d5'
>>> from deepspeed.runtime import utils as ds_utils
>>> ds_utils.partition_balanced([1, 1, 1, 1, 1], 4)
[0, 2, 4, 5, 5]
>>> 

the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part, which is not balanced at all. the last part will throw an exception because there are no parameters to training.

i add some unit test for this function, and i will fix it later if anyone need it.

@zjjMaiMai
Copy link
Contributor Author

@microsoft-github-policy-service agree

@zjjMaiMai
Copy link
Contributor Author

already fixed! cc @tjruwase

@ShadenSmith
Copy link
Contributor

Thanks for this PR, @zjjMaiMai!

A note on balance: the objective function in the original code minimizes the maximum load per partition. The maximum load on a pipeline stage determines the pipeline throughput, and so the original result is also balanced. Example paper: Fast Optimal Load Balancing Algorithms for 1D Partitioning

@tjruwase tjruwase added this pull request to the merge queue Oct 6, 2023
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Oct 6, 2023
@tjruwase tjruwase added this pull request to the merge queue Dec 4, 2023
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Dec 4, 2023
@tjruwase tjruwase added this pull request to the merge queue Dec 8, 2023
Merged via the queue into microsoft:master with commit 2bdf061 Dec 8, 2023
15 checks passed
mauryaavinash95 pushed a commit to mauryaavinash95/DeepSpeed that referenced this pull request Feb 17, 2024
# Background

In pipeline parallelism, deepspeed uses `ds_utils.partition_balanced` to
balance the partitioning of the model according to the number of
parameters or class names.

https://github.com/microsoft/DeepSpeed/blob/581e44dd1ab3c409a5905335867c761d5cb4db5b/deepspeed/runtime/pipe/module.py#L380-L395

# What wrong?
```
>>> import deepspeed
>>> deepspeed.__version__
'0.10.3+542dc0d5'
>>> from deepspeed.runtime import utils as ds_utils
>>> ds_utils.partition_balanced([1, 1, 1, 1, 1], 4)
[0, 2, 4, 5, 5]
>>> 
```
the result [0, 2, 4, 5, 5] means [2, 2, 1, 0] layers for each part,
which is not balanced at all. the last part will throw an exception
because there are no parameters to training.

i add some unit test for this function, and i will fix it later if
anyone need it.

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants