In [None]:
import data_nested as data
import torch

train_dl, eval_dl = data.get_dataloaders(128,
                                            128,
                                            1,
                                            1,
                                            '/home/kkj/axolotl/datasets/IPR036736_90_grouped/train',
                                            '/home/kkj/axolotl/datasets/IPR036736_90_grouped/valid',
                                            1024,
                                            False,
                                            4, # num_workers
                                            distributed=True,
                                            shuffle_each_epoch=True,
)

train_iter = iter(train_dl)
eval_iter = iter(eval_dl)

In [None]:
print('train')
for i, batch in enumerate(eval_iter):
    seq_lens: torch.Tensor = batch['input_ids'].offsets().diff()
    if i % 100 == 0:
        print(i, seq_lens.sum())
    
    if seq_lens.all():
        pass
    else:
        print(i, seq_lens)
        print(seq_lens.sum())
        raise ValueError('seq_lens not all')

In [3]:
# Test distributed setup by mocking the distributed environment
import torch.distributed as dist
from unittest.mock import patch
import os

import data_nested as data

print("Testing Distributed DataLoader Setup")
print("=" * 50)

# Mock distributed environment for testing
def mock_distributed_setup(world_size=2, rank=0):
    """Mock distributed environment for testing purposes"""
    
    # Patch the distributed functions
    def mock_get_world_size():
        return world_size
    
    def mock_get_rank():
        return rank
    
    def mock_is_available():
        return True
    
    return patch.multiple(
        dist,
        get_world_size=mock_get_world_size,
        get_rank=mock_get_rank,
        is_available=mock_is_available
    )

# Test with 2 ranks (simulating 2 GPUs)
world_size = 2
batches_by_rank = {}

for rank in range(world_size):
    print(f"\n--- Testing Rank {rank}/{world_size-1} ---")
    
    with mock_distributed_setup(world_size=world_size, rank=rank):
        # Create distributed dataloader
        train_dl_dist, eval_dl_dist = data.get_dataloaders(
            128,  # train_batch_size
            128,  # valid_batch_size
            world_size,  # ngpus
            1,    # accum
            '/home/kkj/axolotl/datasets/IPR036736_90_grouped/train',
            '/home/kkj/axolotl/datasets/IPR036736_90_grouped/valid',
            1024, # max_length
            False, # drop_last
            0,    # num_workers (set to 0 for testing to avoid multiprocessing issues)
            distributed=True,
        )
        
        eval_iter_dist = iter(eval_dl_dist)
        
        # Collect first few batches for this rank
        rank_batches = []
        rank_total_tokens = 0
        
        for i, batch in enumerate(eval_iter_dist):
            if i >= 10:  # Only test first 10 batches
                break
                
            seq_lens = batch['input_ids'].offsets().diff()
            total_tokens = seq_lens.sum().item()
            rank_total_tokens += total_tokens
            
            rank_batches.append({
                'batch_idx': i,
                'num_sequences': len(seq_lens),
                'total_tokens': total_tokens,
                'seq_lengths': seq_lens.tolist()
            })
            
            print(f"  Batch {i}: {len(seq_lens)} sequences, {total_tokens} tokens")
        
        batches_by_rank[rank] = {
            'batches': rank_batches,
            'total_tokens': rank_total_tokens
        }
        
        print(f"  Rank {rank} processed {len(rank_batches)} batches, {rank_total_tokens} total tokens")

# Analyze distribution
print(f"\n--- Distribution Analysis ---")
print(f"World size: {world_size}")

for rank in range(world_size):
    rank_data = batches_by_rank[rank]
    print(f"Rank {rank}: {len(rank_data['batches'])} batches, {rank_data['total_tokens']} tokens")

# Check if ranks are getting different batches (they should be!)
if world_size == 2:
    rank0_tokens = [b['total_tokens'] for b in batches_by_rank[0]['batches']]
    rank1_tokens = [b['total_tokens'] for b in batches_by_rank[1]['batches']]
    
    print(f"\nRank 0 token counts: {rank0_tokens[:5]}...")
    print(f"Rank 1 token counts: {rank1_tokens[:5]}...")
    
    # Check if they're different (which indicates proper distribution)
    if rank0_tokens != rank1_tokens:
        print("✓ Ranks are getting different batches (good!)")
    else:
        print("⚠ Ranks are getting identical batches (potential issue)")

print("\n--- Testing Epoch Cycling ---")
# Test epoch cycling with distributed setup
with mock_distributed_setup(world_size=2, rank=0):
    train_dl_cycle, _ = data.get_dataloaders(
        64, 64, 2, 1,
        '/home/kkj/axolotl/datasets/IPR036736_90_grouped/train',
        '/home/kkj/axolotl/datasets/IPR036736_90_grouped/valid',
        1024, False, 0, distributed=True
    )
    
    train_iter_cycle = iter(train_dl_cycle)
    
    # Get a few batches and check epoch advancement
    for i in range(5):
        batch = next(train_iter_cycle)
        seq_lens = batch['input_ids'].offsets().diff()
        print(f"Step {i}: Epoch ?, {len(seq_lens)} sequences, {seq_lens.sum()} tokens")

print("\nDistributed testing complete!")

Testing Distributed DataLoader Setup

--- Testing Rank 0/1 ---
  Batch 0: 97 sequences, 65536 tokens
  Batch 0: 97 sequences, 65536 tokens
  Batch 1: 86 sequences, 65536 tokens
  Batch 1: 86 sequences, 65536 tokens
  Batch 2: 102 sequences, 65536 tokens
  Batch 2: 102 sequences, 65536 tokens
  Batch 3: 94 sequences, 65536 tokens
  Batch 3: 94 sequences, 65536 tokens
  Batch 4: 97 sequences, 65536 tokens
  Batch 4: 97 sequences, 65536 tokens
  Batch 5: 95 sequences, 65536 tokens
  Batch 5: 95 sequences, 65536 tokens
  Batch 6: 101 sequences, 65536 tokens
  Batch 6: 101 sequences, 65536 tokens
  Batch 7: 100 sequences, 65536 tokens
  Batch 7: 100 sequences, 65536 tokens
  Batch 8: 81 sequences, 53199 tokens
  Batch 8: 81 sequences, 53199 tokens
  Batch 9: 106 sequences, 65536 tokens
  Batch 9: 106 sequences, 65536 tokens
  Rank 0 processed 10 batches, 643023 total tokens

--- Testing Rank 1/1 ---
  Rank 0 processed 10 batches, 643023 total tokens

--- Testing Rank 1/1 ---
  Batch 0: 91 s