In [1]:
import json
import os
import random
from collections import Counter

import numpy as np
from torchvision import datasets
from tqdm.notebook import tqdm

In [2]:
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)

In [3]:
path_data = os.path.join(os.path.dirname(os.getcwd()), 'data')
os.makedirs(path_data, exist_ok=True)

In [4]:
def display_dist(clients, num_labels):
    cnt = 0
    for k, v in clients.items():
        classes = []
        for c in range(num_labels):
            if c in np.unique(v['label']):
                classes.append(c)
            else:
                classes.append('_')
        print(f"[{k:>2}]  #sample  {len(v['index']):>5d}  >>  Classes  {' '.join(str(y) for y in classes)}  ({len(np.unique(v['label'])):>2d})  >>  {sorted(Counter(v['label']).items())}")
        cnt += len(v['index'])

    print(f'---\n[Total] {cnt}')
    print(f'\n#sample_left {len(train_dataset)} - {cnt} = {len(train_dataset) - cnt}\n')

# Get Data

In [5]:
dataset = 'cifar10'

if dataset == 'cifar10':
    train_dataset = datasets.CIFAR10(path_data, train=True, download=True)
elif dataset == 'svhn':
    train_dataset = datasets.SVHN(path_data, split='train', download=True)
elif dataset == 'fmnist':
    train_dataset = datasets.FashionMNIST(path_data, train=True, download=True)
elif dataset == 'mnist':
    train_dataset = datasets.MNIST(path_data, train=True, download=True)
elif dataset == 'stl10':
    train_dataset = datasets.STL10(path_data, split='train', download=True)
    
try:
    labels = np.asarray(train_dataset.targets)
except AttributeError:
    labels = np.asarray(train_dataset.labels)
    
num_labels = len(np.unique(labels))
N = len(train_dataset)
indices = np.arange(N)

Files already downloaded and verified


# Split Data

In [6]:
alpha = 0.1
num_clients = 100

train_size = N // num_clients

clients = {}
taken = []
for c in tqdm(range(num_clients)):
    p = np.random.dirichlet(np.repeat(alpha, num_labels))
    w = np.zeros(N)
    for l in range(num_labels):
        w[labels == l] = p[l]
    w[taken] = 0.0
    w /= w.sum()
    x = np.random.choice(indices, size=train_size, p=w, replace=False).tolist()
    y = labels[x].tolist()
    cli = clients.get(str(c), {'index': [], 'label': []})
    cli['index'].extend(x)
    cli['label'].extend(y)
    clients[str(c)] = cli
    taken.extend(x)

display_dist(clients, num_labels)

  0%|          | 0/100 [00:00<?, ?it/s]

[ 0]  #sample    500  >>  Classes  _ 1 _ _ 4 _ 6 _ _ _  ( 3)  >>  [(1, 104), (4, 20), (6, 376)]
[ 1]  #sample    500  >>  Classes  _ _ 2 3 4 _ 6 _ 8 _  ( 5)  >>  [(2, 48), (3, 34), (4, 71), (6, 343), (8, 4)]
[ 2]  #sample    500  >>  Classes  _ 1 _ 3 _ _ _ _ _ 9  ( 3)  >>  [(1, 486), (3, 6), (9, 8)]
[ 3]  #sample    500  >>  Classes  _ _ 2 3 4 5 _ 7 8 9  ( 7)  >>  [(2, 4), (3, 266), (4, 4), (5, 42), (7, 50), (8, 38), (9, 96)]
[ 4]  #sample    500  >>  Classes  _ _ _ 3 4 5 6 _ 8 _  ( 5)  >>  [(3, 1), (4, 57), (5, 264), (6, 1), (8, 177)]
[ 5]  #sample    500  >>  Classes  0 _ 2 _ 4 _ 6 7 8 9  ( 7)  >>  [(0, 143), (2, 2), (4, 7), (6, 39), (7, 8), (8, 31), (9, 270)]
[ 6]  #sample    500  >>  Classes  0 _ 2 _ 4 _ _ _ 8 9  ( 5)  >>  [(0, 31), (2, 220), (4, 175), (8, 8), (9, 66)]
[ 7]  #sample    500  >>  Classes  _ _ 2 _ _ _ _ _ _ 9  ( 2)  >>  [(2, 485), (9, 15)]
[ 8]  #sample    500  >>  Classes  _ 1 _ _ _ _ 6 7 _ 9  ( 4)  >>  [(1, 10), (6, 4), (7, 479), (9, 7)]
[ 9]  #sample    500  >>  Cl

# Save Data

In [7]:
mode = 'diri{}a{}'.format(num_clients, ''.join(e for e in str(alpha) if e.isalnum()))
filename = f'{dataset}_{mode}_{seed}'

with open(os.path.join(path_data, filename + '.json'), 'w') as f:
    json.dump(clients, f)