In [15]:
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from tensorflow.keras import layers, models

# Load the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Preprocess the data
def preprocess(x, y):
    x = tf.cast(x, tf.float32) / 255.0
    y = tf.cast(y, tf.int64)
    return x, y

x_train, y_train = preprocess(x_train, y_train)
x_test, y_test = preprocess(x_test, y_test)

# Create the Keras model
def create_cifar10_model():
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(10))
    return model

# Create a TFF version of the Keras model 
def model_fn():
    keras_model = create_cifar10_model() #(HERE U CHANGE THE MODEL)
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=(
            tf.TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32),
            tf.TensorSpec(shape=(None, 1), dtype=tf.int64),
        ),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
    )

# Simulate federated data by splitting the dataset into multiple clients
def split_data_for_clients(data, client_count):
    client_data = []
    data_len = len(data[0])
    batch_size = data_len // client_count

    for i in range(client_count):
        start = i * batch_size
        end = (i + 1) * batch_size if i != client_count - 1 else data_len
        client_data.append((data[0][start:end], data[1][start:end]))

    return client_data

client_count = 10
client_data = split_data_for_clients((x_train, y_train), client_count)

# Create a federated dataset from the client data
federated_data = [
    tff.simulation.ClientData.from_clients_and_fn(
        client_ids=[str(i)], create_tf_dataset_for_client_fn=lambda _: tf.data.Dataset.from_tensor_slices(client_data[i]).batch(20)
    ).create_tf_dataset_for_client(str(i))
    for i in range(client_count)
]


In [16]:
import nest_asyncio
nest_asyncio.apply()

# Train the federated model
trainer = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
)

# Initialize the state
state = trainer.initialize()

# Train the model for multiple rounds
num_rounds = 10
for round_num in range(1, num_rounds + 1):
    print(f"Round {round_num}")
    state, metrics = trainer.next(state, federated_data)
    print(f"Metrics: {metrics}")



Round 1
Metrics: OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.15082), ('loss', 2.266358)])), ('stat', OrderedDict([('num_examples', 50000)]))])
Round 2
Metrics: OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2222), ('loss', 2.1104443)])), ('stat', OrderedDict([('num_examples', 50000)]))])
Round 3
Metrics: OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.27026), ('loss', 2.0038571)])), ('stat', OrderedDict([('num_examples', 50000)]))])
Round 4
Metrics: OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.3078), ('loss', 1.9133645)])), ('stat

In [18]:

# Convert the test dataset into a list of datasets
client_test_data = split_data_for_clients((x_test, y_test), client_count)

# Create a federated dataset from the client test data
federated_test_data = [
    tf.data.Dataset.from_tensor_slices(client_test_data[i]).batch(len(client_test_data[i][0]))
    for i in range(client_count)
]

# Evaluate the trained model on the federated test dataset
tff_evaluator = tff.learning.build_federated_evaluation(model_fn)
test_metrics = tff_evaluator(state.model, federated_test_data)
print(f"Test Metrics: {test_metrics}")

Test Metrics: OrderedDict([('sparse_categorical_accuracy', 0.4811), ('loss', 1.4383234)])
