This code comes from https://www.tensorflow.org/federated

In [1]:
import tensorflow as tf
import tensorflow_federated as tff

INFO:tensorflow:Using local port 18181
INFO:tensorflow:Using local port 15390
INFO:tensorflow:Using local port 17022
INFO:tensorflow:Using local port 16933
INFO:tensorflow:Using local port 15981
INFO:tensorflow:Using local port 18050
INFO:tensorflow:Using local port 23066
INFO:tensorflow:Using local port 15947
INFO:tensorflow:Using local port 16758
INFO:tensorflow:Using local port 22651


TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


In [11]:
tff.__version__

'0.16.1'

In [4]:
# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()

In [3]:
source

<tensorflow_federated.python.simulation.hdf5_client_data.HDF5ClientData at 0x211a81226a0>

In [8]:
# for the n-th client, return a matrix with number of images rows and 28*28 columns
def client_data(n):
  return source.create_tf_dataset_for_client(source.client_ids[n]).map(
      lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])
  ).repeat(10).batch(20)

In [9]:
# Pick a subset of client devices to participate in training.
train_data = [client_data(n) for n in range(3)]

# Grab a single batch of data so that TFF knows what data looks like.
sample_batch = tf.nest.map_structure(
    lambda x: x.numpy(), iter(train_data[0]).next())

# Wrap a Keras model for use with TFF.
def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Dense(10, tf.nn.softmax, input_shape=(784,),
                            kernel_initializer='zeros')
  ])
  return tff.learning.from_keras_model(
      model,
      dummy_batch=sample_batch, # this is an exceptional argument in audi9
      # downgrade to tff  0.13.1 would work
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process(
  model_fn,
  client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.1))
state = trainer.initialize()
for _ in range(5):
  state, metrics = trainer.next(state, train_data)
  print (metrics.loss)

TypeError: from_keras_model() got an unexpected keyword argument 'dummy_batch'