In [1]:
import tensorflow as tf
from sklearn import model_selection
import numpy as np
import os

import fl_util

np.random.seed(1234)
tf.set_random_seed(1234)

# Keras

In [2]:
def trainTestValSplit(dataName, x, y):
    # seed 를 고정시켜서 분할
    x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size=0.4, random_state=1234)
    x_val, x_test, y_val, y_test = model_selection.train_test_split(x_test, y_test, test_size=0.5, random_state=1234)
    trainData_by1Nid = np.array([ { 'x': x_train, 'y': y_train } ])
    valData_by1Nid = np.array([ { 'x': x_val, 'y': y_val } ])
    testData_by1Nid = np.array([ { 'x': x_test, 'y': y_test } ])
    print(len(trainData_by1Nid[0]['x']), len(valData_by1Nid[0]['x']), len(testData_by1Nid[0]['y']))
    
    fl_util.serialize(os.path.join(dataName, 'train'), trainData_by1Nid)
    fl_util.serialize(os.path.join(dataName, 'val'), valData_by1Nid)
    fl_util.serialize(os.path.join(dataName, 'test'), testData_by1Nid)
    
def prepare(dataName, expand_dims_x):
    if dataName == 'mnist-o':
        trainData, testData = tf.keras.datasets.mnist.load_data()
    elif dataName == 'mnist-f':
        trainData, testData = tf.keras.datasets.fashion_mnist.load_data()
    elif dataName == 'cifar10':
        trainData, testData = tf.keras.datasets.cifar10.load_data()
    else:
        raise Exception(DATA_NAME)
    x = np.concatenate((trainData[0], testData[0]), axis=0)
    y = np.concatenate((trainData[1], testData[1]), axis=0)
    
    x = np.array([ x_ / 255.0 for x_ in x ], dtype=np.float32) # Normalize
    if expand_dims_x == True:
        x = np.expand_dims(x, axis=-1)
    if dataName == 'cifar10':
        y = y.flatten() # cifar10 의 경우 flatten 필요
    print(x.shape, y.shape)
    
#     x = np.array([ x.flatten() / 255.0 if flatten else x / 255.0 for x in dataX ], dtype=np.float32)

#     if modelName == 'svm':
#         if dataName == 'cifar10':
#             # airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
#             vehicleClasses = [0, 1, 8, 9]
#             y = np.array([ -1 if y in vehicleClasses else 1 for y in dataY ], dtype=np.int32)
#         else:
#             raise Exception(dataName)
#     else:
#         y = np.array(dataY, dtype=np.int32)

    trainTestValSplit(dataName, x, y)

# MNIST-O
prepare('mnist-o', expand_dims_x=True)
prepare('mnist-f', expand_dims_x=True)
prepare('cifar10', expand_dims_x=False)

(70000, 28, 28, 1) (70000,)
42000 14000 14000
(70000, 28, 28, 1) (70000,)
42000 14000 14000
(60000, 32, 32, 3) (60000,)
36000 12000 12000


# LEAF - FEMNIST (./preprocess.sh -s iid --sf 0.1)

In [3]:
IMAGE_SIZE = 28

dataName = 'femnist'
uids, _, data = fl_util.readJsonDir(os.path.join(dataName, 'sampled'))

x = np.concatenate([ data[uid]['x'] for uid in uids ], axis=0)
y = np.concatenate([ data[uid]['y'] for uid in uids ], axis=0)

x = x.reshape((-1, IMAGE_SIZE, IMAGE_SIZE))

fid = 580
import matplotlib.pyplot as plt
plt.imshow(x[fid], cmap='gray')
plt.show()
print(y[fid])

x = np.expand_dims(x, axis=-1)
print(x.shape, y.shape)
print(len(np.unique(y)))

trainTestValSplit(dataName, x, y)

<Figure size 640x480 with 1 Axes>

8
(78353, 28, 28, 1) (78353,)
62
47011 15671 15671


# LEAF - CELEBA (./preprocess.sh -s iid --sf 0.05)

In [4]:
from PIL import Image

IMAGE_SIZE = 84
IMAGES_DIR = os.path.join('celeba', 'img_align_celeba')

def _load_image(img_name):
    img = Image.open(os.path.join(IMAGES_DIR, img_name))
    img = img.resize((IMAGE_SIZE, IMAGE_SIZE)).convert('RGB')
    img = np.array(img, dtype=np.float32)
    return img / 255.0

dataName = 'celeba'
uids, _, data = fl_util.readJsonDir(os.path.join(dataName, 'sampled'))
x_fileNames = np.concatenate([ data[uid]['x'] for uid in uids ], axis=0)
y = np.concatenate([ data[uid]['y'] for uid in uids ], axis=0)
print(x_fileNames.shape, y.shape)
x = np.array([ _load_image(x_fileName) for x_fileName in x_fileNames ], dtype=np.float32)
print(x.shape, y.shape)

trainTestValSplit(dataName, x, y)

(10014,) (10014,)
(10014, 84, 84, 3) (10014,)
6008 2003 2003
