# Federated Learning for Image Classification

In this tutorial, we introduce Federated Learning (FL), a set of higher-level
interfaces that can be used to perform common types of federated learning tasks,
such as federated training and federated evaluation, against user-supplied
models implemented in TensorFlow (e.g., with Keras), and we demonstrate how to
perform simple federated learning experiments using a federated MNIST data set
as an example.

This tutorial, and the Federated Learning API, are intended primarly for users
who will want to plug their own TensorFlow models into TFF, treating the latter
mostly as a black box. For a more in-depth understanding of TFF and how to
implement your own federated learning algorithms, consider also reviewing a
follow-up tutorial
[Custom Federated Algorithms with the Federated Core API](custom_federated_algorithms.ipynb).

## Before we start

Before we start, please run the following to make sure that your environment is
correctly setup. If you don't see a greeting, please refer to the
[Installation](../install.md) guide for instructions.

In [21]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import numpy as np
import tensorflow as tf

from tensorflow_federated import python as tff

nest = tf.contrib.framework.nest

np.random.seed(0)

tf.enable_eager_execution()
tf.enable_resource_variables()

tff.federated_computation(lambda: 'Hello, World!')()

'Hello, World!'

## Preparing the input data

In [0]:
#@test {"output": "ignore"}
mnist_train, _ = tf.keras.datasets.mnist.load_data()

In [23]:
[(x.dtype, x.shape) for x in mnist_train]

[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

In [0]:
NUM_EXAMPLES_PER_USER = 1000
BATCH_SIZE = 100

def get_data_for_digit(source, digit):
  output_sequence = []
  all_samples = [i for i, d in enumerate(source[1]) if d == digit]
  for i in range(0, min(len(all_samples), NUM_EXAMPLES_PER_USER), BATCH_SIZE):
    batch_samples = all_samples[i:i + BATCH_SIZE]
    output_sequence.append(collections.OrderedDict([
        ('x', np.array([source[0][i].flatten() / 255.0 for i in batch_samples],
                       dtype=np.float32)),
        ('y', np.array([source[1][i] for i in batch_samples], dtype=np.int32))
    ]))
  return output_sequence

federated_train_data = [get_data_for_digit(mnist_train, d) for d in xrange(10)]

In [0]:
sample_batch = federated_train_data[0][0]

In [26]:
nest.map_structure(lambda x: (x.dtype, x.shape), sample_batch)

OrderedDict([('x', (dtype('float32'), (100, 784))), ('y', (dtype('int32'), (100,)))])

## Creating a model with Keras

In [0]:
def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(784,)),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10, activation=tf.nn.softmax)
  ])

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer = tf.keras.optimizers.SGD(0.1),
      metrics = [tf.keras.metrics.SparseCategoricalAccuracy()])

  return tff.learning.from_compiled_keras_model(model, sample_batch)

## Training the model on federated data

In [0]:
iterative_process = tff.learning.build_federated_averaging_process(model_fn)

In [0]:
state = iterative_process.initialize()

In [30]:
#@test {"skip": true}
for _ in xrange(5):
  state, metrics = iterative_process.next(state, federated_train_data)
  print (metrics)

<sparse_categorical_accuracy=0.9153,loss=0.273054>
<sparse_categorical_accuracy=0.9258,loss=0.245137>
<sparse_categorical_accuracy=0.9425,loss=0.22302>
<sparse_categorical_accuracy=0.9583,loss=0.204645>
<sparse_categorical_accuracy=0.9643,loss=0.18895>


## Customizing the model implementation

In [0]:
MnistVariables = collections.namedtuple(
    'MnistVariables', 'weights bias num_examples loss_sum accuracy_sum')

In [0]:
def create_mnist_variables():
  return MnistVariables(
      weights = tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(784, 10)),
          name='weights',
          trainable=True),
      bias = tf.Variable(
          lambda: tf.zeros(dtype=tf.float32, shape=(10)),
          name='bias',
          trainable=True),
      num_examples = tf.Variable(0.0, name='num_examples', trainable=False),
      loss_sum = tf.Variable(0.0, name='loss_sum', trainable=False),
      accuracy_sum = tf.Variable(0.0, name='accuracy_sum', trainable=False))

In [0]:
def mnist_forward_pass(variables, batch):
  y = tf.nn.softmax(tf.matmul(batch['x'], variables.weights) + variables.bias)
  predictions = tf.cast(tf.argmax(y, 1), tf.int32)

  loss = -tf.reduce_mean(tf.reduce_sum(
      tf.one_hot(batch['y'], 10) * tf.log(y), reduction_indices=[1]))
  accuracy = tf.reduce_mean(
      tf.cast(tf.equal(predictions, batch['y']), tf.float32))

  num_examples = tf.to_float(tf.size(batch['y']))

  tf.assign_add(variables.num_examples, num_examples)
  tf.assign_add(variables.loss_sum, loss * num_examples)
  tf.assign_add(variables.accuracy_sum, accuracy * num_examples)

  return loss, predictions

In [0]:
def get_local_mnist_metrics(variables):
  return collections.OrderedDict([
      ('num_examples', variables.num_examples),
      ('loss', variables.loss_sum / variables.num_examples),
      ('accuracy', variables.accuracy_sum / variables.num_examples)
    ])

In [0]:
@tff.federated_computation
def aggregate_local_mnist_metrics(metrics):
  return {
      'num_examples': tff.federated_sum(metrics.num_examples),
      'loss': tff.federated_average(metrics.loss, metrics.num_examples),
      'accuracy': tff.federated_average(metrics.accuracy, metrics.num_examples)
  }

In [0]:
class MnistModel(tff.learning.Model):

  def __init__(self):
    self._variables = create_mnist_variables()

  @property
  def trainable_variables(self):
    return [self._variables.weights, self._variables.bias]

  @property
  def non_trainable_variables(self):
    return []

  @property
  def local_variables(self):
    return [self._variables.num_examples,
            self._variables.loss_sum,
            self._variables.accuracy_sum]

  @property
  def input_spec(self):
    return collections.OrderedDict([
        ('x', tf.TensorSpec([None, 784], tf.float32)),
        ('y', tf.TensorSpec([None], tf.int32))])

  @tf.contrib.eager.function(autograph=False)
  def forward_pass(self, batch, training=True):
    del training
    loss, predictions = mnist_forward_pass(self._variables, batch)
    return tff.learning.BatchOutput(loss=loss, predictions=predictions)

  @tf.contrib.eager.function(autograph=False)
  def report_local_outputs(self):
    return get_local_mnist_metrics(self._variables)

  @property
  def federated_output_computation(self):
    return aggregate_local_mnist_metrics

In [0]:
class MnistTrainableModel(MnistModel, tff.learning.TrainableModel):

  @tf.contrib.eager.defun(autograph=False)
  def train_on_batch(self, batch):
    output = self.forward_pass(batch)
    optimizer = tf.train.GradientDescentOptimizer(0.1)
    optimizer.minimize(output.loss, var_list=self.trainable_variables)
    return output

In [0]:
iterative_process = tff.learning.build_federated_averaging_process(
    MnistTrainableModel)

In [0]:
state = iterative_process.initialize()

In [40]:
#@test {"skip": true}
for _ in xrange(5):
  state, metrics = iterative_process.next(state, federated_train_data)
  print (metrics)

<accuracy=0.91,loss=0.293638,num_examples=10000.0>
<accuracy=0.976,loss=0.277596,num_examples=10000.0>
<accuracy=0.9786,loss=0.263369,num_examples=10000.0>
<accuracy=0.9796,loss=0.25075,num_examples=10000.0>
<accuracy=0.98,loss=0.239585,num_examples=10000.0>
