Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] partition_balanced return wrong result. #4312

Merged
merged 10 commits into from Dec 8, 2023
99 changes: 37 additions & 62 deletions deepspeed/runtime/utils.py
Expand Up @@ -14,7 +14,6 @@
import psutil
import gc
from math import sqrt
from bisect import bisect_left
from packaging import version as pkg_version

import torch
Expand Down Expand Up @@ -570,67 +569,43 @@ def partition_uniform(num_items, num_parts):
return parts


def _lprobe(weights, num_parts, bottleneck):
num_items = len(weights)
total_weight = weights[-1]

# initialize partitioning
parts = [0] * (num_parts + 1)
for p in range(1, num_parts + 1):
parts[p] = num_items

bsum = bottleneck # running sum of target weight for pth partition
chunksize = num_items // num_parts
step = chunksize
for p in range(1, num_parts):
# Jump to the next bucket
while (step < num_items) and (weights[step] < bsum):
step += chunksize

# Find the end index of partition p
parts[p] = bisect_left(weights, bsum, lo=step - chunksize, hi=min(step, num_items))
# Nothing more to partition, return early
if parts[p] == num_items:
# See if the current partition is overweight.
part_size = weights[-1] - weights[parts[p - 1]]
return parts, part_size < bottleneck

# Next partition target
bsum = weights[parts[p] - 1] + bottleneck

return parts, bsum >= total_weight


def _rb_partition_balanced(weights, num_parts, eps):
total_weight = weights[-1]
lower = total_weight / num_parts # best case heaviest partition
upper = total_weight # worst case heaviest partition

# Do a binary search for the best partitioning
while upper > lower + eps:
mid = lower + ((upper - lower) / 2)
parts, success = _lprobe(weights, num_parts, mid)
if success:
upper = mid
else:
lower = mid + eps
return upper


def partition_balanced(weights, num_parts, eps=1e-3):
num_items = len(weights)
# First check for the trivial edge case
if num_items <= num_parts:
return partition_uniform(num_items, num_parts)

weights_ = prefix_sum_inc(weights)

# Find the smallest bottleneck (weight of heaviest partition)
bottleneck = _rb_partition_balanced(weights_, num_parts, eps=eps)

# Now compute that partitioning
parts, success = _lprobe(weights_, num_parts, bottleneck)
assert success
def partition_balanced(weights, num_parts):
"""
use dynamic programming solve `The Linear Partition Problem`.
see https://www8.cs.umu.se/kurser/TDBAfl/VT06/algorithms/BOOK/BOOK2/NODE45.HTM
"""
import numpy as np
n = len(weights)
m = num_parts

if n <= m:
return partition_uniform(n, m)

dp_max = np.full((n + 1, m + 1), np.inf)
dp_min = np.full((n + 1, m + 1), np.inf)
dp_cost = np.full((n + 1, m + 1), np.inf)
position = np.zeros((n + 1, m + 1), dtype=int)
prefix_sum = np.zeros((n + 1))
prefix_sum[1:] = np.cumsum(weights)

dp_max[0, 0] = 0
dp_cost[0, 0] = 0
for i in range(1, n + 1):
for j in range(1, min(i, m) + 1):
for k in range(i):
max_sum = max(dp_max[k, j - 1], prefix_sum[i] - prefix_sum[k])
min_sum = min(dp_min[k, j - 1], prefix_sum[i] - prefix_sum[k])
cost = max_sum - min_sum
if dp_cost[i, j] >= cost:
dp_cost[i, j] = cost
dp_max[i, j] = max_sum
dp_min[i, j] = min_sum
position[i, j] = k

parts = [n]
for i in reversed(range(1, m + 1)):
parts.append(position[parts[-1], i])
parts.reverse()

return parts

Expand Down
25 changes: 25 additions & 0 deletions tests/unit/utils/test_partition_balanced.py
@@ -0,0 +1,25 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

from deepspeed.runtime import utils as ds_utils


def check_partition(weights, num_parts, target_diff):
result = ds_utils.partition_balanced(weights=weights, num_parts=num_parts)

parts_sum = []
for b, e in zip(result[:-1], result[1:]):
parts_sum.append(sum(weights[b:e]))

assert max(parts_sum) - min(
parts_sum
) == target_diff, f"ds_utils.partition_balanced(weights={weights}, num_parts={num_parts}) return {result}"


def test_partition_balanced():
check_partition([1, 2, 1], 4, target_diff=2)
check_partition([1, 1, 1, 1], 4, target_diff=0)
check_partition([1, 1, 1, 1, 1], 4, target_diff=1)
check_partition([1, 1, 1, 1, 0, 1], 4, target_diff=1)