In [4]:
import json
import numpy as np
import time

from multipack_sampler import MultipackDistributedBatchSampler
from multipack_sampler_linear import MultipackDistributedBatchSampler_LinearAttention

In [5]:
class InterleavedSampler:
    def __init__(self, lengths: np.ndarray, batch_size: int, num_replicas: int, rank: int, seed: int = 0):
        self.seed = seed

        self.lengths = lengths
        self.batch_size = batch_size

        self.num_replicas = num_replicas
        self.rank = rank

        self.epoch = 0

    def num_batches(self):
        return len(self.lengths) // (self.num_replicas * self.batch_size)
    
    def set_epoch(self, epoch):
        self.epoch = epoch

    def __iter__(self):
        indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(len(self.lengths))

        lengths = self.lengths[indices]
        overall_batch_size = self.batch_size * self.num_replicas
        for index in range(0, len(lengths), overall_batch_size):
            batch = lengths[index: index + overall_batch_size]
            if len(batch) < self.num_replicas:
                break

            result = indices[index + np.argsort(batch)]
            yield result[self.rank + np.arange(self.batch_size) * self.num_replicas]


In [6]:
DATASET = "testdata.json"
BATCH_SIZE = 16
CTX_LEN = 2048
NUM_GPUS = 8
EPOCHS = 10

SAMPLERS = {
    "Multipack QuadraticAttention": lambda lengths, rank: MultipackDistributedBatchSampler(lengths=lengths, batch_max_length=BATCH_SIZE * CTX_LEN, num_replicas=NUM_GPUS, rank=rank),
    "Multipack LinearAttention": lambda lengths, rank: MultipackDistributedBatchSampler_LinearAttention(lengths=lengths, batch_max_length=BATCH_SIZE * CTX_LEN, num_replicas=NUM_GPUS, rank=rank),

    "Interleaved": lambda lengths, rank: InterleavedSampler(lengths=lengths, batch_size=BATCH_SIZE, num_replicas=NUM_GPUS, rank=rank),
}

# Load testdata
with open(DATASET, "r") as f:
    lengths = np.array(json.load(f))

# test sampler correctness & efficiency
for sampler_name, sampler_fn in SAMPLERS.items():
    print(f"Sampler {sampler_name}:")

    tot_len = 0
    tot_maxlen = 0
    tot_batches = 0
    avg_sql2lag = 0
    max_sql2lag = 0

    start_time = time.time()
    samplers = [sampler_fn(lengths=lengths, rank=rank) for rank in range(NUM_GPUS)]
    print("Batch count for ranks:", [sampler.num_batches() for sampler in samplers])
    print(f"Packing Time: {(time.time() - start_time) * 1000:.0f}ms")
    print("")

    for epoch in range(EPOCHS):
        batches = []
        sqlen = [[] for _ in range(NUM_GPUS)]
        olen = [[] for _ in range(NUM_GPUS)]

        for rank, sampler in enumerate(samplers):
            sampler.set_epoch(epoch)

            for batch in sampler:
                batches.extend(batch)

                # Check constraints
                overall_len = sum([lengths[x] for x in batch])
                square_len = sum([lengths[x] ** 2 for x in batch])
                assert overall_len <= BATCH_SIZE * CTX_LEN

                # Add stats
                tot_len += overall_len
                tot_batches += 1

                # square len
                sqlen[rank].append(square_len)
                olen[rank].append(overall_len)

        tot_maxlen += np.sum(np.max(olen, axis=0))

        sqlen = np.array(sqlen)
        sqlag = np.max(sqlen, axis=0) - np.min(sqlen, axis=0)

        avg_sql2lag += np.sqrt(np.mean(sqlag))
        max_sql2lag += np.sqrt(np.max(sqlag))

        # Check overall unique
        batches.sort()
        assert batches == list(set(batches))  # Unique

    # Check efficiency
    print(f"L^2 lag avg: {avg_sql2lag / EPOCHS:.0f} max: {max_sql2lag / EPOCHS:.0f}")
    print(f"Efficiency: {tot_len / (tot_batches * CTX_LEN * BATCH_SIZE) * 100:.2f}%")
    print(f"Utilization: {tot_len / (tot_maxlen * NUM_GPUS) * 100:.2f}%")
    print("==========\n")

Sampler Multipack QuadraticAttention:
Batch count for ranks: [37, 37, 37, 37, 37, 37, 37, 37]
Packing Time: 20ms

L^2 lag avg: 438 max: 717
Efficiency: 99.70%
Utilization: 98.16%

Sampler Multipack LinearAttention:
Batch count for ranks: [36, 36, 36, 36, 36, 36, 36, 36]
Packing Time: 19ms

L^2 lag avg: 6500 max: 6761
Efficiency: 99.64%
Utilization: 99.64%

Sampler Interleaved:
Batch count for ranks: [48, 48, 48, 48, 48, 48, 48, 48]
Packing Time: 0ms

L^2 lag avg: 1914 max: 2000
Efficiency: 96.79%
Utilization: 75.67%

