In [1]:
import os
import numpy as np
import torch.distributed as dist
import datetime
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import Sampler
import time
import axolotl.data_nested as data

from typing import List, Optional

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load testdata
time0 = time.time() 
train_set = data.get_dataset('/home/kkj/axolotl/datasets/IPR036736_90_grouped/train')
# train_set = data.get_dataset('/mnt/e/datasets/UniRef50_grouped/train')
time1 = time.time()
print(f"Dataset loaded in {time1 - time0:.2f} seconds")

lengths = train_set["length"]
time2 = time.time()
print(f"Lengths: {lengths}")
print(f"lengths retrieved in {time2 - time1:.2f} seconds")

cluster_sizes = train_set["cluster_size"]
time3 = time.time()
print(f"Cluster sizes: {cluster_sizes}")
print(f"Cluster sizes retrieved in {time3 - time2:.2f} seconds")

Dataset loaded in 0.16 seconds
Lengths: [tensor([94]), tensor([933]), tensor([158]), tensor([2094, 2114, 2118, 2085, 2082, 2087, 2090, 2081, 2085,  661, 2084,  609,
         407, 2086, 2094, 2084, 2087, 2085, 2089]), tensor([2383]), tensor([2485, 2485, 2499, 2485, 2485, 2485, 2485, 2044, 2504, 2434]), tensor([86]), tensor([8540]), tensor([83, 83]), tensor([1287, 1287]), tensor([78, 76]), tensor([2170, 2172, 2170, 2171]), tensor([86]), tensor([ 429, 3148, 1949]), tensor([2443, 2443]), tensor([951]), tensor([1669]), tensor([2280, 2287, 2290]), tensor([95]), tensor([2942]), tensor([1291]), tensor([589]), tensor([1817]), tensor([1184]), tensor([188]), tensor([290]), tensor([410, 406]), tensor([1121]), tensor([87, 87]), tensor([80, 81]), tensor([1168]), tensor([1312]), tensor([3086]), tensor([1788]), tensor([94]), tensor([2035]), tensor([81]), tensor([85]), tensor([1097]), tensor([86]), tensor([ 83, 135]), tensor([87]), tensor([4625]), tensor([2370]), tensor([685]), tensor([2425]), tensor([

In [None]:
class SimpleDistributedBatchSampler(Sampler):
    """Unpadded length sampling using Multipack V2, for models with quadratic attention complexity.
       It also tries to evenly distribute the sequences using LPT, so that quadratic load is more balanced.

       Approximate (at most 1.33x ?) the optimal solution of the identical-machines scheduling problem, which is NP-hard.

       Time Complexity: O(n log n log k)
       n = maximum number of sequences per batch, k = number of nodes
    """
    def __init__(
        self,
        dataset, # A clustered dataset TODO specify format
        length_key, # a list of tensors, containing the lengths of sequences in each cluster
        cluster_size_key, # a tensor of cluster sizes, where each cluster size corresponds to the number of sequences in the cluster
        max_length: int, # maximum length of each sequence in the batch, what the sequences will be truncated to
        total_length: int, # total length of the batch (total amount of tokens)
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        seed: int = 0,
        epoch: int = 0,
        shuffle: bool = True,
    ):
        # Get rank
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()

        self.dataset = dataset
        self.length_key = length_key
        self.cluster_size_key = cluster_size_key
        self.max_length = max_length # maximum length of each sequence in the batch, what the sequecnes will be truncated to
        self.total_length = total_length

        self.num_replicas = num_replicas
        self.rank = rank
        self.seed = seed
        self.epoch = epoch
        self.shuffle = shuffle

        self.indices = list(range(len(dataset)))

    def set_epoch(self, epoch: int):
        self.epoch = epoch
    
    def __iter__(self):
        if self.shuffle:
            g = torch.Generator().manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.indices), generator=g).tolist()
        else:
            indices = self.indices.copy()

        # We use the seed and epoch to deterministically select one sequence from each cluster
        cluster_selection_seed = (self.seed + self.epoch)
        
        # Track batch generation for distributed processing
        batch = []
        batch_length = 0
        batch_count = 0  # Global batch counter
        
        for idx in indices:
            # Get cluster size and select sequence deterministically
            cluster_size = self.dataset[idx][self.cluster_size_key].item()
            cluster_idx = cluster_selection_seed % cluster_size
            
            # Get the length and truncate to max_length
            sequence_length = self.dataset[idx][self.length_key][cluster_idx].item()
            truncated_length = min(sequence_length, self.max_length)
            
            # Add sample to current batch
            batch.append(idx)
            batch_length += truncated_length
            
            # Check if adding this sample caused overflow
            if batch_length >= self.total_length:
                # Check if this batch belongs to our rank
                if batch_count % self.num_replicas == self.rank:
                    yield batch
                
                # Reset for next batch
                batch = []
                batch_length = 0
                batch_count += 1
        
        # Handle remaining batch if not empty
        if len(batch) > 0:
            # Check if this final batch belongs to our rank
            if batch_count % self.num_replicas == self.rank:
                yield batch

    def collate_fn(self, batch):
        g = torch.Generator().manual_seed(self.seed + self.epoch)
        cluster_selection_seed = (self.seed + self.epoch)

        input_ids_list = []
        label_list = []
        length_sum = 0

        for i in range(len(batch) - 1):
            # Get the sample from the dataset
            cluster_size = batch['cluster_size'][i].item()
            if cluster_size == 1:
                idx = 0
            else:
                idx = (cluster_selection_seed % cluster_size)

            input_ids_list.append(data.maybe_truncate(batch["input_ids"][i][idx], self.max_length, generator=g))
            length_sum += min(batch["length"][i][idx].item(), self.max_length)
            label_list.append(batch["label"][i][idx])

        # the last input ids should be truncated to the remaining length, to not exceed the total length
        last_length = min(self.total_length - length_sum, self.max_length)
        input_ids_list.append(data.maybe_truncate(batch["input_ids"][-1][idx], last_length, generator=g))
        label_list.append(batch["label"][-1][idx])

        # convert to nested tensor for the model
        input_ids = torch.nested.nested_tensor(input_ids_list, layout=torch.jagged)
        label = torch.tensor(label_list, dtype=torch.long)

        return {"input_ids": input_ids, "label": label}

In [None]:
DATASET = "/home/kkj/axolotl/datasets/IPR036736_90_grouped/train"
BATCH_SIZE = 16
CTX_LEN = 2048
NUM_GPUS = 8
EPOCHS = 10
SEED = 42
START_EPOCH = 0

SAMPLERS = {
    "Simple Distributed Batch Sampler": lambda lengths, cluster_sizes, rank: SimpleDistributedBatchSampler(
        dataset=train_set,
        length_key="length",
        cluster_size_key="cluster_size",
        max_length=CTX_LEN,
        total_length=BATCH_SIZE * CTX_LEN,
        num_replicas=NUM_GPUS,
        rank=rank,
        seed=SEED,
        epoch=START_EPOCH,
        shuffle=True,  # Shuffle batches for each epoch
),
    "Multipack QuadraticAttention": lambda lengths, cluster_sizes, rank: data.MultipackDistributedBatchSampler(lengths=lengths, 
                                                                                                cluster_sizes=cluster_sizes,
                                                                                                max_length=CTX_LEN,
                                                                                                total_length=BATCH_SIZE * CTX_LEN, 
                                                                                                num_replicas=NUM_GPUS, 
                                                                                                rank=rank,
                                                                                                seed=SEED,
                                                                                                epoch=START_EPOCH,
),
}

# test sampler correctness & efficiency
for sampler_name, sampler_fn in SAMPLERS.items():
    print(f"Sampler {sampler_name}:")

    tot_len = 0
    tot_maxlen = 0
    tot_batches = 0
    avg_sql2lag = 0
    max_sql2lag = 0

    # Detailed timing for sampler creation
    start_time = time.time()
    samplers = [sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank) for rank in range(NUM_GPUS)]
    creation_time = time.time() - start_time
    print(f"Sampler creation time: {creation_time*1000:.2f}ms")
    
    # Try to get batch counts (might not work for streaming samplers)
    try:
        batch_count_start = time.time()
        batch_counts = [sampler.num_batches() for sampler in samplers if hasattr(sampler, 'num_batches')]
        batch_count_time = time.time() - batch_count_start
        print("Batch count for ranks:", batch_counts)
        print(f"Batch counting time: {batch_count_time*1000:.2f}ms")
    except:
        print("Batch counting not available (streaming sampler)")
    
    print("")

    # Time each epoch
    for epoch in range(START_EPOCH, EPOCHS + START_EPOCH):
        epoch_start = time.time()
        
        batches = []
        sqlen = [[] for _ in range(NUM_GPUS)]
        olen = [[] for _ in range(NUM_GPUS)]

        # Time epoch setting
        set_epoch_start = time.time()
        for rank, sampler in enumerate(samplers):
            sampler.set_epoch(epoch)
        set_epoch_time = time.time() - set_epoch_start
        
        # Time batch iteration
        iteration_start = time.time()
        cluster_selection_seed = (SEED + epoch)
        
        batch_processing_times = []
        for rank, sampler in enumerate(samplers):
            rank_start = time.time()
            
            for batch_idx, batch in enumerate(sampler):
                batch_start = time.time()
                
                batches.extend(batch)
                cluster_selection_idx = [cluster_selection_seed % cluster_sizes[x] for x in batch]

                # Check constraints
                overall_len = sum([min(lengths[x][y], CTX_LEN) for x, y in zip(batch, cluster_selection_idx)])
                square_len = sum([lengths[x][y] ** 2 for x, y in zip(batch, cluster_selection_idx)])
                # assert overall_len <= BATCH_SIZE * CTX_LEN, f"Overall length {overall_len} exceeds maximum {BATCH_SIZE * CTX_LEN} for batch {batch} at rank {rank} in epoch {epoch}"

                # Add stats
                tot_len += overall_len
                tot_batches += 1

                # square len
                sqlen[rank].append(square_len)
                olen[rank].append(overall_len)
                
                batch_time = time.time() - batch_start
                batch_processing_times.append(batch_time * 1000)
            
            rank_time = time.time() - rank_start
            if epoch == START_EPOCH:  # Only print for first epoch
                print(f"  Rank {rank} processing time: {rank_time*1000:.2f}ms")
        
        iteration_time = time.time() - iteration_start
        
        # Time statistics computation
        stats_start = time.time()
        tot_maxlen += np.sum(np.max(olen, axis=0))

        sqlen = np.array(sqlen)
        sqlag = np.max(sqlen, axis=0) - np.min(sqlen, axis=0)

        avg_sql2lag += np.sqrt(np.mean(sqlag))
        max_sql2lag += np.sqrt(np.max(sqlag))

        # Check overall unique
        batches.sort()
        assert batches == list(set(batches))  # Unique
        stats_time = time.time() - stats_start
        
        epoch_time = time.time() - epoch_start
        
        if epoch == START_EPOCH:  # Print detailed timing for first epoch
            print(f"  Epoch {epoch} timing breakdown:")
            print(f"    Set epoch: {set_epoch_time*1000:.2f}ms")
            print(f"    Batch iteration: {iteration_time*1000:.2f}ms")
            print(f"    Statistics: {stats_time*1000:.2f}ms")
            print(f"    Total epoch: {epoch_time*1000:.2f}ms")
            if batch_processing_times:
                print(f"    Avg batch processing: {np.mean(batch_processing_times):.2f}ms")
                print(f"    Min/Max batch processing: {np.min(batch_processing_times):.2f}ms / {np.max(batch_processing_times):.2f}ms")

    # Check efficiency
    print(f"L^2 lag avg: {avg_sql2lag / EPOCHS:.0f} max: {max_sql2lag / EPOCHS:.0f}")
    print(f"Efficiency: {tot_len / (tot_batches * CTX_LEN * BATCH_SIZE) * 100:.2f}%")
    print(f"Utilization: {tot_len / (tot_maxlen * NUM_GPUS) * 100:.2f}%")
    print("==========\n")

In [7]:
# Detailed timing tests for sampler performance
print("=== DETAILED TIMING TESTS ===\n")

def time_sampler_parts(sampler_name, sampler_fn, rank=0):
    print(f"Testing {sampler_name} (Rank {rank}):")
    
    # Test sampler creation time
    start_time = time.time()
    sampler = sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank)
    creation_time = time.time() - start_time
    print(f"  Sampler creation: {creation_time*1000:.2f}ms")
    
    # Test epoch setting time
    start_time = time.time()
    sampler.set_epoch(0)
    epoch_time = time.time() - start_time
    print(f"  Set epoch: {epoch_time*1000:.2f}ms")
    
    # Test iteration initialization time
    start_time = time.time()
    iterator = iter(sampler)
    init_time = time.time() - start_time
    print(f"  Iterator init: {init_time*1000:.2f}ms")
    
    # Test first batch generation time
    start_time = time.time()
    try:
        first_batch = next(iterator)
        first_batch_time = time.time() - start_time
        print(f"  First batch: {first_batch_time*1000:.2f}ms (size: {len(first_batch)})")
    except StopIteration:
        print("  First batch: No batches generated")
        return
    
    # Test subsequent batch generation times
    batch_times = []
    batch_sizes = []
    start_time = time.time()
    
    batch_count = 1  # Already got first batch
    for batch in iterator:
        batch_time = time.time()
        batch_times.append((batch_time - start_time) * 1000)
        batch_sizes.append(len(batch))
        start_time = batch_time
        batch_count += 1
        if batch_count >= 10:  # Test first 10 batches
            break
    
    if batch_times:
        print(f"  Avg batch time (2-10): {np.mean(batch_times):.2f}ms")
        print(f"  Min/Max batch time: {np.min(batch_times):.2f}ms / {np.max(batch_times):.2f}ms")
        print(f"  Avg batch size: {np.mean(batch_sizes):.1f}")
        print(f"  Min/Max batch size: {min(batch_sizes)} / {max(batch_sizes)}")
    
    # Test collate function time if available
    if hasattr(sampler, 'collate_fn'):
        # Create a mock batch using actual dataset samples
        mock_batch = train_set[first_batch]
        
        start_time = time.time()
        collated = sampler.collate_fn(mock_batch)
        collate_time = time.time() - start_time
        print(f"  Collate time (first batch): {collate_time*1000:.2f}ms")
    
    print()

# Test each sampler
for sampler_name, sampler_fn in SAMPLERS.items():
    time_sampler_parts(sampler_name, sampler_fn, rank=0)

print("=== MEMORY USAGE TEST ===")
import psutil
import gc

def test_memory_usage(sampler_name, sampler_fn, rank=0):
    print(f"Testing memory usage for {sampler_name}:")
    
    # Get initial memory
    process = psutil.Process()
    initial_memory = process.memory_info().rss / 1024 / 1024  # MB
    
    # Create sampler
    sampler = sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank)
    after_creation = process.memory_info().rss / 1024 / 1024
    
    # Generate some batches
    sampler.set_epoch(0)
    batches = []
    for i, batch in enumerate(iter(sampler)):
        batches.append(batch)
        if i >= 50:  # Generate 50 batches
            break
    
    after_batches = process.memory_info().rss / 1024 / 1024
    
    print(f"  Initial memory: {initial_memory:.1f}MB")
    print(f"  After creation: {after_creation:.1f}MB (+{after_creation-initial_memory:.1f}MB)")
    print(f"  After 50 batches: {after_batches:.1f}MB (+{after_batches-initial_memory:.1f}MB)")
    print(f"  Generated {len(batches)} batches")
    
    # Cleanup
    del sampler, batches
    gc.collect()
    print()

# Test memory usage for each sampler
for sampler_name, sampler_fn in SAMPLERS.items():
    test_memory_usage(sampler_name, sampler_fn, rank=0)

=== DETAILED TIMING TESTS ===

Testing Simple Distributed Batch Sampler (Rank 0):
  Sampler creation: 11.90ms
  Set epoch: 0.01ms
  Iterator init: 0.01ms
  First batch: 108.80ms (size: 26)
  Avg batch time (2-10): 504.81ms
  Min/Max batch time: 445.52ms / 574.38ms
  Avg batch size: 33.3
  Min/Max batch size: 28 / 42
  Avg batch time (2-10): 504.81ms
  Min/Max batch time: 445.52ms / 574.38ms
  Avg batch size: 33.3
  Min/Max batch size: 28 / 42
  Collate time (first batch): 1892.92ms

Testing Multipack QuadraticAttention (Rank 0):
  Sampler creation: 22.38ms
  Set epoch: 0.00ms
  Collate time (first batch): 1892.92ms

Testing Multipack QuadraticAttention (Rank 0):
  Sampler creation: 22.38ms
  Set epoch: 0.00ms
  Iterator init: 15314.93ms
  First batch: 0.00ms (size: 29)
  Avg batch time (2-10): 0.00ms
  Min/Max batch time: 0.00ms / 0.00ms
  Avg batch size: 31.7
  Min/Max batch size: 30 / 33
  Iterator init: 15314.93ms
  First batch: 0.00ms (size: 29)
  Avg batch time (2-10): 0.00ms
  Mi

TypeError: string indices must be integers

In [None]:

train_sampler = data.MultipackDistributedBatchSampler(
        lengths = lengths, # a list of lists, where each inner list contains the lengths of the sequences in the cluster
        cluster_sizes = cluster_sizes, # a list of cluster sizes, where each cluster size corresponds to the number of sequences in the cluster
        num_replicas: Optional[int] = None,
        rank: Optional[int] = None,
        max_length = 1024, # maximum length of each sequence in the batch, what the sequences will be truncated to
        total_length = 4*1024, # total length of the batch (total amount of tokens)
        seed: int = 0,
        epoch: int = 0,

)

# Build data iterators
train_loader = data.cycle_loader(DataLoader(
    train_set,
    # batch_size=.batch_size // (config.ngpus * .accum),
    batch_sampler=train_sampler,
    num_workers=num_workers,
    collate_fn=train_sampler.collate_fn,
    pin_memory=True,
    shuffle=(train_sampler is None),
    persistent_workers=True,
))