In [3]:
import torch
from torch.utils.data import DataLoader, Dataset, Sampler, BatchSampler
import random
from torch.nn.utils.rnn import pad_sequence

# Define a simple dataset class for demonstration
class SimpleDataset(Dataset):
    def __init__(self, seq_lengths):
        # Generate sequences with specified lengths
        self.data = [torch.ones(length) for length in seq_lengths]
        self.labels = list(range(len(seq_lengths)))  # Dummy labels for indexing

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def __len__(self):
        return len(self.data)

# Define the GroupedSampler class
class GroupedSampler(Sampler):
    def __init__(self, seqs, batch_size):
        self.seqs = seqs
        self.batch_size = batch_size
        self.index_length_pairs = [(i, len(seq)) for i, seq in enumerate(self.seqs)]

    def __iter__(self):
        random.shuffle(self.index_length_pairs)
        
        grouped_indices = []
        group_size = self.batch_size * 10  # Smaller group size for demonstration
        
        for i in range(0, len(self.index_length_pairs), group_size):
            group = self.index_length_pairs[i:i + group_size]
            group_sorted = sorted(group, key=lambda x: x[1])
            grouped_indices.extend([idx for idx, _ in group_sorted])
        
        return iter(grouped_indices)

    def __len__(self):
        return len(self.seqs)

# Define a custom collate function for padding sequences in a batch
def collate_fn(batch):
    sequences, labels = zip(*batch)
    sequences_padded = pad_sequence(sequences, batch_first=True, padding_value=0)
    return sequences_padded, torch.tensor(labels)

# Create a small dataset of varying sequence lengths
small_seq_lengths = [5, 7, 10, 15, 20, 25, 30, 35, 40, 45, 50]
small_dataset = SimpleDataset(small_seq_lengths)

# Initialize GroupedSampler and BatchSampler
batch_size = 3
small_grouped_sampler = GroupedSampler(small_dataset.data, batch_size)
small_batch_sampler = BatchSampler(small_grouped_sampler, batch_size=batch_size, drop_last=False)

# DataLoader to fetch batches using small_batch_sampler with custom collate_fn
small_dataloader = DataLoader(small_dataset, batch_sampler=small_batch_sampler, collate_fn=collate_fn)

# Display batches to illustrate the grouping and padding process
for batch in small_dataloader:
    sequences_padded, labels = batch
    print("Batch padded sequences shape:", sequences_padded.shape)
    print("Padded sequences:", sequences_padded)
    print("Labels:", labels)
    print("Batch lengths:", [len(seq) for seq in sequences_padded])  # Print the padded length for each sequence in batch
    print("----")


Batch padded sequences shape: torch.Size([3, 10])
Padded sequences: tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]])
Labels: tensor([0, 1, 2])
Batch lengths: [10, 10, 10]
----
Batch padded sequences shape: torch.Size([3, 25])
Padded sequences: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1.]])
Labels: tensor([3, 4, 5])
Batch lengths: [25, 25, 25]
----
Batch padded sequences shape: torch.Size([3, 40])
Padded sequences: tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0.,
       