In [None]:
import os
import random
random.seed(99)
import shutil

import numpy as np
np.random.seed(99)

In [None]:
IS_RESET = True
SRC_DATA = input('SRC_DATA: ')
DST_DATA = input('DST_DATA: ')
CLIENTS = 15 # 1 is validation set (index 0)
VAL_RATE = 0.1
# balanced_iid balanced_noniid unbalanced_iid unbalanced_noniid as_is
TYPE = 'non-iid-2c'

In [None]:
def get_files(root_dir, ext=('.jpg', '.png')):
    queue = [root_dir]
    while len(queue) != 0:
        nest_dir = queue.pop()
        label = None
        files = list()
        with os.scandir(nest_dir) as it:
            for entry in it:
                if not entry.name.startswith('.') and entry.is_file():
                    if entry.name.endswith(ext):
                        label = os.path.basename(os.path.dirname(entry.path)) # dirname is label
                        files.append(entry.path)
                elif not entry.name.startswith('.') and entry.is_dir():
                    queue.append(entry.path)
            if label is not None:
                yield label, files

def get_labels(root_dir):
    labels = list()
    with os.scandir(root_dir) as it:
        for entry in it:
            if not entry.name.startswith('.') and entry.is_dir():
                label = os.path.basename(entry.path) # dirname is label
                labels.append(label)
    return labels

In [None]:
data_root = os.path.abspath(os.path.expanduser(SRC_DATA))
dst_base = os.path.abspath(os.path.expanduser(DST_DATA))

if IS_RESET:
    if os.path.exists(dst_base):
        shutil.rmtree(dst_base)

if TYPE == 'iid':
    for label, files in get_files(data_root):
        random.shuffle(files)
        nfile = len(files)
        # copy validation set
        nval = int(nfile*VAL_RATE)
        dst = os.path.join(dst_base, '0', label)
        os.makedirs(dst, exist_ok=True)
        for path in files[:nval]:
            shutil.copy(path, dst)
        files = files[nval:]
        # copy trainset
        ntrain = (nfile - nval)//CLIENTS
        for i in range(1, CLIENTS+1):
            dst = os.path.join(dst_base, str(i), label)
            os.makedirs(dst, exist_ok=True)
            for path in files[:ntrain]:
                shutil.copy(path, dst)
            files = files[ntrain:]
        print(f'Done {label}')
elif TYPE == 'non-iid':
    i = 1
    for label, files in get_files(data_root):
        random.shuffle(files)
        nfile = len(files)
        # copy validation set
        nval = int(nfile*VAL_RATE)
        dst = os.path.join(dst_base, '0', label)
        os.makedirs(dst, exist_ok=True)
        for path in files[:nval]:
            shutil.copy(path, dst)
        files = files[nval:]
        # copy trainset
        ntrain = (nfile - nval)
        dst = os.path.join(dst_base, str(i), label)
        os.makedirs(dst, exist_ok=True)
        for path in files[:ntrain]:
            shutil.copy(path, dst)
        print(f'Done {label}')
        i = i + 1
elif TYPE == 'non-iid-2c':
    labels = ['gmail', 'facebook', 'email', 'skype', 'spotify',
              'torrent', 'netflix', 'hangout', 'aim', 'youtube',
              'sftp', 'ftps', 'scp', 'voipbuster', 'vimeo']
    label_weights = [ 5, 8, 3, 5, 10,
                      5, 8, 5, 5, 10,
                      3, 3, 3, 5, 5]
    clients = dict()
    for i in range(1, 15+1):
        l = list()
        l.append(labels[i-1])
        l.extend(random.choices(labels, weights=label_weights, k=3))
        clients[str(i)] = l
    lc = dict()
    ll = dict()
    for key, value in clients.items():
        for l in value:
            t = ll.get(l, set())
            t.add(key)
            ll[l] = t
            lc[l] = lc.get(l, 0) + 1

    for label, files in get_files(data_root):
        random.shuffle(files)
        nfile = len(files)
        # copy validation set
        nval = int(nfile*VAL_RATE)
        dst = os.path.join(dst_base, '0', label)
        os.makedirs(dst, exist_ok=True)
        for path in files[:nval]:
            shutil.copy(path, dst)
        files = files[nval:]
        # copy trainset
        ntrain = (nfile - nval)//len(ll[label])
        for i in ll[label]:
            dst = os.path.join(dst_base, str(i), label)
            os.makedirs(dst, exist_ok=True)
            for path in files[:ntrain]:
                shutil.copy(path, dst)
            files = files[ntrain:]
        print(f'Done {label}')
