# Keras

In [1]:
import tensorflow as tf
from sklearn import model_selection
import numpy as np
import os

import fl_util

def prepare(dataName, flatten=True):
    if dataName == 'mnist-o':
        trainData, testData = tf.keras.datasets.mnist.load_data()
    elif dataName == 'mnist-f':
        trainData, testData = tf.keras.datasets.fashion_mnist.load_data()
    elif dataName == 'cifar10':
        trainData, testData = tf.keras.datasets.cifar10.load_data()
    else:
        raise Exception(DATA_NAME)
    x = np.concatenate((trainData[0], testData[0]), axis=0)
    y = np.concatenate((trainData[1], testData[1]), axis=0)
    if dataName == 'cifar10':
        y = y.flatten() # cifar10 의 경우 flatten 필요
    print(x.shape, y.shape)
        
    # Normalize
    x = np.array([ x_ / 255.0 for x_ in x ], dtype=np.float32)
#     x = np.array([ x.flatten() / 255.0 if flatten else x / 255.0 for x in dataX ], dtype=np.float32)

#     if modelName == 'svm':
#         if dataName == 'cifar10':
#             # airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
#             vehicleClasses = [0, 1, 8, 9]
#             y = np.array([ -1 if y in vehicleClasses else 1 for y in dataY ], dtype=np.int32)
#         else:
#             raise Exception(dataName)
#     else:
#         y = np.array(dataY, dtype=np.int32)

    # seed 를 고정시켜서 분할
    x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size=0.4, random_state=1234)
    x_val, x_test, y_val, y_test = model_selection.train_test_split(x_test, y_test, test_size=0.5, random_state=1234)
    trainData_by1Nid = np.array([ { 'x': x_train, 'y': y_train } ])
    valData_by1Nid = np.array([ { 'x': x_val, 'y': y_val } ])
    testData_by1Nid = np.array([ { 'x': x_test, 'y': y_test } ])
    print(len(trainData_by1Nid[0]['x']), len(valData_by1Nid[0]['x']), len(testData_by1Nid[0]['y']))
    
    fl_util.serialize(os.path.join(dataName, 'train'), trainData_by1Nid)
    fl_util.serialize(os.path.join(dataName, 'val'), valData_by1Nid)
    fl_util.serialize(os.path.join(dataName, 'test'), testData_by1Nid)

# MNIST-O
prepare('mnist-o')
prepare('mnist-f')
prepare('cifar10')

(70000, 28, 28) (70000,)
42000 14000 14000
(70000, 28, 28) (70000,)
42000 14000 14000
(60000, 32, 32, 3) (60000,)
36000 12000 12000


# Leaf (./preprocess.sh -s iid --sf 0.1 로 생성)

In [2]:
dataName = 'femnist'
uids, _, data = fl_util.readJsonDir(os.path.join(dataName, 'sampled'))
x = np.concatenate([ data[uid]['x'] for uid in uids ], axis=0)
y = np.concatenate([ data[uid]['y'] for uid in uids ], axis=0)
print(x.shape, y.shape)

# Normalize
x = np.array([ x_ / 255.0 for x_ in x ], dtype=np.float32)

# seed 를 고정시켜서 분할
x_train, x_test, y_train, y_test = model_selection.train_test_split(x, y, test_size=0.4, random_state=1234)
x_val, x_test, y_val, y_test = model_selection.train_test_split(x_test, y_test, test_size=0.5, random_state=1234)
trainData_by1Nid = np.array([ { 'x': x_train, 'y': y_train } ])
valData_by1Nid = np.array([ { 'x': x_val, 'y': y_val } ])
testData_by1Nid = np.array([ { 'x': x_test, 'y': y_test } ])
print(len(trainData_by1Nid[0]['x']), len(valData_by1Nid[0]['x']), len(testData_by1Nid[0]['y']))

fl_util.serialize(os.path.join(dataName, 'train'), trainData_by1Nid)
fl_util.serialize(os.path.join(dataName, 'val'), valData_by1Nid)
fl_util.serialize(os.path.join(dataName, 'test'), testData_by1Nid)

# data = {}
# data['people'] = []
# data['people'].append({
#     'name': 'Scott',
#     'website': 'stackabuse.com',
#     'from': 'Nebraska'
# })
# data['people'].append({
#     'name': 'Larry',
#     'website': 'google.com',
#     'from': 'Michigan'
# })
# data['people'].append({
#     'name': 'Tim',
#     'website': 'apple.com',
#     'from': 'Alabama'
# })

# with open('data.txt', 'w') as outfile:
#     json.dump(data, outfile)

# with open('data.txt') as json_file:
#     data = json.load(json_file)
#     for p in data['people']:
#         print('Name: ' + p['name'])
#         print('Website: ' + p['website'])
#         print('From: ' + p['from'])
#         print('')

(78353, 784) (78353,)
47011 15671 15671
