In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model

In [None]:
cifar = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar.load_data()
print('Training:\timg:{}\tlabel:{}\nTesting:\timg:{}\tlabel:{}\n'.format(train_images.shape,train_labels.shape,test_images.shape,test_labels.shape))

In [None]:
def process_data(images, labels, splits):
    images = images.reshape(images.shape[0], 32, 32, 3)
    images = images.astype('float32')
    images /= 255
    labels = labels.astype(np.int32)
    labels = tf.one_hot(tf.constant(labels), 10)
    if splits == 0:
        labels = np.array(labels)
        labels = labels.reshape(labels.shape[0], 10)
        return images, labels
    else:
        split_images = np.array_split(images, splits)
        split_labels = np.array_split(labels, splits)
        reshaped_labels = []
        for label_portion in split_labels:
            reshaped_labels.append(label_portion.reshape(label_portion.shape[0], 10))
    return split_images, reshaped_labels

def model():
    input_shape = (32, 32, 3)
    inputs = Input(shape=input_shape)
    x = Conv2D(32, (3, 3), activation='relu')(inputs)
    x = MaxPooling2D((2, 2))(x)
    x = Flatten()(x)
    x = Dense(64, activation='relu')(x)
    outputs = Dense(10, activation='softmax')(x)
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def federated_average(weight1, weight2):
    new_weights = np.array(weight1) + np.array(weight2)
    new_weights = new_weights / 2
    return new_weights.tolist()

In [None]:
split_images, split_labels = process_data(train_images, train_labels,5)

server_model = model()
client_1 = model()
client_2 = model()

server_model.fit(split_images[0], split_labels[0], epochs=1)
server_weights = server_model.get_weights()

client_1.set_weights(server_weights)
client_2.set_weights(server_weights)

client_1.fit(split_images[1], split_labels[1], epochs=1)
client_1.fit(split_images[2], split_labels[2], epochs=1)
client_1_weights = client_1.get_weights()

client_2.fit(split_images[3], split_labels[3], epochs=1)
client_2.fit(split_images[4], split_labels[4], epochs=1)
client_2_weights = client_2.get_weights()

In [None]:
global_update = federated_average(client_1_weights,client_2_weights)
server_model.set_weights(global_update)
test_images_proc, test_labels_proc = process_data(test_images, test_labels,0)
loss, accuracy = server_model.evaluate(test_images_proc, test_labels_proc)