In [77]:
from datasets import Dataset, DatasetDict
import torch
import random

In [107]:
ds_lists = {}
img_id=0
for ar in [0.5, 1.0, 1.5]:
    if not ar in ds_lists:
        ds_lists[ar]=[]    
        for _ in range(round(4*ar)):
            ds_lists[ar].append(dict(img=f"image {img_id}, ar {ar}"))
            img_id+=1

ds = DatasetDict({
    str(split): Dataset.from_list(ds_lists[split]) 
    for split in ds_lists
})

ds

DatasetDict({
    0.5: Dataset({
        features: ['img'],
        num_rows: 2
    })
    1.0: Dataset({
        features: ['img'],
        num_rows: 4
    })
    1.5: Dataset({
        features: ['img'],
        num_rows: 6
    })
})

In [220]:
class ShapeBatchingDataset(torch.utils.data.Dataset):
    """
    Dataset wrapper that handles batching samples with different aspect ratios.
    
    Creates separate dataloaders for each aspect ratio (split) and cycles through them
    in round-robin fashion. Ensures all samples in a batch have the same aspect ratio.
    
    Args:
        hf_dataset: HuggingFace DatasetDict with splits representing different aspect ratios
        splits: List of split names to use from the dataset
        batch_size: Number of samples per batch (default: 8)
        seed: Random seed for samplers (default: 42)
    
    Yields:
        Tuples of (split_name, batch_data) where each batch contains samples
        of the same aspect ratio.
    """
    def __init__(self, hf_dataset, splits, batch_size=8):
        self.hf_dataset = hf_dataset
        self.splits = splits  # each split is one aspect ratio
        self.dataloaders = {}
        
        # Create a dataloader for each split (=aspect ratio)
        for split in splits:
            sampler = torch.utils.data.RandomSampler(hf_dataset[split])
            self.dataloaders[split] = torch.utils.data.DataLoader(
                hf_dataset[split], sampler=sampler, batch_size=batch_size
            )
    
    def __iter__(self):
        # Reset iterators at the beginning of each epoch
        iterators = { split: iter(dataloader) for split, dataloader in self.dataloaders.items() }
        active_dataloaders = set(self.splits)  # Track exhausted dataloaders
        current_split_index = -1
        
        while active_dataloaders:
            # Round robin: change split on every iteration (=after every batch OR after we unsucc. tried to get a batch) 
            current_split_index = (current_split_index + 1) % len(self.splits)
            split = self.splits[current_split_index]

            # Skip if this dataloader is exhausted
            if split not in active_dataloaders: continue
            
            # Try to get the next batch
            try:
                batch = next(iterators[split])
                yield split, batch
            # dataloader is exhausted
            except StopIteration: active_dataloaders.remove(split)

    def __len__(self):
        return sum(len(dataloader) for dataloader in self.dataloaders.values())


In [221]:
shape_dataset = ShapeBatchingDataset(
    hf_dataset=ds, splits=["0.5", "1.0", "1.5"], batch_size=3
)

print(len(shape_dataset))

# Iterate through batches
samples = 0 
for split, batch in shape_dataset:
    print(f"Processing batch with aspect ratio: {split}")
    print(batch)
    samples += len(batch["img"])
print("processed",samples,"samples")

5
Processing batch with aspect ratio: 0.5
{'img': ['image 1, ar 0.5', 'image 0, ar 0.5']}
Processing batch with aspect ratio: 1.0
{'img': ['image 2, ar 1.0', 'image 5, ar 1.0', 'image 3, ar 1.0']}
Processing batch with aspect ratio: 1.5
{'img': ['image 10, ar 1.5', 'image 8, ar 1.5', 'image 6, ar 1.5']}
Processing batch with aspect ratio: 1.0
{'img': ['image 4, ar 1.0']}
Processing batch with aspect ratio: 1.5
{'img': ['image 7, ar 1.5', 'image 11, ar 1.5', 'image 9, ar 1.5']}
processed 12 samples
