In [1]:
import tensorflow as tf
import tensorflow_federated as tff

In [28]:
train, test = tff.simulation.datasets.emnist.load_data()

In [29]:
def client_data(n:int) -> tf.data.Dataset:
    return train.create_tf_dataset_for_client(train.client_ids[n]).map(
        lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])).repeat(10).batch(20)
train_data = [client_data(i) for i in range(10)]

In [30]:
def model_fn() -> tff.learning.Model:
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(64, activation='softmax', input_shape=(784,),kernel_initializer='zeros')
    ])
    return tff.learning.from_keras_model(model,
                                         input_spec=train_data[0].element_spec,
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                                         metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [31]:
trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1))

In [32]:
state = trainer.initialize()

In [38]:
for i in range(10):
    result = trainer.next(state, train_data)
    state = result.state
    metrics = result.metrics
    print(f"Client {i} Loss: {metrics['client_work']['train']['loss']}")

Client 0 Accuracy: 14.177961349487305
Client 1 Accuracy: 14.177961349487305
Client 2 Accuracy: 14.177962303161621
Client 3 Accuracy: 14.177961349487305
Client 4 Accuracy: 14.177962303161621
Client 5 Accuracy: 14.177962303161621
Client 6 Accuracy: 14.177961349487305
Client 7 Accuracy: 14.177961349487305
Client 8 Accuracy: 14.177961349487305
Client 9 Accuracy: 14.177962303161621
