In [1]:
from datasets.utils.logging import disable_progress_bar
from torch.utils.data import DataLoader, Subset, random_split
import torchvision
import numpy as np
import pickle


DATA_PATH='/tmp/data/cifar10'
NUM_CLIENTS = 1
DUMP_FILE_NAME = '/tmp/data/CIFAR10-IID-1-CLIENT.pkl'

transform = torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                    (0.5, 0.5, 0.5), 
                    (0.5, 0.5, 0.5)),
            ])

cifar10_train = torchvision.datasets.CIFAR10(
    root=DATA_PATH,
    train=True,
    transform=transform,
    download=True
)

cifar10_test = torchvision.datasets.CIFAR10(
    root=DATA_PATH,
    train=False,
    transform=transform,
    download=True
)

def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

def cifar_non_iid(dataset, num_classes, num_users, alpha = 0.5):
    N = len(dataset)
    min_size = 0
    print("Dataset size:", N)

    dict_users = {}
    while min_size < 10:
        idx_batch = [[] for _ in range(num_users)]
        for k in range(num_classes):
            idx_k = np.where(np.asarray(dataset.targets) == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, num_users))
            ## Balance
            proportions = np.array([p*(len(idx_j)<N/num_users) for p,idx_j in zip(proportions,idx_batch)])
            proportions = proportions/proportions.sum()
            proportions = (np.cumsum(proportions)*len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j,idx in zip(idx_batch,np.split(idx_k,proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])

    for j in range(num_users):
        np.random.shuffle(idx_batch[j])
        dict_users[j] = idx_batch[j]
    return dict_users

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [2]:
idxs = cifar_iid(cifar10_train, 10)

In [4]:
print(idxs[0])

{6, 32785, 29, 30, 32, 39, 32808, 41, 46, 47, 32817, 32822, 60, 32830, 64, 67, 32835, 32861, 94, 32863, 32868, 32891, 129, 131, 139, 143, 32914, 32917, 163, 32951, 32958, 195, 32966, 32982, 32987, 229, 232, 33001, 236, 33007, 33013, 247, 33019, 253, 263, 33034, 33038, 33042, 33059, 33061, 306, 33077, 33083, 33086, 338, 33107, 33117, 354, 355, 33127, 33129, 362, 363, 33134, 366, 33138, 382, 33155, 33156, 391, 394, 401, 402, 33174, 420, 33194, 441, 444, 451, 455, 466, 477, 479, 487, 489, 33275, 510, 33284, 520, 33289, 33288, 33291, 33293, 33296, 33301, 535, 541, 543, 547, 553, 563, 33340, 588, 33367, 33371, 611, 33380, 616, 628, 631, 632, 638, 33408, 33412, 653, 33421, 33432, 33438, 671, 680, 687, 689, 694, 33464, 702, 33471, 33483, 728, 33497, 740, 746, 751, 752, 33521, 754, 33524, 760, 33535, 770, 772, 775, 776, 33544, 788, 789, 793, 33571, 812, 33583, 823, 33593, 33601, 834, 839, 841, 848, 869, 876, 33647, 33650, 884, 33654, 887, 898, 899, 33671, 908, 910, 33683, 33689, 33691, 33696, 