In [39]:
import json
import numpy as np

from multipack_sampler import MultipackDistributedBatchSampler

In [40]:
class InterleavedSampler:
    def __init__(self, lengths: np.ndarray, batch_size: int, num_replicas: int, rank: int, seed: int = 0):
        self.seed = seed

        self.lengths = lengths
        self.batch_size = batch_size

        self.num_replicas = num_replicas
        self.rank = rank

        self.epoch = 0

    def num_batches(self):
        return len(self.lengths) // (self.num_replicas * self.batch_size)
    
    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))

        lengths = self.lengths[indices]
        overall_batch_size = self.batch_size * self.num_replicas
        for index in range(0, len(lengths), overall_batch_size):
            batch = lengths[index: index + overall_batch_size]
            if len(batch) < self.num_replicas:
                break

            result = indices[index + np.argsort(batch)]
            yield result[self.rank + np.arange(self.batch_size) * self.num_replicas]


In [41]:
DATASET = "testdata.json"
BATCH_SIZE = 16
CTX_LEN = 2048
NUM_GPUS = 8
EPOCHS = 10

SAMPLERS = {
    "Multipack": lambda lengths, rank: MultipackDistributedBatchSampler(lengths=lengths, batch_max_length=BATCH_SIZE * CTX_LEN, num_replicas=NUM_GPUS, rank=rank),
    "Interleaved": lambda lengths, rank: InterleavedSampler(lengths=lengths, batch_size=BATCH_SIZE, num_replicas=NUM_GPUS, rank=rank),
}

# Load testdata
with open(DATASET, "r") as f:
    lengths = np.array(json.load(f))

# test sampler correctness & efficiency
for sampler_name, sampler_fn in SAMPLERS.items():
    print(f"Sampler {sampler_name}:")

    tot_len = 0
    tot_batches = 0

    samplers = [sampler_fn(lengths=lengths, rank=rank) for rank in range(NUM_GPUS)]
    print([sampler.num_batches() for sampler in samplers])

    for epoch in range(EPOCHS):
        batches = []
        tot_length = [[] for _ in range(NUM_GPUS)]

        for rank, sampler in enumerate(samplers):
            sampler.set_epoch(epoch)
            for batch in sampler:
                batches.extend(batch)

                # Check constraints
                overall_len = sum([lengths[x] for x in batch])
                assert overall_len <= BATCH_SIZE * CTX_LEN

                # Add stats
                tot_len += overall_len
                tot_batches += 1

        # Check overall unique
        batches.sort()
        assert batches == list(set(batches))  # Unique

    # Check efficiency
    print(f"Overall Efficiency: {tot_len / (tot_batches * CTX_LEN * BATCH_SIZE)}")

Sampler Multipack:
[36, 36, 36, 36, 36, 36, 36, 36]
Overall Efficiency: 0.9963896327548557
Sampler Interleaved:
[48, 48, 48, 48, 48, 48, 48, 48]
Overall Efficiency: 0.756684939066569
