In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np

class GroupDataset(Dataset):
    def __init__(self, X, y, groups):
        self.X = X
        self.y = y
        self.groups = groups
        self.group_to_indices = self._group_indices()

    def _group_indices(self):
        group_to_indices = {}
        for idx, group in enumerate(self.groups):
            if group not in group_to_indices:
                group_to_indices[group] = []
            group_to_indices[group].append(idx)
        return group_to_indices

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx], self.groups[idx]

class ExperimentSampler:
    def __init__(self, group_to_indices, batch_size):
        self.group_to_indices = group_to_indices
        self.batch_size = batch_size
        self.group_order = list(group_to_indices.keys())
        np.random.shuffle(self.group_order)

    def __iter__(self):
        for group in self.group_order:
            indices = self.group_to_indices[group]
            for i in range(0, len(indices), self.batch_size):
                yield indices[i:i + self.batch_size]

    def __len__(self):
        # Total number of batches across all groups
        total_batches = 0
        for indices in self.group_to_indices.values():
            total_batches += (len(indices) + self.batch_size - 1) // self.batch_size
        return total_batches

batch_size = 4

# Example tensors (replace with your actual data)
X_tensor = torch.randn(100, 10)
y_tensor = torch.randint(0, 2, (100,))
groups = np.random.randint(0, 5, 100)

# Create dataset
dataset = GroupDataset(X_tensor, y_tensor, groups)

# Create experiment sampler
experiment_sampler = ExperimentSampler(dataset.group_to_indices, batch_size)

# Create data loader
data_loader = DataLoader(dataset, batch_sampler=experiment_sampler)

# Example: iterating through the data loader
for batch_indices in data_loader:
    X_batch = dataset.X[batch_indices]
    y_batch = dataset.y[batch_indices]
    print(X_batch.shape, y_batch.shape, batch_indices)
