## Fashion MNIST 이미지 추출
### 클라이언트별 IID

In [None]:
import os
import random

import tensorflow as tf
import numpy as np
import pandas as pd
import PIL.Image as Image

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()
images = np.concatenate((train_images, test_images))
labels = np.concatenate((train_labels, test_labels))

In [None]:
# https://www.tensorflow.org/api_docs/python/tf/keras/datasets/fashion_mnist/load_data
labelnames = ['T-shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
labelnames

In [None]:
clients = 10
div = clients + 1

In [None]:
counter = {}
output = os.path.abspath(os.path.expanduser('dataset'))
for idx, data in enumerate(zip(images, labels), start=0):
    image = Image.fromarray(data[0])
    label = labelnames[data[1]]
    num = counter.get(label, 0)
    party = num%div
    if party != 0:
        party = random.randint(1, clients)
    odir = os.path.join(output, f'{party}', label)
    os.makedirs(odir, exist_ok=True)
    opath = os.path.join(odir, f'{num:04d}.jpg')
    image.save(opath)
    counter[label] = num + 1

## 클라이언트별 학습

In [None]:
dataset_root = os.path.abspath(os.path.expanduser('dataset'))
with os.scandir(dataset_root) as it:
    for entry in it:
        if not entry.name.startswith('.') and entry.is_dir():
            print(entry, entry.path, entry.name)

In [None]:
# Datasets
datasets = {}
dataset_root = os.path.abspath(os.path.expanduser('dataset'))
with os.scandir(dataset_root) as it:
    for entry in it:
        if not entry.name.startswith('.') and entry.is_dir():
            if entry.name == '0':
                continue
            image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255, validation_split=0.2)
            train = image_generator.flow_from_directory(entry.path,
                                                        classes=labelnames,
                                                        target_size=(32, 32),
                                                        subset='training',
                                                        shuffle=True)
            test = image_generator.flow_from_directory(entry.path,
                                                       classes=labelnames,
                                                       target_size=(32, 32),
                                                       subset='validation',
                                                       shuffle=True)
            datasets[entry.name] = (train, test)
datasets.keys()

In [None]:
# Datasets
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
datasets['0'] = image_generator.flow_from_directory(os.path.join(dataset_root, '0'),
                                                    classes=labelnames, 
                                                    target_size=(32, 32), 
                                                    shuffle=True)
datasets.keys()

In [None]:
model = tf.keras.models.Sequential([
                       tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
                       tf.keras.layers.MaxPooling2D((2, 2)),
                       tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                       tf.keras.layers.MaxPooling2D((2, 2)),
                       tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                       tf.keras.layers.Flatten(),
                       tf.keras.layers.Dense(64, activation='relu'),
                       tf.keras.layers.Dense(len(labelnames))])
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

In [None]:
# Base model
odir = os.path.join('models', '0', '0')
model.save(odir)

In [None]:
rounds = 10
epochs = 10

In [None]:
for r in range(rounds):
    weights = []
    rpath = os.path.join('models', f'{r}')
    for c in range(1, clients+1):
        mpath = os.path.join(rpath, '0')
        model = tf.keras.models.load_model(mpath)
        history = model.fit(datasets[f'{c}'][0], epochs=epochs, verbose=0)
        opath = os.path.join(rpath, f'{c}')
        model.save(opath)
        if c == 1:
            weights = model.get_weights()
        else:
            for idx, weight in enumerate(model.get_weights()):
                weights[idx] = weights[idx] + weight
        print(f'Local train: round #{r} with clinent #{c}')
    for idx, weight in enumerate(weights):
        weights[idx] = weights[idx] / clients
    mpath = os.path.join(rpath, '0')
    model = tf.keras.models.load_model(mpath)
    model.set_weights(weights)
    ndir = os.path.join('models', f'{r+1}')
    npath = os.path.join(ndir, '0')
    model.save(npath)
    metric = model.evaluate(datasets['0'], verbose=0)
    print(f'Global aggregation: round #{r+1} for {metric}')

In [None]:
global_acc = []
global_loss = []
federated_acc = []
federated_loss = []
for r in range(1, rounds+1):
    weights = []
    mpath = os.path.join('models', f'{r}', '0')
    model = tf.keras.models.load_model(mpath)
    acc = []
    loss = []
    for c in range(0, clients+1):
        if c == 0:
            metric = model.evaluate(datasets[f'{c}'], verbose=0)
            global_loss.append(metric[0])
            global_acc.append(metric[1])
            print(f'round #{r} - Global: {metric}')
        else:
            metric = model.evaluate(datasets[f'{c}'][1], verbose=0)
            loss.append(metric[0])
            acc.append(metric[1])
    loss = sum(loss) / clients
    acc = sum(acc) / clients
    federated_loss.append(loss)
    federated_acc.append(acc)
    print(f'round #{r} - Federated: {(loss, acc)}')

In [None]:
import matplotlib.pyplot as plt # 시각화 도구
%matplotlib inline

In [None]:
X = range(1, rounds+1)
Y1_loss = []
Y1_acc = []
Y2_loss = []
Y2_acc = []
for i in range(rounds):
    Y1_loss.append(global_loss[i])
    Y1_acc.append(global_acc[i])
    Y2_loss.append(federated_loss[i])
    Y2_acc.append(federated_acc[i])

In [None]:
fig = plt.figure(figsize=(8, 6))
fig.set_facecolor('white')
ax = fig.add_subplot()
ax.plot(X, Y1_acc, label='Global')
ax.plot(X, Y2_acc, label='Federated')
ax.set_title('CIFAR-10', fontsize='x-large')
ax.set_xlabel('Federated rounds', fontsize='x-large')
ax.set_ylabel('Accuracy', fontsize='x-large')
ax.set_ylim([0, 1])
ax.legend()
# fig.savefig('accuracy.png', bbox_inches='tight')