## CIFAR-10 이미지 추출
### 클라이언트별 IID

In [None]:
import os

import tensorflow as tf
import numpy as np
import pandas as pd
import PIL.Image as Image

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()
images = np.concatenate((train_images, test_images))
labels = np.concatenate((train_labels, test_labels))

In [None]:
# https://www.tensorflow.org/api_docs/python/tf/keras/datasets/cifar10/load_data
labelnames = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']
labelnames

In [None]:
clients = 10
div = clients + 1

In [None]:
counter = {}
output = os.path.abspath(os.path.expanduser('dataset'))
for idx, data in enumerate(zip(images, labels), start=0):
    image = Image.fromarray(data[0])
    label = labelnames[data[1][0]]
    num = counter.get(label, 0)
    odir = os.path.join(output, f'{num%div}', label)
    os.makedirs(odir, exist_ok=True)
    opath = os.path.join(odir, f'{num:04d}.jpg')
    image.save(opath)
    counter[label] = num + 1

## 클라이언트별 학습

In [None]:
dataset_root = os.path.abspath(os.path.expanduser('dataset'))
with os.scandir(dataset_root) as it:
    for entry in it:
        if not entry.name.startswith('.') and entry.is_dir():
            print(entry, entry.path, entry.name)

In [None]:
# Datasets
datasets = {}
dataset_root = os.path.abspath(os.path.expanduser('dataset'))
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255)
with os.scandir(dataset_root) as it:
    for entry in it:
        if not entry.name.startswith('.') and entry.is_dir():
            dataset = image_generator.flow_from_directory(entry.path,
                                                          classes=labelnames, 
                                                          target_size=(32, 32), 
                                                          shuffle=True)
            datasets[entry.name] = dataset
datasets.keys()

In [None]:
model = tf.keras.models.Sequential([
                       tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
                       tf.keras.layers.MaxPooling2D((2, 2)),
                       tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                       tf.keras.layers.MaxPooling2D((2, 2)),
                       tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
                       tf.keras.layers.Flatten(),
                       tf.keras.layers.Dense(64, activation='relu'),
                       tf.keras.layers.Dense(len(labelnames))])
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()

In [None]:
# Base model
odir = os.path.join('models', '0', '0')
model.save(odir)

In [None]:
rounds = 10
epochs = 10

In [None]:
for r in range(rounds):
    weights = []
    rpath = os.path.join('models', f'{r}')
    for c in range(1, clients+1):
        mpath = os.path.join(rpath, '0')
        model = tf.keras.models.load_model(mpath)
        history = model.fit(datasets[f'{c}'], epochs=epochs, verbose=0)
        opath = os.path.join(rpath, f'{c}')
        model.save(opath)
        if c == 1:
            weights = model.get_weights()
        else:
            for idx, weight in enumerate(model.get_weights()):
                weights[idx] = weights[idx] + weight
        print(f'Local train: round #{r} with clinent #{c}')
    for idx, weight in enumerate(weights):
        weights[idx] = weights[idx] / clients
    mpath = os.path.join(rpath, '0')
    model = tf.keras.models.load_model(mpath)
    model.set_weights(weights)
    ndir = os.path.join('models', f'{r+1}')
    npath = os.path.join(ndir, '0')
    model.save(npath)
    metric = model.evaluate(datasets['0'], verbose=0)
    print(f'Global aggregation: round #{r+1} for {metric}')