In [100]:
import random
import time

import numpy as np


def iid_divide(l, g):
    """
    https://github.com/TalwalkarLab/leaf/blob/master/data/utils/sample.py
    divide list `l` among `g` groups
    each group has either `int(len(l)/g)` or `int(len(l)/g)+1` elements
    returns a list of groups
    """
    num_elems = len(l)
    group_size = int(len(l) / g)
    num_big_groups = num_elems - g * group_size
    num_small_groups = g - num_big_groups
    glist = []
    for i in range(num_small_groups):
        glist.append(l[group_size * i: group_size * (i + 1)])
    bi = group_size * num_small_groups
    group_size += 1
    for i in range(num_big_groups):
        glist.append(l[bi + group_size * i:bi + group_size * (i + 1)])
    return glist


def split_list_by_indices(l, indices):
    """
    divide list `l` given indices into `len(indices)` sub-lists
    sub-list `i` starts from `indices[i]` and stops at `indices[i+1]`
    returns a list of sub-lists
    """
    res = []
    current_index = 0
    for index in indices:
        res.append(l[current_index: index])
        current_index = index

    return res


def iid_split(dataset, n_clients, frac, seed=1234):
    """
    split classification dataset among `n_clients` in an IID fashion. The dataset is split as follow:

    :param dataset:
    :type dataset: torch.utils.Dataset
    :param n_clients: number of clients
    :param frac: fraction of dataset to use
    :param seed:
    :return: list (size `n_clients`) of subgroups, each subgroup is a list of indices.
    """
    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    n_samples = int(len(dataset) * frac)
    selected_indices = rng.sample(list(range(len(dataset))), n_samples)
    rng.shuffle(selected_indices)

    return iid_divide(selected_indices, n_clients)


def by_labels_non_iid_split(dataset, n_classes, n_clients, n_clusters, alpha, frac, seed=1234):
    """
    split classification dataset among `n_clients`. The dataset is split as follow:
        1) classes are grouped into `n_clusters`
        2) for each cluster `c`, samples are partitioned across clients using dirichlet distribution

    Inspired by the split in "Federated Learning with Matched Averaging"__(https://arxiv.org/abs/2002.06440)

    :param dataset:
    :type dataset: torch.utils.Dataset
    :param n_classes: number of classes present in `dataset`
    :param n_clients: number of clients
    :param n_clusters: number of clusters to consider; if it is `-1`, then `n_clusters = n_classes`
    :param alpha: parameter controlling the diversity among clients
    :param frac: fraction of dataset to use
    :param seed:
    :return: list (size `n_clients`) of subgroups, each subgroup is a list of indices.
    """
    if n_clusters == -1:
        n_clusters = n_classes

    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    all_labels = list(range(n_classes))
    rng.shuffle(all_labels)
    clusters_labels = iid_divide(all_labels, n_clusters)

    label2cluster = dict()  # maps label to its cluster
    for group_idx, labels in enumerate(clusters_labels):
        for label in labels:
            label2cluster[label] = group_idx

    # get subset
    n_samples = int(len(dataset) * frac)
    selected_indices = rng.sample(list(range(len(dataset))), n_samples)

    clusters_sizes = np.zeros(n_clusters, dtype=int)
    clusters = {k: [] for k in range(n_clusters)}
    for idx in selected_indices:
        _, label = dataset[idx]
        group_id = label2cluster[label]
        clusters_sizes[group_id] += 1
        clusters[group_id].append(idx)

    for _, cluster in clusters.items():
        rng.shuffle(cluster)

    clients_counts = np.zeros((n_clusters, n_clients), dtype=np.int64)  # number of samples by client from each cluster

    for cluster_id in range(n_clusters):
        weights = np.random.dirichlet(alpha=alpha * np.ones(n_clients))
        clients_counts[cluster_id] = np.random.multinomial(clusters_sizes[cluster_id], weights)

    clients_counts = np.cumsum(clients_counts, axis=1)

    clients_indices = [[] for _ in range(n_clients)]
    for cluster_id in range(n_clusters):
        cluster_split = split_list_by_indices(clusters[cluster_id], clients_counts[cluster_id])

        for client_id, indices in enumerate(cluster_split):
            clients_indices[client_id] += indices

    return clients_indices


def pathological_non_iid_split(dataset, n_classes, n_clients, n_classes_per_client, frac=1, seed=1234):
    """
    split classification dataset among `n_clients`. The dataset is split as follow:
        1) sort the data by label
        2) divide it into `n_clients * n_classes_per_client` shards, of equal size.
        3) assign each of the `n_clients` with `n_classes_per_client` shards

    Inspired by the split in
     "Communication-Efficient Learning of Deep Networks from Decentralized Data"__(https://arxiv.org/abs/1602.05629)

    :param dataset:
    :type dataset: torch.utils.Dataset
    :param n_classes: umber of classes present in `dataset`
    :param n_clients: number of clients
    :param n_classes_per_client:
    :param frac: fraction of dataset to use
    :param seed:
    :return: list (size `n_clients`) of subgroups, each subgroup is a list of indices.
    """
    rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
    rng = random.Random(rng_seed)
    np.random.seed(rng_seed)

    # get subset
    n_samples = int(len(dataset) * frac)
    selected_indices = rng.sample(list(range(len(dataset))), n_samples)

    label2index = {k: [] for k in range(n_classes)}
    for idx in selected_indices:
        _, label = dataset[idx]
        label2index[label].append(idx)

    sorted_indices = []
    for label in label2index:
        sorted_indices += label2index[label]

    n_shards = n_clients * n_classes_per_client
    shards = iid_divide(sorted_indices, n_shards)
    random.shuffle(shards)
    tasks_shards = iid_divide(shards, n_clients)

    clients_indices = [[] for _ in range(n_clients)]
    for client_id in range(n_clients):
        for shard in tasks_shards[client_id]:
            clients_indices[client_id] += shard

    return clients_indices

In [101]:
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import json

In [102]:
training_data = datasets.MNIST(
    root="./benchmark/mnist/data/",
    train=True,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)

testing_data = datasets.MNIST(
    root="./benchmark/mnist/data/",
    train=False,
    download=False,
    transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]),
)

In [103]:
num_clients=300
# total_label=10

labels = training_data.targets
total_label = len(np.unique(labels))
dirichlet = 100

res = by_labels_non_iid_split(training_data, n_classes=total_label, n_clients=num_clients, n_clusters=-1, alpha=dirichlet, frac=0.2, seed=1)

dis_mtx = np.zeros([num_clients, total_label])
for client_id in range(len(res)):
    for sample_id in res[client_id]:
        label = training_data.targets[sample_id].item()
        dis_mtx[client_id][label] += 1

In [104]:
dis_mtx

array([[4., 6., 3., ..., 6., 2., 4.],
       [4., 2., 1., ..., 1., 7., 4.],
       [4., 2., 4., ..., 3., 8., 2.],
       ...,
       [2., 8., 6., ..., 2., 2., 6.],
       [3., 8., 1., ..., 4., 2., 4.],
       [3., 5., 0., ..., 4., 3., 6.]])

In [105]:
np.sum(dis_mtx, 1)

array([38., 34., 43., 41., 40., 31., 54., 36., 33., 35., 46., 42., 31.,
       40., 57., 28., 52., 26., 47., 30., 42., 39., 37., 44., 43., 46.,
       35., 35., 35., 39., 43., 42., 32., 36., 42., 41., 39., 47., 40.,
       33., 46., 36., 31., 46., 40., 46., 51., 46., 31., 52., 51., 36.,
       42., 45., 51., 52., 46., 46., 34., 45., 41., 33., 42., 40., 42.,
       30., 42., 39., 51., 34., 37., 30., 43., 45., 32., 36., 40., 32.,
       32., 32., 42., 43., 41., 34., 45., 34., 51., 41., 47., 41., 40.,
       37., 40., 44., 37., 41., 49., 44., 39., 39., 41., 32., 40., 40.,
       47., 51., 38., 48., 28., 35., 36., 34., 44., 37., 36., 48., 38.,
       42., 45., 26., 46., 35., 36., 39., 42., 40., 37., 31., 44., 42.,
       36., 43., 41., 33., 47., 46., 34., 36., 43., 47., 41., 44., 36.,
       37., 49., 42., 43., 37., 30., 31., 45., 37., 40., 29., 38., 55.,
       48., 41., 45., 35., 40., 37., 36., 33., 49., 35., 48., 50., 28.,
       42., 36., 38., 34., 39., 44., 46., 42., 30., 40., 52., 36

In [106]:
np.sum(dis_mtx, 0)

array([1200., 1327., 1149., 1222., 1191., 1072., 1203., 1271., 1154.,
       1211.])

In [107]:
res_dict = {}

for i in range(len(res)):
    res_dict[i] = res[i]

res_dict[0]

[55867,
 14805,
 15321,
 22745,
 35804,
 56618,
 42013,
 10906,
 31674,
 56052,
 53852,
 4571,
 16882,
 52937,
 51950,
 2037,
 17869,
 15014,
 30511,
 37847,
 26554,
 6299,
 49450,
 4756,
 17725,
 2084,
 36975,
 48596,
 35946,
 31701,
 19167,
 13725,
 3540,
 28981,
 33176,
 25551,
 277,
 45387]

In [108]:
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

In [109]:
import os
# Produce json file
dataset = "mnist"

dir_path = f"./dataset_idx/{dataset}/sparse_dir{dirichlet}/{num_clients}client/"
if not os.path.exists(dir_path):
    os.makedirs(dir_path)
json.dump(res_dict, open(dir_path + f"{dataset}_sparse.json", "w"), indent=4, cls=NpEncoder)
print("Output generated successfully")

# Produce stat file

np.savetxt(dir_path + f"{dataset}_sparse_stat.csv", dis_mtx, delimiter=",", fmt="%d")
print("Stats generated successfully")

Output generated successfully
Stats generated successfully
