In [2]:
import tensorflow as tf
import tensorflow_federated as tff
import nest_asyncio

In [3]:
# Apply nest_asyncio to enable nested event loops
nest_asyncio.apply()

In [4]:
# Load the CIFAR-10 dataset
(cifar10_train_images, cifar10_train_labels), (cifar10_test_images, cifar10_test_labels) = tf.keras.datasets.cifar10.load_data()

In [5]:
# Preprocess the dataset
def preprocess(images, labels):
    images = tf.cast(images, tf.float32) / 255.0
    return (images, labels)

In [6]:
# Split the dataset into multiple "client" datasets
num_clients = 10
client_datasets = []
for i in range(num_clients):
    start = i * len(cifar10_train_images) // num_clients
    end = (i + 1) * len(cifar10_train_images) // num_clients
    client_images = cifar10_train_images[start:end]
    client_labels = cifar10_train_labels[start:end]
    client_dataset = tf.data.Dataset.from_tensor_slices((client_images, client_labels))
    client_dataset = client_dataset.map(preprocess).batch(20)
    client_datasets.append(client_dataset)

In [7]:
# Define a simple CNN model
def create_keras_model():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Input(shape=(32, 32, 3)),
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    return model

In [8]:
# Define a TFF model
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=client_datasets[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [9]:
# Create a federated averaging process
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

In [10]:
# Initialize the federated averaging process
state = iterative_process.initialize()

In [11]:
# Run multiple rounds of federated averaging
NUM_ROUNDS = 10
for round_num in range(NUM_ROUNDS):
    state, metrics = iterative_process.next(state, client_datasets)
    print('round {:2d}, metrics={}'.format(round_num, metrics))


round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1695), ('loss', 2.2137127)]))])
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2506), ('loss', 2.037138)]))])
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.29942), ('loss', 1.9327714)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.3473), ('loss', 1.8198311)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_proce

In [12]:
def evaluate_model(state, test_dataset):
    keras_model = create_keras_model()
    keras_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    
    state.model.assign_weights_to(keras_model)
    
    test_images, test_labels = zip(*list(test_dataset))
    test_images = tf.concat(test_images, axis=0)
    test_labels = tf.concat(test_labels, axis=0)
    
    loss, accuracy = keras_model.evaluate(test_images, test_labels, verbose=0)
    return loss, accuracy

test_dataset = tf.data.Dataset.from_tensor_slices((cifar10_test_images, cifar10_test_labels))
test_dataset = test_dataset.map(preprocess).batch(20)

loss, accuracy = evaluate_model(state, test_dataset)
print(f'Test loss: {loss}, Test accuracy: {accuracy}')


Test loss: 1.3613240718841553, Test accuracy: 0.515500009059906
