In [2]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np

In [1]:
def mnistIID(dataset, num_users):
    num_images = int(len(dataset)/num_users)
    users_dict, indeces = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        # to force the random to be the same, use a seed
        # so that you can se the reflexion of new params and optimizers
        np.random.seed(i)
        # drop repeated items
        users_dict[i] = set(np.random.choice(indeces, num_images, replace=False))
        indeces = list(set(indeces) - users_dict[i])
    return users_dict

In [None]:
def mnistNonIID(dataset, num_users):
    classes, images = 100, 600
    classes_indx = [i for i in range(classes)]
    users_dict = {i: np.array([]) for i in range(num_users)}
    indeces = np.arrange(classes*images)
    unsorted_labels = dataset.train_labels.numpy()

    indeces_unlabels = np.vstack((indeces, unsorted_labels))
    labels = indeces_unlabels[:, indeces_unlabels[1, :].argsort()]
    indeces = labels[0, :]

    for i in range(num_users):
        # 2 classes
        temp = set(np.random.choice(classes_indx, 2, replace=False))
        # removing the added classes
        classes_indx = list(set(classes_indx) - temp)

        for i in temp:
            users_dict[i] = np.concatenate(
                (users_dict[i], indeces[t*images:(t+1)*images]), axis=0)

    return users_dict

In [6]:
# since I will use only MNIST, will keep this function impure, tied, and hardcorded
def load_dataset(num_users, iidtype):
    # https://discuss.pytorch.org/t/normalization-in-the-mnist-example/457
    # [14,1:12] for more on transforms
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    # change train and download params to match
    train_dataset = datasets.MNIST("./data", train=True, download=False, transform=transform)
    test_dataset = datasets.MNIST("./data", train=False, download=False, transform=transform)
    
    train_group, test_group = None, None
    
    if iidtype == 'iid':
        train_group = mnistIID(train_dataset, num_users)
        test_group = mnistIID(test_dataset, num_users)
    elif iidtype == 'noniid':
        train_group = mnistNonIID(train_dataset, num_users)
        test_group = mnistNonIID(test_dataset, num_users)
    else:
        train_group = None # FIXME
        test_group = None # FIXME

    return train_dataset, test_dataset, train_group, test_group

In [4]:
class FedDataset(Dataset):
    def __init__(self, dataset, indx):
        self.dataset = dataset
        self.indx = [int(i) for i in indx]
    
    def __len__(self):
        return len(self.indx)
    
    def __getitem__(self, item):
        images, label = self.dataset[self.indx[item]]
        return torch.tensor(images), torch.tensor(label)

In [5]:
def getActualImages(dataset, indeces, batch_size):
    return DataLoader(FedDataset(dataset, indeces), batch_size=batch_size, shuffle=True)