In [None]:
import time
import axolotl.data_nested as data
import numpy as np

In [None]:
# Load testdata
time0 = time.time() 
train_set = data.get_dataset('/mnt/e/datasets/IPR036736_90_grouped/train')
# train_set = data.get_dataset('/mnt/e/datasets/UniRef50_grouped/train')
time1 = time.time()
print(f"Dataset loaded in {time1 - time0:.2f} seconds")

lengths = train_set["length"]
time2 = time.time()
# print(f"Lengths: {lengths}")
print(f"lengths retrieved in {time2 - time1:.2f} seconds")

cluster_sizes = train_set["cluster_size"]
time3 = time.time()
print(f"Cluster sizes: {cluster_sizes}")
print(f"Cluster sizes retrieved in {time3 - time2:.2f} seconds")

In [4]:
DATASET = "/home/kkj/axolotl/datasets/IPR036736_90_grouped/train"
BATCH_SIZE = 16
CTX_LEN = 2048
NUM_GPUS = 8
EPOCHS = 10
SEED = 42
START_EPOCH = 0

SAMPLERS = {
    "Simple Distributed Batch Sampler": lambda lengths, cluster_sizes, rank: data.SimpleDistributedBatchSampler(
        dataset=train_set,
        length_key="length",
        cluster_size_key="cluster_size",
        max_length=CTX_LEN,
        total_length=BATCH_SIZE * CTX_LEN,
        num_replicas=NUM_GPUS,
        rank=rank,
        seed=SEED,
        epoch=START_EPOCH,
        shuffle=True,  # Shuffle batches for each epoch
),
    "Multipack QuadraticAttention": lambda lengths, cluster_sizes, rank: data.MultipackDistributedBatchSampler(lengths=lengths, 
                                                                                                cluster_sizes=cluster_sizes,
                                                                                                max_length=CTX_LEN,
                                                                                                total_length=BATCH_SIZE * CTX_LEN, 
                                                                                                num_replicas=NUM_GPUS, 
                                                                                                rank=rank,
                                                                                                seed=SEED,
                                                                                                epoch=START_EPOCH,
),
}


In [None]:
# test sampler correctness & efficiency
for sampler_name, sampler_fn in SAMPLERS.items():
    print(f"Sampler {sampler_name}:")

    tot_len = 0
    tot_maxlen = 0
    tot_batches = 0
    avg_sql2lag = 0
    max_sql2lag = 0

    # Detailed timing for sampler creation
    start_time = time.time()
    samplers = [sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank) for rank in range(NUM_GPUS)]
    creation_time = time.time() - start_time
    print(f"Sampler creation time: {creation_time*1000:.2f}ms")
    
    # Try to get batch counts (might not work for streaming samplers)
    try:
        batch_count_start = time.time()
        batch_counts = [sampler.num_batches() for sampler in samplers if hasattr(sampler, 'num_batches')]
        batch_count_time = time.time() - batch_count_start
        print("Batch count for ranks:", batch_counts)
        print(f"Batch counting time: {batch_count_time:.2f}s")
    except:
        print("Batch counting not available (streaming sampler)")
    
    print("")

    # Time each epoch
    for epoch in range(START_EPOCH, EPOCHS + START_EPOCH):
        epoch_start = time.time()
        
        batches = []
        sqlen = [[] for _ in range(NUM_GPUS)]
        olen = [[] for _ in range(NUM_GPUS)]

        # Time epoch setting
        set_epoch_start = time.time()
        for rank, sampler in enumerate(samplers):
            sampler.set_epoch(epoch)
        set_epoch_time = time.time() - set_epoch_start
        
        # Time batch iteration
        iteration_start = time.time()
        cluster_selection_seed = (SEED + epoch)
        
        batch_processing_times = []
        for rank, sampler in enumerate(samplers):
            rank_start = time.time()
            
            for batch_idx, batch in enumerate(sampler):
                batch_start = time.time()
                
                batches.extend(batch)
                cluster_selection_idx = [cluster_selection_seed % cluster_sizes[x] for x in batch]

                # Check constraints
                overall_len = sum([min(lengths[x][y], CTX_LEN) for x, y in zip(batch, cluster_selection_idx)])
                square_len = sum([lengths[x][y] ** 2 for x, y in zip(batch, cluster_selection_idx)])
                # assert overall_len <= BATCH_SIZE * CTX_LEN, f"Overall length {overall_len} exceeds maximum {BATCH_SIZE * CTX_LEN} for batch {batch} at rank {rank} in epoch {epoch}"

                # Add stats
                tot_len += overall_len
                tot_batches += 1

                # square len
                sqlen[rank].append(square_len)
                olen[rank].append(overall_len)
                
                batch_time = time.time() - batch_start
                batch_processing_times.append(batch_time)
            
            rank_time = time.time() - rank_start
            if epoch == START_EPOCH:  # Only print for first epoch
                print(f"  Rank {rank} processing time: {rank_time:.2f}m")
        
        iteration_time = time.time() - iteration_start
        
        # # Time statistics computation
        # stats_start = time.time()
        # tot_maxlen += np.sum(np.max(olen, axis=0))

        # sqlen = np.array(sqlen)
        # sqlag = np.max(sqlen, axis=0) - np.min(sqlen, axis=0)

        # avg_sql2lag += np.sqrt(np.mean(sqlag))
        # max_sql2lag += np.sqrt(np.max(sqlag))

        # Check overall unique
        batches.sort()
        assert batches == list(set(batches))  # Unique
        # stats_time = time.time() - stats_start
        
        epoch_time = time.time() - epoch_start
        
        if epoch == START_EPOCH:  # Print detailed timing for first epoch
            print(f"  Epoch {epoch} timing breakdown:")
            print(f"    Set epoch: {set_epoch_time:.2f}s")
            print(f"    Batch iteration: {iteration_time:.2f}s")
            # print(f"    Statistics: {stats_time*1000:.2f}ms")
            print(f"    Total epoch: {epoch_time:.2f}s")
            if batch_processing_times:
                print(f"    Avg batch processing: {np.mean(batch_processing_times):.2f}s")
                print(f"    Min/Max batch processing: {np.min(batch_processing_times):.2f}s / {np.max(batch_processing_times):.2f}s")

    # # Check efficiency
    # print(f"L^2 lag avg: {avg_sql2lag / EPOCHS:.0f} max: {max_sql2lag / EPOCHS:.0f}")
    # print(f"Efficiency: {tot_len / (tot_batches * CTX_LEN * BATCH_SIZE) * 100:.2f}%")
    # print(f"Utilization: {tot_len / (tot_maxlen * NUM_GPUS) * 100:.2f}%")
    # print("==========\n")

In [None]:
# Detailed timing tests for sampler performance
print("=== DETAILED TIMING TESTS ===\n")

def time_sampler_parts(sampler_name, sampler_fn, rank=0):
    print(f"Testing {sampler_name} (Rank {rank}):")
    
    # Test sampler creation time
    start_time = time.time()
    sampler = sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank)
    creation_time = time.time() - start_time
    print(f"  Sampler creation: {creation_time*1000:.2f}ms")
    
    # Test epoch setting time
    start_time = time.time()
    sampler.set_epoch(0)
    epoch_time = time.time() - start_time
    print(f"  Set epoch: {epoch_time*1000:.2f}ms")
    
    # Test iteration initialization time
    start_time = time.time()
    iterator = iter(sampler)
    init_time = time.time() - start_time
    print(f"  Iterator init: {init_time*1000:.2f}ms")
    
    # Test first batch generation time
    start_time = time.time()
    try:
        first_batch = next(iterator)
        first_batch_time = time.time() - start_time
        print(f"  First batch: {first_batch_time*1000:.2f}ms (size: {len(first_batch)})")
    except StopIteration:
        print("  First batch: No batches generated")
        return
    
    # Test subsequent batch generation times
    batch_times = []
    batch_sizes = []
    start_time = time.time()
    
    batch_count = 1  # Already got first batch
    for batch in iterator:
        batch_time = time.time()
        batch_times.append((batch_time - start_time) * 1000)
        batch_sizes.append(len(batch))
        start_time = batch_time
        batch_count += 1
        if batch_count >= 10:  # Test first 10 batches
            break
    
    if batch_times:
        print(f"  Avg batch time (2-10): {np.mean(batch_times):.2f}ms")
        print(f"  Min/Max batch time: {np.min(batch_times):.2f}ms / {np.max(batch_times):.2f}ms")
        print(f"  Avg batch size: {np.mean(batch_sizes):.1f}")
        print(f"  Min/Max batch size: {min(batch_sizes)} / {max(batch_sizes)}")
    
    # Test collate function time if available
    if hasattr(sampler, 'collate_fn'):
        # Create a mock batch using actual dataset samples
        mock_batch = train_set[first_batch]
        
        start_time = time.time()
        collated = sampler.collate_fn(mock_batch)
        collate_time = time.time() - start_time
        print(f"  Collate time (first batch): {collate_time*1000:.2f}ms")
    
    print()

# Test each sampler
for sampler_name, sampler_fn in SAMPLERS.items():
    time_sampler_parts(sampler_name, sampler_fn, rank=0)

print("=== MEMORY USAGE TEST ===")
import psutil
import gc

def test_memory_usage(sampler_name, sampler_fn, rank=0):
    print(f"Testing memory usage for {sampler_name}:")
    
    # Get initial memory
    process = psutil.Process()
    initial_memory = process.memory_info().rss / 1024 / 1024  # MB
    
    # Create sampler
    sampler = sampler_fn(lengths=lengths, cluster_sizes=cluster_sizes, rank=rank)
    after_creation = process.memory_info().rss / 1024 / 1024
    
    # Generate some batches
    sampler.set_epoch(0)
    batches = []
    for i, batch in enumerate(iter(sampler)):
        batches.append(batch)
        if i >= 50:  # Generate 50 batches
            break
    
    after_batches = process.memory_info().rss / 1024 / 1024
    
    print(f"  Initial memory: {initial_memory:.1f}MB")
    print(f"  After creation: {after_creation:.1f}MB (+{after_creation-initial_memory:.1f}MB)")
    print(f"  After 50 batches: {after_batches:.1f}MB (+{after_batches-initial_memory:.1f}MB)")
    print(f"  Generated {len(batches)} batches")
    
    # Cleanup
    del sampler, batches
    gc.collect()
    print()

# Test memory usage for each sampler
for sampler_name, sampler_fn in SAMPLERS.items():
    test_memory_usage(sampler_name, sampler_fn, rank=0)

In [21]:
BATCH_SIZE = 32
CTX_LEN = 1024
NUM_GPUS = 1
EPOCHS = 10
SEED = 42
START_EPOCH = 0

SAMPLERS = {
    "Simple Distributed Batch Sampler": lambda lengths, cluster_sizes, rank: data.SimpleDistributedBatchSampler(
        dataset=train_set,
        length_key="length",
        cluster_size_key="cluster_size",
        max_length=CTX_LEN,
        total_length=BATCH_SIZE * CTX_LEN,
        num_replicas=NUM_GPUS,
        rank=rank,
        seed=SEED,
        epoch=START_EPOCH,
        shuffle=True,  # Shuffle batches for each epoch
),
    "Multipack QuadraticAttention": lambda lengths, cluster_sizes, rank: data.MultipackDistributedBatchSampler(lengths=lengths, 
                                                                                                cluster_sizes=cluster_sizes,
                                                                                                max_length=CTX_LEN,
                                                                                                total_length=BATCH_SIZE * CTX_LEN, 
                                                                                                num_replicas=NUM_GPUS, 
                                                                                                rank=rank,
                                                                                                seed=SEED,
                                                                                                epoch=START_EPOCH,
),
}


In [22]:
import axolotl.data_nested as data
import time
import numpy as np

DATASET_TRAIN = "/mnt/e/datasets/IPR036736_90_grouped/train"
DATASET_TEST = "/mnt/e/datasets/IPR036736_90_grouped/test"

train_loader, test_loader = data.get_dataloaders(
    train_batch_size=BATCH_SIZE,
    valid_batch_size=BATCH_SIZE,
    ngpus=NUM_GPUS,
    accum=1,
    train_path=DATASET_TRAIN,
    valid_path=DATASET_TEST,
    max_length=CTX_LEN,
    drop_last=True,
    num_workers=4,
    distributed=False,
    seed=42, # Seed for picking the data in the clusters and making torch generators.
    epoch=0,
    shuffle_each_epoch=True,
)

train_iter = iter(train_loader)
test_iter = iter(test_loader)

# Test first batch generation
print("Testing first batch generation:")
start_time = time.time()
try:
    first_train_batch = next(train_iter)
    first_train_time = time.time() - start_time
    print(f"  First train batch: {first_train_time:.2f}s")
    print(f"  First train batch tokens: {first_train_batch['input_ids'].offsets()}")
except:
    print("  First train batch: No batches generated")

start_time = time.time()
try:
    first_test_batch = next(test_iter)
    first_test_time = time.time() - start_time
    print(f"  First test batch: {first_test_time:.2f}s")
    print(f"  First train batch tokens: {first_test_batch['input_ids'].offsets()}")
except:
    print("  First test batch: No batches generated")

# Test subsequent batch generation
print("\nTesting subsequent batch generation:")
batch_count = 0
times = []
while True:
    try:
        start_time = time.time()
        batch = next(train_iter)
        batch_time = time.time() - start_time
        times.append(batch_time)
        batch_count += 1
        if batch_count % 10 == 0:  # Print every 10 batches
            print(f"  Batch {batch_count}: {batch['input_ids'].offsets()}")
    except StopIteration:
        print("  No more batches available")
        break

    if batch_count >= 50:  # Limit to 50 batches
        break

print(f"\nTotal batches generated: {batch_count}")
if times:
    print(f"Average batch time: {np.mean(times):.2f}s")
    print(f"Min/Max batch time: {np.min(times):.2f}s / {np.max(times):.2f}s")



Loading dataset from /mnt/e/datasets/IPR036736_90_grouped/train
Loading dataset from /mnt/e/datasets/IPR036736_90_grouped/test
Testing first batch generation:
  First train batch: 0.94s
  First train batch tokens: tensor([    0,    81,   169,   323,   399,   489,  1513,  2537,  3561,  4585,
         5528,  5618,  6642,  7666,  8690,  9714, 10738, 11762, 12786, 13810,
        14834, 15855, 16879, 17885, 17976, 19000, 20024, 20162, 20259, 21283,
        21384, 22408, 22517, 22602, 23626, 24650, 24731, 25706, 26623, 27647,
        27733, 28757, 29781, 30601, 31625, 32649])
  First test batch: 3.86s
  First train batch tokens: tensor([    0,   192,   604,   682,  1706,  2730,  3754,  4778,  4859,  5883,
         5969,  6993,  8017,  9041,  9200,  9273,  9355, 10379, 11403, 12427,
        13439, 13518, 14228, 15252, 16276, 16639, 16714, 17738, 18762, 19786,
        20810, 21834, 21920, 22564, 23588, 24612, 25636, 25742, 26766, 27790,
        27906, 27994, 28071, 28193, 28376, 28641, 29665, 