In [187]:
import torch
from torchvision.datasets import MNIST,CIFAR100,CIFAR10
from torch.utils.data import random_split, DataLoader, IterableDataset
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.distributions.dirichlet import Dirichlet
from itertools import groupby
import random

In [186]:
def partition_by_class(dataset: IterableDataset):
    key = lambda x: x[1]
    return {k:list(vs) for k,vs in groupby(sorted(dataset,key=key), key)}

In [244]:
def split(partition, nb_nodes: int, alpha: float = 1.):
    splitter = Dirichlet(torch.ones(nb_nodes)*alpha)
    nodes = [list() for i in range(nb_nodes)]
    
    # iterate class and add a random nb of samples to each node
    for k,vs in partition.items():
        random.shuffle(vs)
        
        nbs = splitter.sample() * len(vs)
        indices = torch.cat((torch.zeros(1),nbs.cumsum(0).round()),0).long()
        
        for i,(start,stop) in enumerate(zip(indices[:-1],indices[1:])):
            nodes[i] += vs[start:stop]
    return nodes

In [273]:
train_dataset = MNIST('data', train=True, download=True, transform=transforms.ToTensor())

In [274]:
partition = partition_by_class(train_dataset)

In [275]:
print(f"total # of samples : {len(train_dataset)}")
for k,v in partition.items():
    print(f" - class {k} : {len(v)}")

total # of samples : 60000
 - class 0 : 5923
 - class 1 : 6742
 - class 2 : 5958
 - class 3 : 6131
 - class 4 : 5842
 - class 5 : 5421
 - class 6 : 5918
 - class 7 : 6265
 - class 8 : 5851
 - class 9 : 5949


In [276]:
nodes = split(partition,10,alpha=0.1)
for i,n in enumerate(nodes):
    print(f"node {i} :")
    for k,vs in partition_by_class(n).items():
        print(f" - class {k} : {len(vs)}")
    print(f"TOTAL      : {len(n)}\n")

node 0 :
 - class 0 : 1772
 - class 1 : 32
 - class 2 : 73
 - class 3 : 19
 - class 4 : 327
 - class 6 : 4592
 - class 9 : 257
TOTAL      : 7072

node 1 :
 - class 0 : 4
 - class 1 : 3
 - class 2 : 71
 - class 6 : 7
 - class 7 : 11
 - class 8 : 90
 - class 9 : 2562
TOTAL      : 2748

node 2 :
 - class 0 : 1777
 - class 1 : 4
 - class 3 : 75
 - class 5 : 83
 - class 6 : 186
 - class 7 : 3
 - class 8 : 79
TOTAL      : 2207

node 3 :
 - class 0 : 111
 - class 1 : 2593
 - class 2 : 192
 - class 3 : 29
 - class 4 : 563
 - class 5 : 1411
 - class 7 : 10
 - class 8 : 5651
TOTAL      : 10560

node 4 :
 - class 1 : 1
 - class 2 : 5149
 - class 3 : 1005
 - class 4 : 33
 - class 5 : 2483
 - class 6 : 1133
 - class 7 : 970
TOTAL      : 10774

node 5 :
 - class 5 : 26
 - class 7 : 1993
 - class 8 : 2
 - class 9 : 1
TOTAL      : 2022

node 6 :
 - class 0 : 26
 - class 1 : 1784
 - class 3 : 3731
 - class 4 : 193
 - class 5 : 1405
 - class 7 : 20
 - class 9 : 404
TOTAL      : 7563

node 7 :
 - class 1