In [6]:
import collections

import numpy as np
import tensorflow as tf
# tf.enable_eager_execution() # this is necessary in audi8
import tensorflow_federated as tff
import nest_asyncio # this is necessary in audi 9 
nest_asyncio.apply()
np.random.seed(0)

tff.federated_computation(lambda: 'Hello, World!')()
# see error meassage at https://github.com/tensorflow/federated/issues/842

b'Hello, World!'

In [7]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [8]:
example_dataset = emnist_train.create_tf_dataset_for_client(emnist_train.client_ids[0])
example_element = next(iter(example_dataset))

In [9]:
tf.executing_eagerly() 

True

In [10]:
NUM_CLIENTS = 10
NUM_EPOCHS = 5
BATCH_SIZE = 20
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER= 10

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

In [11]:
preprocessed_example_dataset = preprocess(example_dataset)

sample_batch = tf.nest.map_structure(lambda x: x.numpy(),
                                     next(iter(preprocessed_example_dataset)))

sample_batch

OrderedDict([('x',
              array([[1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     ...,
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.],
                     [1., 1., 1., ..., 1., 1., 1.]], dtype=float32)),
             ('y',
              array([[7],
                     [5],
                     [2],
                     [1],
                     [5],
                     [5],
                     [1],
                     [4],
                     [9],
                     [7],
                     [0],
                     [8],
                     [0],
                     [1],
                     [7],
                     [9],
                     [6],
                     [9],
                     [9],
                     [3]]))])

In [12]:
def make_federated_data(client_data, client_ids):
    return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

- In simulation setting, we choose subset of clients randomly. Random subset mighe **slow down convergence**. We will only sample the set of clients once, and reuse the same set across rounds to speed up convergence (intentionally over-fitting to these few user's data). 

In [13]:
sample_clients = emnist_train.client_ids[0:NUM_CLIENTS] # choose first 10 clients

federated_train_data = make_federated_data(emnist_train, sample_clients) # preprocessed data for 10 clients

print('Number of client datasets: {l}'.format(l=len(federated_train_data)))
print('First dataset: {d}'.format(d=federated_train_data[0]))

Number of client datasets: 10
First dataset: <PrefetchDataset shapes: OrderedDict([(x, (None, 784)), (y, (None, 1))]), types: OrderedDict([(x, tf.float32), (y, tf.int32)])>


- If we first define keras model 

In [14]:
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

- need to convert the keras model to tff learning model interface using :

In [15]:
def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

- Next we can train model calling fed avg algorithm. **Keep in mind that the argument needs to be a constructor (such as model_fn above), not an already-constructed instance**

In [16]:
iterative_process = tff.learning.build_federated_averaging_process(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

Instructions for updating:
If using Keras pass *_constraint arguments to layers.


Instructions for updating:
If using Keras pass *_constraint arguments to layers.
  self._parent_message_weakref = weakref.proxy(parent_message)
  self._parent_message_weakref = weakref.proxy(parent_message)


In [17]:
str(iterative_process.initialize.type_signature)

'( -> <model=<trainable=<float32[784,10],float32[10]>,non_trainable=<>>,optimizer_state=<int64>,delta_aggregate_state=<>,model_broadcast_state=<>>@SERVER)'

In [18]:
state = iterative_process.initialize()

In [19]:
SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS

SyntaxError: invalid syntax (<ipython-input-19-29da5a8a351a>, line 1)

In [20]:
state, metrics = iterative_process.next(state, federated_train_data)
print('round 1, metrics = {}'.format(metrics))

round 1, metrics = OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11255144), ('loss', 3.0456898)]))])


In [21]:
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, federated_train_data)
    print('round {:2d}, metrics={}'.format(round_num, metrics))

round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1382716), ('loss', 2.9017148)]))])
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.16111112), ('loss', 2.7337723)]))])
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.16296296), ('loss', 2.7093928)]))])
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.21090534), ('loss', 2.5592349)]))])
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.23806584), ('loss', 2.4227028)]))])
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', ()), ('train', OrderedDict([('sparse_categorical_accuracy', 0.2516461), ('loss', 2.3178978)]))])
round  8, metrics=