In [10]:
def keep_first_n_per_class(dataset, n=10):

    class_count = {}
    filtered_dataset = []

    for data, class_label in dataset:
        if class_label not in class_count:
            class_count[class_label] = 0

        if class_count[class_label] < n:
            filtered_dataset.append((data, class_label))
            class_count[class_label] += 1

    return filtered_dataset

def get_datasets(data_dir, data_transforms, train_ratio, val_ratio, random_seed):
    if datasets_is_split(data_dir):
        train_path = os.path.join(data_dir, "train")
        val_path = os.path.join(data_dir, "val")
        test_path = os.path.join(data_dir, "test")

    else:
        same_dir = data_dir

        train_path = same_dir
        val_path = same_dir
        test_path = same_dir

    train_dataset = datasets.ImageFolder(train_path, transform = data_transforms['train'])
    val_dataset = datasets.ImageFolder(val_path, transform = data_transforms['val'])
    test_dataset = datasets.ImageFolder(test_path, transform = data_transforms['test'])

    for i, class_name in enumerate(train_dataset.classes):
        pprint(f"Class label {i}: {class_name}")

    if not datasets_is_split(data_dir):
        num_train = len(test_dataset)
        indices = list(range(num_train))
        pprint("--------- INDEX checking ---------")
        pprint(f"Original: {indices[:5]}")
        random.seed(random_seed)
        random.shuffle(indices)
        pprint(f"Shuffled: {indices[:5]}")
        pprint("--------- INDEX shuffled ---------\n")

        split_train = int(np.floor(train_ratio * num_train))
        split_val = split_train + int(np.floor(val_ratio * (num_train-split_train)))
        train_idx, val_idx, test_idx = indices[0:split_train], indices[split_train:split_val], indices[split_val:]
        train_dataset = Subset(train_dataset, train_idx)
        val_dataset = Subset(val_dataset, val_idx)
        test_dataset = Subset(test_dataset, test_idx)
        
    for ii in range(len(data_transforms.keys())-3):
        aug_dataset = datasets.ImageFolder(train_path, transform = data_transforms[f'aug{ii}'])
        aug_sub = Subset(aug_dataset, train_idx)
        train_dataset = ConcatDataset([train_dataset, aug_sub])

    return {
        'train' : train_dataset,
        'val' : val_dataset,
        'test' : test_dataset,
    }