In [187]:
import torch
from torchvision.datasets import MNIST,CIFAR100,CIFAR10
from torch.utils.data import random_split, DataLoader, IterableDataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.distributions.dirichlet import Dirichlet
from itertools import groupby
import random

In [186]:
def partition_by_class(dataset: IterableDataset):
    key = lambda x: x[1]
    return {k:list(vs) for k,vs in groupby(sorted(dataset,key=key), key)}

In [244]:
def split(partition, nb_nodes: int, alpha: float = 1.):
    splitter = Dirichlet(torch.ones(nb_nodes)*alpha)
    nodes = [list() for i in range(nb_nodes)]
    
    # iterate class and add a random nb of samples to each node
    for k,vs in partition.items():
        random.shuffle(vs)
        
        nbs = splitter.sample() * len(vs)
        indices = torch.cat((torch.zeros(1),nbs.cumsum(0).round()),0).long()
        
        for i,(start,stop) in enumerate(zip(indices[:-1],indices[1:])):
            nodes[i] += vs[start:stop]
    return nodes

In [54]:
train_dataset = CIFAR10('data', train=True, download=True, transform=transforms.ToTensor())

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data


In [78]:
partition = partition_by_class(train_dataset)

In [256]:
print(f"total # of samples : {len(train_dataset)}")
for k,v in partition.items():
    print(f" - class {k} : {len(v)}")

total # of samples : 50000
 - class 0 : 5000
 - class 1 : 5000
 - class 2 : 5000
 - class 3 : 5000
 - class 4 : 5000
 - class 5 : 5000
 - class 6 : 5000
 - class 7 : 5000
 - class 8 : 5000
 - class 9 : 5000


In [268]:
nodes = split(partition,10,alpha=0.1)
for i,n in enumerate(nodes):
    print(f"node {i} :")
    for k,vs in partition_by_class(n).items():
        print(f" - class {k} : {len(vs)}")
    print(f"TOTAL      : {len(n)}\n")

node 0 :
 - class 0 : 608
 - class 2 : 541
 - class 4 : 3553
 - class 5 : 2641
 - class 7 : 2
 - class 8 : 980
 - class 9 : 249
TOTAL      : 8574

node 1 :
 - class 0 : 66
 - class 1 : 4150
 - class 3 : 2558
 - class 5 : 1541
 - class 6 : 1469
 - class 7 : 4064
 - class 8 : 295
 - class 9 : 132
TOTAL      : 14275

node 2 :
 - class 2 : 528
 - class 4 : 11
 - class 5 : 83
 - class 7 : 38
 - class 9 : 2274
TOTAL      : 2934

node 3 :
 - class 1 : 24
 - class 3 : 765
 - class 4 : 873
 - class 6 : 293
 - class 7 : 823
 - class 8 : 3419
 - class 9 : 85
TOTAL      : 6282

node 4 :
 - class 0 : 817
 - class 3 : 235
 - class 4 : 2
 - class 5 : 333
 - class 6 : 2344
 - class 7 : 68
 - class 8 : 262
 - class 9 : 130
TOTAL      : 4191

node 5 :
 - class 1 : 688
 - class 3 : 354
 - class 4 : 121
 - class 5 : 26
 - class 6 : 889
 - class 7 : 1
 - class 8 : 3
 - class 9 : 2077
TOTAL      : 4159

node 6 :
 - class 1 : 22
 - class 2 : 615
 - class 3 : 456
 - class 5 : 367
TOTAL      : 1460

node 7 :
 