# Fix Random Seed

In [1]:
import copy
import json
import os
import random
from collections import Counter

import numpy as np
import torch
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, Subset
from torchvision import datasets

In [2]:
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)

# Setup

In [3]:
modetype = 'diri' # 'iid', 'diri', 'clsimb'

clsimb = [2, 7] # [2, 7], [2, 3, 4], [3, 3, 3] for 'clsimb'

splits = 9 * 2 # for 'iid' and 'diri'
if modetype == 'clsimb':
    splits = sum(clsimb)
alpha = 0.01 # 0.01, 1.5 for 'diri'

if modetype == 'clsimb':
    mode = '{}{}c{}'.format(modetype, sum(clsimb), ''.join(str(x) for x in clsimb))
elif modetype == 'diri':
    mode = '{}{}a{}'.format(modetype, splits, ''.join(e for e in str(alpha) if e.isalnum()))
elif modetype == 'iid':
    mode = f'{modetype}{splits}'
    if modetype == 'diri':
        mode = f'{modetype}{splits // 2}'
    
dataset = 'cifar10'

filename = f'{dataset}_{mode}_{seed}'

In [4]:
path_data = os.path.join(os.path.dirname(os.getcwd()), 'data')
os.makedirs(path_data, exist_ok=True)

# Get Data

In [5]:
classes = np.arange(10)

if dataset == 'cifar10':
    train_dataset = datasets.CIFAR10(path_data, train=True, download=True)
    train_size = 714 * 2
elif dataset == 'svhn':
    train_dataset = datasets.SVHN(path_data, split='train', download=True)
    train_size = 1065 * 2
    if clsimb == [2, 7]:
        classes = [1, 2, 3, 4]
elif dataset == 'fmnist':
    train_dataset = datasets.FashionMNIST(path_data, train=True, download=True)
    train_size = 857 * 2
elif dataset == 'mnist':
    train_dataset = datasets.MNIST(path_data, train=True, download=True)
    train_size = 851 * 2
    if clsimb == [2, 7]:
        classes = [1, 2, 3, 7]
elif dataset == 'stl10':
    train_dataset = datasets.STL10(path_data, split='train', download=True)
    train_size = 71 * 2

Files already downloaded and verified


In [6]:
try:
    y = train_dataset.targets
except AttributeError:
    y = train_dataset.labels
if type(y) is torch.Tensor:
    y = y.tolist()
    
# c = np.asarray(list(Counter(y).items()))
# c = c[c[:,1].argsort()[::-1][:len(c)]]
# j = c[:,1]
# t = j // 7

# for x in t:
#     print(x, len(j[j >= x * 4]), j[j >= x * 4], len(j[j >= x * 7]), j[j >= x * 7])

labels = np.asarray(y)
num_classes = len(np.unique(labels))

In [7]:
def clsplit(clsimb, classes, desig=2):
    outputs = []
    
    for _ in clsimb:
        c = np.random.choice(classes, desig, replace=False)
        c.sort()
        c = c.tolist()
        outputs.append(c)
        classes = [x for x in classes if x not in c]
        
    return outputs

In [8]:
classes = clsplit(clsimb, classes, 2)
print(classes)

[[1, 8], [2, 5]]


# Split Data

In [9]:
def dirichlet_split(dataset, splits, alpha=0.05, min_size=0, min_required_size=10, K=10, train_size=1/3, seed=42):
    try:
        labels = np.asarray(dataset.targets)
    except AttributeError:
        labels = np.asarray(dataset.labels)

    idx_batch = [[] for _ in range(splits)]
    clients = {}
    
    if train_size > 0:
        labels, _ = train_test_split(labels, train_size=train_size, random_state=seed, shuffle=True, stratify=labels)
    
    while min_size < min_required_size:
        idx_batch = [[] for _ in range(splits)]

        for k in range(K):
            idx_k = np.where(labels == k)[0]
            np.random.shuffle(idx_k)
            proportions = np.random.dirichlet(np.repeat(alpha, splits))
            proportions = np.array([p * (len(idx_j) < len(dataset) / splits) for p, idx_j in zip(proportions, idx_batch)])
            proportions = proportions / proportions.sum()
            proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
            idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))]
            min_size = min([len(idx_j) for idx_j in idx_batch])
            
    for j in range(splits):
        np.random.shuffle(idx_batch[j])
        clients[j] = {'index': idx_batch[j], 'label': labels[idx_batch[j]].tolist()}
    return clients

In [10]:
def stratified_split(dataset, splits, train_size, indices, seed=42):

    try:
        labels = np.asarray(dataset.targets)[indices]
    except AttributeError:
        labels = np.asarray(dataset.labels)[indices]
    clients = {}
    
    for i in range(splits):
        split_x, indices, split_y, labels = train_test_split(indices, labels, train_size=train_size, random_state=seed, shuffle=True, stratify=labels)
        clients[i] = {'index': split_x.tolist(), 'label': split_y.tolist()}

    return clients

In [11]:
def makesubset(dataset, target_cls):
    try:
        truths = dataset.targets
    except AttributeError:
        truths = dataset.labels
    boolarr = [True if y in target_cls else False for y in truths]
    indices = np.arange(len(dataset))
    subset_idx = indices[boolarr]
    return subset_idx

In [12]:
def class_imbalanced_split(dataset, splits, classes, train_size, seed=42):
    groups = [makesubset(dataset, c) for c in classes]

    clients = {}
    t, i = 0, 0
    for g in groups:
        client_group = stratified_split(dataset, max(splits), train_size, g, seed=seed)
        for k, v in client_group.items():
            if k < splits[i]:
                clients[t] = v
            else:
                i += 1
                break
            t += 1
    return clients

In [13]:
def display_dist(clients):
    cnt = 0
    for k, v in clients.items():
        classes = []
        for c in range(10):
            if c in np.unique(v['label']):
                classes.append(c)
            else:
                classes.append('_')
        print(f"[{k:>2}]  #sample  {len(v['index']):>5d}  >>  Classes  {' '.join(str(y) for y in classes)}  ({len(np.unique(v['label'])):>2d})  >>  {sorted(Counter(v['label']).items())}")
        cnt += len(v['index'])

    print(f'---\n[Total] {cnt}')
    print(f'\n#sample_left {len(train_dataset)} - {cnt} = {len(train_dataset) - cnt}\n')

In [14]:
indices = np.arange(len(train_dataset))

In [34]:
if modetype == 'clsimb':
    clients = class_imbalanced_split(dataset=train_dataset, splits=clsimb, classes=classes, train_size=train_size, seed=seed)
elif modetype == 'diri':
    clients = dirichlet_split(dataset=train_dataset, splits=splits, alpha=alpha, min_size=0, min_required_size=10, K=num_classes, train_size=0.3, seed=seed)
elif modetype == 'iid':
    clients = stratified_split(dataset=train_dataset, splits=splits, train_size=train_size, indices=indices, seed=seed)

display_dist(clients)

thres = 20
print(f'>>>>> Drop classes with <= {thres} sample!\n')

inter = dict.fromkeys(clients.keys())

for k, v in clients.items():
    inter[k] = {}
    arr = np.asarray(list(Counter(v['label']).items()))
    arr = arr[arr[:,1] > thres]
    boolarr = np.isin(np.asarray(v['label']), arr[:,0])
    inter[k]['index'] = list(np.asarray(v['index'])[boolarr])
    inter[k]['label'] = list(np.asarray(v['label'])[boolarr])
    
display_dist(inter)

[ 0]  #sample   1220  >>  Classes  _ 1 _ _ _ _ 6 _ _ _  ( 2)  >>  [(1, 17), (6, 1203)]
[ 1]  #sample    884  >>  Classes  _ _ _ 3 _ _ _ _ _ _  ( 1)  >>  [(3, 884)]
[ 2]  #sample   2755  >>  Classes  _ _ _ _ _ 5 _ _ 8 9  ( 3)  >>  [(5, 1303), (8, 1327), (9, 125)]
[ 3]  #sample    862  >>  Classes  _ 1 _ _ _ _ 6 7 _ _  ( 3)  >>  [(1, 32), (6, 296), (7, 534)]
[ 4]  #sample    169  >>  Classes  _ _ _ _ 4 _ _ _ _ _  ( 1)  >>  [(4, 169)]
[ 5]  #sample   1596  >>  Classes  0 _ _ _ 4 _ _ _ _ _  ( 2)  >>  [(0, 1498), (4, 98)]
[ 6]  #sample    911  >>  Classes  _ _ _ _ _ _ _ 7 _ _  ( 1)  >>  [(7, 911)]
[ 7]  #sample   1084  >>  Classes  _ _ _ _ _ _ _ _ 8 9  ( 2)  >>  [(8, 172), (9, 912)]
[ 8]  #sample   1447  >>  Classes  _ 1 _ _ _ _ _ _ _ _  ( 1)  >>  [(1, 1447)]
[ 9]  #sample     53  >>  Classes  _ _ _ _ _ _ _ 7 _ _  ( 1)  >>  [(7, 53)]
[10]  #sample    462  >>  Classes  _ _ _ _ _ _ _ _ _ 9  ( 1)  >>  [(9, 462)]
[11]  #sample   1221  >>  Classes  _ 1 _ _ 4 _ _ _ _ _  ( 2)  >>  [(1, 1), (4, 122

# Manual Mixture

In [23]:
def mix(inter, fromclient, target, toclient):
    indices = np.asarray(inter[fromclient]['index'])
    labels = np.asarray(inter[fromclient]['label'])
    boolarr = np.isin(labels, target)
    inter[toclient]['index'].extend(list(indices[boolarr]))
    inter[toclient]['label'].extend(list(labels[boolarr]))
    inter[fromclient]['index'] = list(indices[~boolarr])
    inter[fromclient]['label'] = list(labels[~boolarr])
    return inter

In [39]:
tmp = mix(inter, 17, 3, 0)
tmp = mix(tmp, 16, 5, 1)
tmp = mix(tmp, 15, 2, 4)
tmp = mix(tmp, 14, 5, 1)

tmp = mix(tmp, 12, 5, 1)
tmp = mix(tmp, 11, 4, 6)
tmp = mix(tmp, 10, 9, 2)
tmp = mix(tmp, 9, 9, 6)

tmp = mix(tmp, 2, 9, 8)
tmp = mix(tmp, 9, 7, 3)

display_dist(tmp)

[ 0]  #sample   1819  >>  Classes  _ _ _ 3 _ _ 6 _ _ _  ( 2)  >>  [(3, 616), (6, 1203)]
[ 1]  #sample   1081  >>  Classes  _ _ _ 3 _ 5 _ _ _ _  ( 2)  >>  [(3, 884), (5, 197)]
[ 2]  #sample   2630  >>  Classes  _ _ _ _ _ 5 _ _ 8 _  ( 2)  >>  [(5, 1303), (8, 1327)]
[ 3]  #sample    915  >>  Classes  _ 1 _ _ _ _ 6 7 _ _  ( 3)  >>  [(1, 32), (6, 296), (7, 587)]
[ 4]  #sample   1668  >>  Classes  _ _ 2 _ 4 _ _ _ _ _  ( 2)  >>  [(2, 1499), (4, 169)]
[ 5]  #sample   1596  >>  Classes  0 _ _ _ 4 _ _ _ _ _  ( 2)  >>  [(0, 1498), (4, 98)]
[ 6]  #sample   2131  >>  Classes  _ _ _ _ 4 _ _ 7 _ _  ( 2)  >>  [(4, 1220), (7, 911)]
[ 7]  #sample   1084  >>  Classes  _ _ _ _ _ _ _ _ 8 9  ( 2)  >>  [(8, 172), (9, 912)]
[ 8]  #sample   2034  >>  Classes  _ 1 _ _ _ _ _ _ _ 9  ( 2)  >>  [(1, 1447), (9, 587)]
[ 9]  #sample      0  >>  Classes  _ _ _ _ _ _ _ _ _ _  ( 0)  >>  []
[10]  #sample      0  >>  Classes  _ _ _ _ _ _ _ _ _ _  ( 0)  >>  []
[11]  #sample      0  >>  Classes  _ _ _ _ _ _ _ _ _ _  ( 0)  >>

# Clean Up

In [40]:
def cleanup(inter):
    for k in list(inter.keys()):
        if len(inter[k]['label']) == 0:
            inter.pop(k)
        else:
            inter[k]['index'] = [int(x) for x in inter[k]['index']]
            inter[k]['label'] = [int(x) for x in inter[k]['label']]
    clients = dict.fromkeys([str(x) for x in inter.keys()])
    for i, (k, v) in enumerate(inter.items()):
        clients[str(i)] = v
    return clients

In [41]:
tmp = cleanup(tmp)
display_dist(tmp)

clients = cleanup(tmp)
display_dist(clients)

[ 0]  #sample   1819  >>  Classes  _ _ _ 3 _ _ 6 _ _ _  ( 2)  >>  [(3, 616), (6, 1203)]
[ 1]  #sample   1081  >>  Classes  _ _ _ 3 _ 5 _ _ _ _  ( 2)  >>  [(3, 884), (5, 197)]
[ 2]  #sample   2630  >>  Classes  _ _ _ _ _ 5 _ _ 8 _  ( 2)  >>  [(5, 1303), (8, 1327)]
[ 3]  #sample    915  >>  Classes  _ 1 _ _ _ _ 6 7 _ _  ( 3)  >>  [(1, 32), (6, 296), (7, 587)]
[ 4]  #sample   1668  >>  Classes  _ _ 2 _ 4 _ _ _ _ _  ( 2)  >>  [(2, 1499), (4, 169)]
[ 5]  #sample   1596  >>  Classes  0 _ _ _ 4 _ _ _ _ _  ( 2)  >>  [(0, 1498), (4, 98)]
[ 6]  #sample   2131  >>  Classes  _ _ _ _ 4 _ _ 7 _ _  ( 2)  >>  [(4, 1220), (7, 911)]
[ 7]  #sample   1084  >>  Classes  _ _ _ _ _ _ _ _ 8 9  ( 2)  >>  [(8, 172), (9, 912)]
[ 8]  #sample   2034  >>  Classes  _ 1 _ _ _ _ _ _ _ 9  ( 2)  >>  [(1, 1447), (9, 587)]
---
[Total] 14958

#sample_left 50000 - 14958 = 35042

[ 0]  #sample   1819  >>  Classes  _ _ _ 3 _ _ 6 _ _ _  ( 2)  >>  [(3, 616), (6, 1203)]
[ 1]  #sample   1081  >>  Classes  _ _ _ 3 _ 5 _ _ _ _  ( 2

# Save Data

In [42]:
with open(os.path.join(path_data, filename + '.json'), 'w') as f:
    json.dump(clients, f)