In [1]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np

In [2]:
cifar = tf.keras.datasets.cifar10
(train_images, train_labels), (test_images, test_labels) = cifar.load_data()

In [3]:
def process_data(images, labels, splits):
    images = images.reshape(images.shape[0], 32, 32, 3)
    images = images.astype('float32')
    images /= 255
    labels = labels.astype(np.int32)
    labels = tf.one_hot(tf.constant(labels), 10)
    if splits == 0:
        labels = np.array(labels)
        labels = labels.reshape(labels.shape[0], 10)
        return images, labels
    else:
        split_images = np.array_split(images, splits)
        split_labels = np.array_split(labels, splits)
        reshaped_labels = []
        for label_portion in split_labels:
            reshaped_labels.append(label_portion.reshape(label_portion.shape[0], 10))
    return split_images, reshaped_labels

def model():
    input_shape = (32, 32, 3)
    inputs = tf.keras.layers.Input(shape=input_shape)
    x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)
    x = tf.keras.layers.MaxPooling2D((2, 2))(x)
    x = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(64, activation='relu')(x)
    outputs = tf.keras.layers.Dense(10, activation='softmax')(x)
    model = tf.keras.models.Model(inputs, outputs)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

def federated_average(weight1, weight2):
    new_weights = np.array(weight1) + np.array(weight2)
    new_weights = new_weights / 2
    return new_weights.tolist()

In [4]:
split_images, split_labels = process_data(train_images, train_labels,3)

server_model = model()
client_1 = model()
client_2 = model()

server_model.fit(split_images[0], split_labels[0], epochs=1)
server_weights = server_model.get_weights()

client_1.set_weights(server_weights)
client_2.set_weights(server_weights)

client_1.fit(split_images[1], split_labels[1], epochs=1)
client_1_weights = client_1.get_weights()

client_2.fit(split_images[2], split_labels[2], epochs=1)
client_2_weights = client_2.get_weights()

global_update = federated_average(client_1_weights,client_2_weights)
server_model.set_weights(global_update)
test_images_proc, test_labels_proc = process_data(test_images, test_labels,0)
loss, accuracy = server_model.evaluate(test_images_proc, test_labels_proc)

Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB



2023-02-24 13:37:08.171192: W tensorflow/core/platform/profile_utils/cpu_utils.cc:128] Failed to get CPU frequency: 0 Hz




  new_weights = np.array(weight1) + np.array(weight2)


