In [19]:
import numpy as np
import numba

import json

In [27]:
from typing import Optional, List

import torch.distributed as dist
from torch.utils.data import Sampler

import numpy as np
import numba


@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
    # First-fit-decreasing bin packing
    # Check if a[] could fit in n bins with capacity c
    # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing

    a = np.sort(a)[::-1]
    bins = np.full((n, ), c, dtype=a.dtype)
    for size in a:
        not_found = True
        for idx in range(n):
            if bins[idx] >= size:
                bins[idx] -= size
                not_found = False
                break

        if not_found:
            return False

    return True


@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
    # First-fit-decreasing bin packing (with result return)

    indices = np.argsort(a)[::-1]
    a = a[indices]

    bins = []
    bins_result = []
    for a_id, size in enumerate(a):
        add_new = True
        for idx in range(len(bins)):
            if bins[idx] >= size:
                bins[idx] -= size
                bins_result[idx].append(indices[a_id] + start_index)
                add_new = False
                break

        if add_new:
            bins.append(c - size)
            bins_result.append([indices[a_id] + start_index])

    return bins_result


@numba.njit
def allocate(lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int):
    # Dynamic batch allocator, similar to Multifit
    # https://en.wikipedia.org/wiki/Multifit_algorithm
    # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)

    s = 0
    start_index = 0
    result = []

    while True:
        # binary search [l, r)
        l = 1
        r = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")

        while r - l > 1:
            m = (l + r) // 2
            if ffd_check(lengths[start_index: start_index + m], c, n):
                l = m
            else:
                r = m

        # use length l
        batch = ffd_with_result(lengths[start_index: start_index + l], c, start_index)
        assert len(batch) <= n
        if len(batch) < n:
            break

        start_index += l
        s = lengths_cumsum[start_index - 1]

        # add local rank
        result.append(batch[rank])

    return result, s, len(result) * c * n


class MultipackDistributedBatchSampler(Sampler):
    """Unpadded length sampling using Multipack.
       Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard."""
    
    def __init__(
        self,
        batch_max_length: int,
        lengths: List[int],
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        seed: int = 0,
    ):
        # Get rank
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()

        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed

        self.batch_max_length = batch_max_length
        self.lengths = lengths
        assert isinstance(self.lengths, np.ndarray)

        self.epoch = 0

        # statistics
        self.eff_total_used = 0
        self.eff_total_slots = 0

    def set_epoch(self, epoch: int):
        self.epoch = epoch

    def generate_batches(self, set_stats=False):
        indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))

        lengths = self.lengths[indices]
        lengths_cumsum = np.cumsum(lengths)

        batches, total_used, total_slots = allocate(lengths=lengths,
                                                    lengths_cumsum=lengths_cumsum,
                                                    rank=self.rank,
                                                    c=self.batch_max_length,
                                                    n=self.num_replicas)
        
        batches = [indices[batch] for batch in batches]

        # statistics
        if set_stats:
            self.eff_total_used += total_used
            self.eff_total_slots += total_slots
        
        return batches
    
    def __iter__(self):
        batches = self.generate_batches(set_stats=True)
        return iter(batches)

    def num_batches(self):
        batches = self.generate_batches()
        return len(batches)

    def efficiency(self):
        return self.eff_total_used / self.eff_total_slots


In [31]:
lengths = np.array([1, 5, 7, 8, 3, 2])
lengths_cumsum = np.cumsum(lengths)

print(ffd_with_result(lengths, 8, start_index=0))

[[3], [2, 0], [1, 4], [5]]


In [32]:
DATASET = "../../dataset_processed/openchat.train.json"
C = 12 * 2048
N = 8
EPOCHS = 10

# Load dataset
with open(DATASET, "r") as f:
    dataset = json.load(f)

# Check allocator efficiency
lengths = np.array([len(tokens) for tokens, masks in dataset])

# test sampler correctness & efficiency
tot_len = 0
tot_batches = 0

samplers = [MultipackDistributedBatchSampler(C, lengths, N, rank) for rank in range(N)]
print([sampler.num_batches() for sampler in samplers])

for epoch in range(EPOCHS):
    batches = []

    for sampler in samplers:
        sampler.set_epoch(epoch)
        for batch in sampler:
            batches.extend(batch)

            # Check constraints
            overall_len = sum([lengths[x] for x in batch])
            assert overall_len <= C

            tot_len += overall_len
            tot_batches += 1

    # Check overall unique
    batches.sort()
    assert batches == list(set(batches))  # Unique

# Check efficiency
efficiency = [sampler.efficiency() for sampler in samplers]
print(f"Efficiency: {efficiency}")

print(f"Overall Efficiency: {tot_len / (tot_batches * C)}")

[48, 48, 48, 48, 48, 48, 48, 48]
Efficiency: [0.9951017803615994, 0.9951017803615994, 0.9951017803615994, 0.9951017803615994, 0.9951017803615994, 0.9951017803615994, 0.9951017803615994, 0.9951017803615994]
Overall Efficiency: 0.9951017803615994


In [23]:
print(samplers[0].num_batches())

48
