In [8]:
import tensorflow as tf
import tensorflow_federated as tff

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

In [9]:
train, test = tff.simulation.datasets.emnist.load_data()

In [10]:
def client_data(n:int) -> tf.data.Dataset:
    return train.create_tf_dataset_for_client(train.client_ids[n]).map(
        lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])).repeat(10).batch(20)
train_data = [client_data(i) for i in range(10)]

In [11]:
def model_fn() -> tff.learning.Model:
    model = tf.keras.models.Sequential([
        tf.keras.layers.Dense(64, activation='softmax', input_shape=(784,),kernel_initializer='zeros')
    ])
    return tff.learning.from_keras_model(model,
                                         input_spec=train_data[0].element_spec,
                                         loss=tf.keras.losses.SparseCategoricalCrossentropy(),
                                         metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [12]:
trainer = tff.learning.algorithms.build_weighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1))

In [13]:
state = trainer.initialize()

In [14]:
for i in range(10):
    result = trainer.next(state, train_data)
    state = result.state
    metrics = result.metrics
    print(f"Client {i} Loss: {metrics['client_work']['train']['loss']}")

Client 0 Loss: 14.126679420471191
Client 1 Loss: 14.093573570251465
Client 2 Loss: 14.177961349487305
Client 3 Loss: 14.177962303161621
Client 4 Loss: 14.177962303161621
Client 5 Loss: 14.177961349487305
Client 6 Loss: 14.177961349487305
Client 7 Loss: 14.177961349487305
Client 8 Loss: 14.177961349487305
Client 9 Loss: 14.177962303161621
