# Federated Learning for Image Classification

In this tutorial, we use the classic MNIST training example to introduce the
Federated Learning (FL) component of TFF - a set of higher-level interfaces that
can be used to perform common types of federated learning tasks, such as
federated training, against user-supplied models implemented in TensorFlow.

This tutorial, and the Federated Learning API, are intended primarly for users
who will want to plug their own TensorFlow models into TFF, treating the latter
mostly as a black box. For a more in-depth understanding of TFF and how to
implement your own federated learning algorithms, consider also reviewing as a
follow-up the tutorial on lower-level interfaces -
[Custom Federated Algorithms with the Federated Core API](custom_federated_algorithms.ipynb).

## Before we start

Before we start, please run the following to make sure that your environment is
correctly setup. If you don't see a greeting, please refer to the
[Installation](../install.md) guide for instructions.

In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import numpy as np
import tensorflow as tf

from tensorflow_federated import python as tff

nest = tf.contrib.framework.nest

np.random.seed(0)

tf.enable_eager_execution()
tf.enable_resource_variables()
tf.compat.v1.enable_v2_behavior()

tff.federated_computation(lambda: 'Hello, World!')()

'Hello, World!'

## Preparing the input data

Let's start with the data. Federated Learning requires a federated data set,
i.e., a collection of data from multipe users. Federated data is typically
non-[i.i.d.](https://en.wikipedia.org/wiki/Independent_and_identically_distributed_random_variables),
which poses a unique set of challenges. To illustrate this, we're going to
manuallly craft a scenario with 10 users, each of whom contributes data on how
to recognize a different digit, and we'll use the federated learning framework
to learn a combined model.

Let's start by loading the standard MNIST data from `tf.keras.datasets`.

In [0]:
#@test {"output": "ignore"}
mnist_train, _ = tf.keras.datasets.mnist.load_data()

The standard MNIST data comes as a pair of Numpy arrays, the first with 28x28
pixel images, and the second with the integers that correspond to the digits,
both with the leading batching dimension.

In [3]:
[(x.dtype, x.shape) for x in mnist_train]

[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

Here's a helper method that will construct a portion training data for a given
digit, in batches of a given size, with images flattened into vectors and
normalized to simplify the model code, and with the features renamed to `x` and
`y` as expected by Keras.

In [0]:
NUM_EXAMPLES_PER_DIGIT = 1000
BATCH_SIZE = 100

def get_training_data_for_digit(digit):
  images, digits = mnist_train
  output_sequence = []
  all_samples = [i for i, d in enumerate(digits) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_DIGIT), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append(collections.OrderedDict([
        ('x', np.array([images[i].flatten() / 255.0 for i in batch_samples],
                       dtype=np.float32)),
        ('y', np.array([digits[i] for i in batch_samples], dtype=np.int32))
    ]))
  return output_sequence

One of the ways one can use to feed federated data to TFF is as a simple Python
list of per-user data sets. We can construct such a representation for our 10
users as follows.

In [0]:
federated_train_data = [get_training_data_for_digit(d) for d in xrange(10)]

As a quick sanity check, let's grab the first batch of data from the first user
(we're going to meed a sample data batch for use with Keras, anyway), and check
the dimensions to make sure the data looks as expected.

In [0]:
sample_batch = federated_train_data[0][0]

In [7]:
nest.map_structure(lambda x: (x.dtype, x.shape), sample_batch)

OrderedDict([('x', (dtype('float32'), (100, 784))), ('y', (dtype('int32'), (100,)))])

## Creating a model with Keras

If you are using Keras, you likely already have code that constructs a Keras
model. Here's an example based on the simple Keras model posted on
[the tensorflow.org tutorials page](https://www.tensorflow.org/tutorials); the
only modification we make is removing the dropout layer. This is not needed,
since we will be combining models across users; in this case, the process of
federated averaging in and by itself will help to avoid overfitting.

In [0]:
def create_compiled_keras_model():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(784,)),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax)
  ])
  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer = tf.keras.optimizers.SGD(0.1),
      metrics = [tf.keras.metrics.SparseCategoricalAccuracy()])
  return model

In order to use any model with TFF, it needs to be wrapped in an instance of the
`tff.learning.Model` interface, which exposes methods to stamp the model's
forward pass, metadata properties, etc., similarly to Keras, but also introduces
additional elements, such as ways to control the process of computing federated
metrics. Let's not worry about this for now; if you have a compiled Keras model
like the one we've just defined above, you can have TFF wrap it for you by
invoking `tff.learning.from_compiled_keras_model`, passing the model and a
sample data batch as arguments, as shown below.

In [0]:
def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

## Training the model on federated data

Now that we have a model wrapped as `tff.learning.Model` for use with TFF, we
can let TFF construct a federated averaging algorithm by invoking the helper
function `tff.learning.build_federated_averaging_process`, as follows.

Keep in mind that the argument needs to be a constructor (such as `model_fn`
above), not the already-constructed instance, so that the construction of your
model can happen in a context controlled by TFF (if you're curious about the
reasons for this, we encourage you to read the follow-up tutorial on
[custom algorithms](custom_federated_algorithms.ipynb)).

In [0]:
iterative_process = tff.learning.build_federated_averaging_process(model_fn)

What did just happen? TFF has constructed a pair of *federated computations* and
packaged them into a standardized iterator-like structure
`tff.utils.IterativeProcess` in which these computations are available as a pair
of properties `initialize` and `next`.

In a nutshell, *federated computations* are programs in TFF's internal language
that can express various federated algorithms (you can find more about this in
the [custom algorithms](custom_federated_algorithms.ipynb) tutorial). In this
case, the two computations generated and packed into `iterative_process`
implement [federated model averaging](https://arxiv.org/abs/1602.05629).

In one of the upcoming releases of the framework, we'll enable you to deploy
such computations for execution in real environments, such as on groups of
`Android` devices. In this tutorial, we'll execute federated computations in a
simple interpreted environment in a simulator, inside this notebook. To execute
a computation in a simulator, you simply invoke them like Python functions, as
we will demonstrate shortly. This default interpreted environment is not
designed for high performance, but it will suffice for this tutorial.

Let's start with the `initialize` computation. As is the case for all federated
computations, you can think of it as a function. The computation takes no
arguments, and returns one result - the representation of the state of the
federated averaging process on the server. While we don't want to dive into the
details of TFF, it may be instructive to see what this state looks like. You can
visualize it as follows.

In [11]:
str(iterative_process.initialize.type_signature)

'( -> <model=<trainable=<dense/bias=float32[512],dense/kernel=float32[784,512],dense_1/bias=float32[10],dense_1/kernel=float32[512,10]>,non_trainable=<>>,optimizer_state=<int64>>@SERVER)'

While the above type signature may at first seem a bit cryptic, you can
recognize that the server state consists of a `model` (the initial model
parameters for MNIST that will be distributed to all devices), and
`optimizer_state` (additional information maintained by the server, such as the
number of rounds to use for hypermarameter schedules, etc.).

Let's invoke the `initialize` computation to construct the server state.

In [0]:
state = iterative_process.initialize()

The second of the pair of federated computations, `next`, represents a single
round of federated averaging, which consists of pushing the server state
(including the model parameters) to the clients, on-device training on their
local data, collecting and averaging model updates, and producing a new updated
model at the server.

Conceptually, you can think of `next` as having a functional type signature
`SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS`.

Let's run a single round of training and visualize the results.

In [0]:
#@test {"timeout": 600}
state, metrics = iterative_process.next(state, federated_train_data)

In [14]:
#@test {"output": "ignore"}
str(metrics)

'<sparse_categorical_accuracy=0.9138,loss=0.271792>'

Let's run a few more rounds, just to confirm that the loss decreases.

In [15]:
#@test {"skip": true}
for _ in xrange(5):
  state, metrics = iterative_process.next(state, federated_train_data)
  print (metrics)

<sparse_categorical_accuracy=0.9259,loss=0.244756>
<sparse_categorical_accuracy=0.9488,loss=0.223349>
<sparse_categorical_accuracy=0.9616,loss=0.205643>
<sparse_categorical_accuracy=0.9683,loss=0.190492>
<sparse_categorical_accuracy=0.9704,loss=0.177308>


In a typical federated training scenario, we would be dealing with potentially a
very large population of user devices, only a fraction of which would be
available for training at a given point in time. To simulate this, you could
split the data into smaller portions, and then pick a random subset of elements
from within `federated_train_data` to simulate a fraction of users participating
in each round of training. We leave it as an exercise to the reader to try
making this simple change and assess the impact of parameters such as the number
of users chosen per round on the duration of each training round an the rate of
convergence of the model at the server.

## Customizing the model implementation

While handing your Keras model to `tff.learning.from_keras_model` or
`tff.learning.from_compiled_keras_model` and letting TFF automatically wrap it
for use in federated learning may be a good place to start, in many cases you
will want to have more explicit control over the process and customize it for
your scenario, so let's do it all over again from scratch.

### Defining model variables, forward pass, and metrics

The first step is to identify the TensorFlow variables we're going to work with.
In order to make the following code more legible, let's define a data structure
to represent the entire set. This will include variables such as `weights` and
`bias` that we will train, as well as variables that will hold various
cumulative statistics and counters we will update during training, such as
`loss_sum`, `accuracy_sum`, and `num_examples`.

In [0]:
MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

Here's a method that creates the variables. For the sake of simplicity, we
represent all statistics as `tf.float32`, as that will eliminate the need for
type conversions at a later stage. Wrapping variable initializers as lambdas is
a requirement imposed by
[resource variables](https://www.tensorflow.org/api_docs/python/tf/enable_resource_variables).

In [0]:
def create_mnist_variables():
  return MnistVariables(
      weights = tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias = tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples = tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum = tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum = tf.Variable(0.0, name='accuracy_sum', trainable=False))

With the variables for model parameters and cumulative statistics in place, we
can now define the forward pass method that computes loss, emits predictions,
and updates the cumulative statistics for a single batch of input data, as
follows.

In [0]:
def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  loss = -tf.reduce_mean(tf.reduce_sum(
      tf.one_hot(batch['y'], 10) * tf.log(y), reduction_indices=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, batch['y']), tf.float32))

  num_examples = tf.to_float(tf.size(batch['y']))

  tf.assign_add(variables.num_examples, num_examples)
  tf.assign_add(variables.loss_sum, loss * num_examples)
  tf.assign_add(variables.accuracy_sum, accuracy * num_examples)

  return loss, predictions

With the single-batch forward pass defined, TFF has enough information to put
together a training loop that processes the on-device data for a single user,
but we may want to control the manner in which we translate the cumulative
statistics accumulated in the course of training into a set of metrics to be
exported by the device, so let's write another helper function to perform this
translation. We're just going to export the average `loss` and `accuracy`, as
well as the `num_examples`; we'll need the latter to correctly weigh the
contributions from different users when computing the global statistics at the
server.

In [0]:
def get_local_mnist_metrics(variables):
  return collections.OrderedDict([
      ('num_examples', variables.num_examples),
      ('loss', variables.loss_sum / variables.num_examples),
      ('accuracy', variables.accuracy_sum / variables.num_examples)
    ])

Finally, we need to determine how to aggregate the local metrics emitted by each
device. This is the only part of the code that isn't written in pure TensorFlow
and Python - it's a *federated computation* expressed in TFF. If you'd like to
dig deeper, skim over the [custom algorithms](custom_federated_algorithms.ipynb)
tutorial, but in most applications, you won't really need to; variants of the
pattern show below should suffice. Simply apply `tff.federated_sum` to metrics
you want to sum and `tff.federated_average` to those you want to average, return
the dictionary of what you wish to report globally, and decorate your function
as `tff.federated_computation`.

In [0]:
@tff.federated_computation
def aggregate_local_mnist_metrics(metrics):
  return {
      'num_examples': tff.federated_sum(metrics.num_examples),
      'loss': tff.federated_average(metrics.loss, metrics.num_examples),
      'accuracy': tff.federated_average(metrics.accuracy, metrics.num_examples)
  }

### Constructing an instance of `tff.learning.Model`

With all of the above in place, we are ready to construct a model representation
for use with TFF similar to one that's generated for you when you let TFF ingest
a Keras model.

In [0]:
class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [
        self._variables.num_examples, self._variables.loss_sum,
        self._variables.accuracy_sum
    ]

  @property
  def input_spec(self):
    return collections.OrderedDict([('x', tf.TensorSpec([None, 784],
                                                        tf.float32)),
                                    ('y', tf.TensorSpec([None], tf.int32))])

  @tf.contrib.eager.function(autograph=False)
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    return tff.learning.BatchOutput(loss=loss, predictions=predictions)

  @tf.contrib.eager.function(autograph=False)
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_local_mnist_metrics

As you can see, the abstract methods and properties defined by
`tff.learning.Model` correspond closely to the code snippets that introduce the
variables, define the loss and statistics that we introduced in the preceding
section. Here are a few points worth highlighting:

*   All state that your model will use must be captured as TensorFlow variables,
    as TFF does not use Python at runtime (remember your code should be written
    such that it can be deployed to mobile devices). The variables should always
    be created in the constructor.
*   Your model should describe what form of data it accepts (`input_spec`), as
    in general, TFF is a stronly-typed environment and wants to determine type
    signatures for all components. Declaring the format of your model's input is
    an essential part of it.
*   Although technically not required, we recommend wrapping all TensorFlow
    logic (forward pass, metric calculations) as `tf.contrib.eager.function`s.

While the above is technically sufficient, and TFF can take over from here,
there's one more optional part you may want to customize, and that's defining
the optimizer and how it's applied to the data to perform a single step of
training. While TFF can do it for you, it's common to take control over this
step as well.

In [0]:
class MnistTrainableModel(MnistModel, tff.learning.TrainableModel):

  @tf.contrib.eager.defun(autograph=False)
  def train_on_batch(self, batch):
    output = self.forward_pass(batch)
    optimizer = tf.train.GradientDescentOptimizer(0.1)
    optimizer.minimize(output.loss, var_list=self.trainable_variables)
    return output

### Simulating federated training with the new model

With all the above in place, the remainder of the process looks like what we've
seen already - just replace the model constructor with the constructor of our
new model class, as follows.

In [0]:
iterative_process = tff.learning.build_federated_averaging_process(
    MnistTrainableModel)

In [0]:
state = iterative_process.initialize()

In [0]:
#@test {"timeout": 600}
state, metrics = iterative_process.next(state, federated_train_data)

In [26]:
#@test {"output": "ignore"}
str(metrics)

'<accuracy=0.91,loss=0.293638,num_examples=10000.0>'

In [27]:
#@test {"skip": true}
for _ in xrange(5):
  state, metrics = iterative_process.next(state, federated_train_data)
  print (metrics)

<accuracy=0.976,loss=0.277596,num_examples=10000.0>
<accuracy=0.9786,loss=0.263369,num_examples=10000.0>
<accuracy=0.9796,loss=0.25075,num_examples=10000.0>
<accuracy=0.98,loss=0.239585,num_examples=10000.0>
<accuracy=0.9799,loss=0.229727,num_examples=10000.0>


This concludes our tutorial.

In [0]:
#@test {"skip": true}

# TODO(b/120157713): Quick follow-ups for some/all of the following, listed in
# no particular order, and to be addressed as time permits:
# - demonstrate the use of data from tff.simulation.datasets.emnist,
# - demonstrate feeding in data from `tf.data.Dataset`s (with the EMNIST data),
# - demonstrate subsampling users on subsequent rounds,
# - demonstrate exporting trained model parameters, and evaluation using Keras,
# - demonstrate federated evaluation without training.