# Hands-on Federated Learning: Image Classification

In their recent (and exteremly thorough!) review of the federated learning literature [*Kairouz, et al (2019)*](https://arxiv.org/pdf/1912.04977.pdf) define federated learning as a machine learning setting where multiple entities (clients) collaborate in solving a machine learning problem, under the coordination of a central server or service provider. Each client’s raw data is stored locally and not exchanged or transferred; instead, focused updates intended for immediate aggregation are used to achieve the learning objective.

In this tutorial we will use a federated version of the classic MNIST dataset to introduce the Federated Learning (FL) API layer of TensorFlow Federated (TFF), [`tff.learning`](https://www.tensorflow.org/federated/api_docs/python/tff/learning) - a set of high-level interfaces that can be used to perform common types of federated learning tasks, such as federated training, against user-supplied models implemented in TensorFlow or Keras.

# Preliminaries

In [None]:
import collections
import os
import typing

import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_federated as tff

In [None]:
# required to run TFF inside Jupyter notebooks
import nest_asyncio
nest_asyncio.apply()

In [None]:
tff.federated_computation(lambda: 'Hello, World!')()

# Preparing the data

In the IID setting the local data on each "client" is assumed to be a representative sample of the global data distribution. This is typically the case by construction when performing data parallel training of deep learning models across multiple CPU/GPU "clients".

The non-IID case is significantly more complicated as there are many ways in which data can be non-IID and different degress of "non-IIDness". Consider a supervised task with features $X$ and labels $y$. A statistical model of federated learning involves two levels of sampling:

1. Sampling a client $i$ from the distribution over available clients $Q$
2. Sampling an example $(X,y)$ from that client’s local data distribution $P_i(X,y)$.

Non-IID data in federated learning typically refers to differences between $P_i$ and $P_j$ for different clients $i$ and $j$. However, it is worth remembering that both the distribution of available clients, $Q$, and the distribution of local data for client $i$, $P_i$, may change over time which introduces another dimension of “non-IIDness”. Finally, if the local data on a client's device is insufficiently randomized, perhaps ordered by time, then independence is violated locally as well. 

In order to facilitate experimentation TFF includes federated versions of several popular datasets that exhibit different forms and degrees of non-IIDness. 

In [None]:
# What datasets are available?
tff.simulation.datasets.

This tutorial uses a version of MNIST that contains a version of the original NIST dataset that has been re-processed using [LEAF](https://leaf.cmu.edu/) so that the data is keyed by the original writer of the digits. 

The federated MNIST dataset displays a particular type of non-IIDness: feature distribution skew (covariate shift). Whith feature distribution skew the marginal distributions $P_i(X)$ vary across clients, even though $P(y|X)$ is shared. In the federated MNIST dataset users are writing the same numbers but each user has a different writing style characterized but different stroke width, slant, etc.

In [None]:
tff.simulation.datasets.emnist.load_data?

In [None]:
emnist_train, emnist_test = (tff.simulation
                                .datasets
                                .emnist
                                .load_data(only_digits=True, cache_dir="../data"))


In [None]:
NUMBER_CLIENTS = len(emnist_train.client_ids)
NUMBER_CLIENTS

In [None]:
def sample_client_ids(client_ids: typing.List[str],
                      sample_size: typing.Union[float, int],
                      random_state: np.random.RandomState) -> typing.List[str]:
    """Randomly selects a subset of clients ids."""
    number_clients = len(client_ids)
    error_msg = "'client_ids' must be non-emtpy."
    assert number_clients > 0, error_msg
    if isinstance(sample_size, float):
        error_msg = "Sample size must be between 0 and 1."
        assert 0 <= sample_size <= 1, error_msg
        size = int(sample_size * number_clients)
    elif isinstance(sample_size, int):
        error_msg = f"Sample size must be between 0 and {number_clients}."
        assert 0 <= sample_size <= number_clients, error_msg
        size = sample_size
    else:
        error_msg = "Type of 'sample_size' must be 'float' or 'int'."
        raise TypeError(error_msg)
    random_idxs = random_state.randint(number_clients, size=size)
    return [client_ids[i] for i in random_idxs]


In [None]:
# these are what the client ids look like
_random_state = np.random.RandomState(42)
sample_client_ids(emnist_train.client_ids, 10, _random_state)

In [None]:
def create_tf_datasets(source: tff.simulation.ClientData,
                       client_ids: typing.Union[None, typing.List[str]]) -> typing.Dict[str, tf.data.Dataset]:
    """Create tf.data.Dataset instances for clients using their client_id."""
    if client_ids is None:
        client_ids = source.client_ids
    datasets = {client_id: source.create_tf_dataset_for_client(client_id) for client_id in client_ids}
    return datasets


def sample_client_datasets(source: tff.simulation.ClientData,
                           sample_size: typing.Union[float, int],
                           random_state: np.random.RandomState) -> typing.Dict[str, tf.data.Dataset]:
    """Randomly selects a subset of client datasets."""
    client_ids = sample_client_ids(source.client_ids, sample_size, random_state)
    client_datasets = create_tf_datasets(source, client_ids)
    return client_datasets


In [None]:
_random_state = np.random.RandomState()
client_datasets = sample_client_datasets(emnist_train, sample_size=1, random_state=_random_state)
(client_id, client_dataset), *_ = client_datasets.items()

fig, axes = plt.subplots(1, 5, figsize=(12,6), sharex=True, sharey=True)
for i, example in enumerate(client_dataset.take(5)):
    axes[i].imshow(example["pixels"].numpy(), cmap="gray")
    axes[i].set_title(example["label"].numpy())
_ = fig.suptitle(x= 0.5, y=0.75, t=f"Training examples for a client {client_id}", fontsize=15)   


## Data preprocessing

Since each client dataset is already a [`tf.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset), preprocessing can be accomplished using Dataset transformations. Another option would be to use preprocessing operations from [`sklearn.preprocessing`](https://scikit-learn.org/stable/modules/preprocessing.html).

Preprocessing consists of the following steps:

1. `map` a function that flattens the 28 x 28 images into 784-element tensors
2. `map` a function that rename the features from pixels and label to X and y for use with Keras
3. `shuffle` the individual examples
4. `batch` the into training batches

We also throw in a `repeat` over the data set to run several epochs on each client device before sending parameters to the server for averaging.

In [None]:
AUTOTUNE = (tf.data
              .experimental
              .AUTOTUNE)
SHUFFLE_BUFFER_SIZE = 1000
NUMBER_TRAINING_EPOCHS = 5 # number of local updates!
TRAINING_BATCH_SIZE = 32
TESTING_BATCH_SIZE = 32

NUMBER_FEATURES = 28 * 28
NUMBER_TARGETS = 10

In [None]:
def _reshape(training_batch):
    """Extracts and reshapes data from a training sample """
    pixels = training_batch["pixels"]
    label = training_batch["label"]
    X = tf.reshape(pixels, shape=[-1]) # flattens 2D pixels to 1D
    y = tf.reshape(label, shape=[1])
    return X, y


def create_training_dataset(client_dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Create a training dataset for a client from a raw client dataset."""
    training_dataset = (client_dataset.map(_reshape, num_parallel_calls=AUTOTUNE)
                                      .shuffle(SHUFFLE_BUFFER_SIZE, seed=None, reshuffle_each_iteration=True)
                                      .repeat(NUMBER_TRAINING_EPOCHS)
                                      .batch(TRAINING_BATCH_SIZE)
                                      .prefetch(buffer_size=AUTOTUNE))
    return training_dataset


def create_testing_dataset(client_dataset: tf.data.Dataset) -> tf.data.Dataset:
    """Create a testing dataset for a client from a raw client dataset."""
    testing_dataset = (client_dataset.map(_reshape, num_parallel_calls=AUTOTUNE)
                                     .batch(TESTING_BATCH_SIZE))
    return testing_dataset


## How to choose the clients included in each training round

In a typical federated training scenario there will be a very large population of user devices however only a fraction of these devices are likely to be available for training at a given point in time. For example, if the client devices are mobile phones then they might only participate in training when plugged into a power source, off a metered network, and otherwise idle.

In a simulated environment, where all data is locally available, an approach is to simply sample a random subset of the clients to be involved in each round of training so that the subset of clients involved will vary from round to round.

### How many clients to include in each round?

Updating and averaging a larger number of client models per training round yields better convergence and in a simulated training environment probably makes sense to include as many clients as is computationally feasible. However in real-world training scenario while averaging a larger number of clients improve convergence, it also makes training vulnerable to slowdown due to unpredictable tail delays in computation/communication at/with the clients.

In [None]:
def create_federated_data(training_source: tff.simulation.ClientData,
                          testing_source: tff.simulation.ClientData,
                          sample_size: typing.Union[float, int],
                          random_state: np.random.RandomState) -> typing.Dict[str, typing.Tuple[tf.data.Dataset, tf.data.Dataset]]:
    
    # sample clients ids from the training dataset
    client_ids = sample_client_ids(training_source.client_ids, sample_size, random_state)
    
    federated_data = {}
    for client_id in client_ids:
        # create training dataset for the client
        _tf_dataset = training_source.create_tf_dataset_for_client(client_id)
        training_dataset = create_training_dataset(_tf_dataset)
        
        # create the testing dataset for the client
        _tf_dataset = testing_source.create_tf_dataset_for_client(client_id)
        testing_dataset = create_testing_dataset(_tf_dataset)
        
        federated_data[client_id] = (training_dataset, testing_dataset)
    
    return federated_data

In [None]:
_random_state = np.random.RandomState(42)
federated_data = create_federated_data(emnist_train,
                                       emnist_test,
                                       sample_size=0.01,
                                       random_state=_random_state)

In [None]:
# keys are client ids, values are (training_dataset, testing_dataset) pairs
len(federated_data)

# Creating a model with Keras

If you are using Keras, you likely already have code that constructs a Keras model. Since the model will need to be replicated on each of the client devices we wrap the model in a no-argument Python function, a representation of which, will eventually be invoked on each client to create the model on that client.

In [None]:
def create_keras_model_fn() -> keras.Model:
    model_fn = keras.models.Sequential([
      keras.layers.Input(shape=(NUMBER_FEATURES,)),
      keras.layers.Dense(units=NUMBER_TARGETS),
      keras.layers.Softmax(),
    ])
    return model_fn


In order to use any model with TFF, it needs to be wrapped in an instance of the [`tff.learning.Model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model) interface, which exposes methods to stamp the model's forward pass, metadata properties, etc, and also introduces additional elements such as ways to control the process of computing federated metrics. 

Once you have a Keras model like the one we've just defined above, you can have TFF wrap it for you by invoking [`tff.learning.from_keras_model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/from_keras_model), passing the model and a sample data batch as arguments, as shown below.

In [None]:
tff.learning.from_keras_model?

In [None]:
def create_tff_model_fn() -> tff.learning.Model:
    keras_model = create_keras_model_fn()
    dummy_batch = (tf.constant(0.0, shape=(TRAINING_BATCH_SIZE, NUMBER_FEATURES), dtype=tf.float32),
                   tf.constant(0, shape=(TRAINING_BATCH_SIZE, 1), dtype=tf.int32))
    loss_fn = (keras.losses
                    .SparseCategoricalCrossentropy())
    metrics = [
        keras.metrics.SparseCategoricalAccuracy()
    ]
    tff_model_fn = (tff.learning
                       .from_keras_model(keras_model, dummy_batch, loss_fn, None, metrics))
    return tff_model_fn


Again, since our model will need to be replicated on each of the client devices we wrap the model in a no-argument Python function, a representation of which, will eventually be invoked on each client to create the model on that client.

# Training the model on federated data

Now that we have a model wrapped as `tff.learning.Model` for use with TFF, we can let TFF construct a Federated Averaging algorithm by invoking the helper function `tff.learning.build_federated_averaging_process` as follows.

Keep in mind that the argument needs to be a constructor (such as `create_tff_model_fn` above), not an already-constructed instance, so that the construction of your model can happen in a context controlled by TFF.

One critical note on the Federated Averaging algorithm below, there are 2 optimizers: a 

1. `client_optimizer_fn` which is only used to compute local model updates on each client. 
2. `server_optimizer_fn` applies the averaged update to the global model on the server. 

N.B. the choice of optimizer and learning rate may need to be different than those you would use to train the model on a standard i.i.d. dataset. Start with stochastic gradient descent with a smaller (than normal) learning rate.

In [None]:
tff.learning.build_federated_averaging_process?

In [None]:
CLIENT_LEARNING_RATE = 1e-2
SERVER_LEARNING_RATE = 1e0


def create_client_optimizer(learning_rate: float = CLIENT_LEARNING_RATE,
                            momentum: float = 0.0,
                            nesterov: bool = False) -> keras.optimizers.Optimizer:
    client_optimizer = (keras.optimizers
                             .SGD(learning_rate, momentum, nesterov))
    return client_optimizer


def create_server_optimizer(learning_rate: float = SERVER_LEARNING_RATE,
                            momentum: float = 0.0,
                            nesterov: bool = False) -> keras.optimizers.Optimizer:
    server_optimizer = (keras.optimizers
                             .SGD(learning_rate, momentum, nesterov))
    return server_optimizer


federated_averaging_process = (tff.learning
                                  .build_federated_averaging_process(create_tff_model_fn, 
                                                                     create_client_optimizer,
                                                                     create_server_optimizer,
                                                                     client_weight_fn=None,
                                                                     stateful_delta_aggregate_fn=None,
                                                                     stateful_model_broadcast_fn=None))


What just happened? TFF has constructed a pair of *federated computations* (i.e., programs in TFF's internal glue language) and packaged them into a [`tff.utils.IterativeProcess`](https://www.tensorflow.org/federated/api_docs/python/tff/utils/IterativeProcess) in which these computations are available as a pair of properties `initialize` and `next`.

It is a goal of TFF to define computations in a way that they could be executed in real federated learning settings, but currently only local execution simulation runtime is implemented. To execute a computation in a simulator, you simply invoke it like a Python function. This default interpreted environment is not designed for high performance, but it will suffice for this tutorial.


## `initialize`

A function that takes no arguments and returns the state of the federated averaging process on the server. This function is only called to initialize a federated averaging process after it has been created.

In [None]:
# () -> SERVER_STATE
print(federated_averaging_process.initialize.type_signature)

In [None]:
state = federated_averaging_process.initialize()

## `next`

A function that takes current server state and federated data as arguments and returns the updated server state as well as any training metrics. Calling `next` performs a single round of federated averaging consisting of the following steps.

1. pushing the server state (including the model parameters) to the clients
2. on-device training on their local data
3. collecting and averaging model updates
4. producing a new updated model at the server.

In [None]:
# extract the training datasets from the federated data
federated_training_data = [training_dataset for _, (training_dataset, _) in federated_data.items()]

# SERVER_STATE, FEDERATED_DATA -> SERVER_STATE, TRAINING_METRICS
state, metrics = federated_averaging_process.next(state, federated_training_data)
print(f"round: 0, metrics: {metrics}")

Let's run a few more rounds on the same training data (which will over-fit to a particular set of clients but will converge faster).

In [None]:
number_training_rounds = 15
for n in range(1, number_training_rounds):
    state, metrics = federated_averaging_process.next(state, federated_training_data)
    print(f"round:{n}, metrics:{metrics}")


# First attempt at simulating federated averaging

A proper federated averaging simulation would randomly sample new clients for each training round, allow for evaluation of training progress on training and testing data, and log training and testing metrics to TensorBoard for reference.

Here we define a function that randomly sample new clients prior to each training round and logs training metrics TensorBoard. We defer handling testing data until we discuss federated evaluation towards the end of the tutorial.

In [None]:
def simulate_federated_averaging(federated_averaging_process: tff.utils.IterativeProcess,
                                 training_source: tff.simulation.ClientData,
                                 testing_source: tff.simulation.ClientData,
                                 sample_size: typing.Union[float, int],
                                 random_state: np.random.RandomState,
                                 number_rounds: int,
                                 initial_state: None = None,
                                 tensorboard_logging_dir: str = None):
    
    state = federated_averaging_process.initialize() if initial_state is None else initial_state
    
    if tensorboard_logging_dir is not None:
        
        if not os.path.isdir(tensorboard_logging_dir):
            os.makedirs(tensorboard_logging_dir)

        summary_writer = (tf.summary
                            .create_file_writer(tensorboard_logging_dir))

        with summary_writer.as_default():
            for n in range(number_rounds):
                federated_data = create_federated_data(training_source,
                                                       testing_source,
                                                       sample_size,
                                                       random_state)
                anonymized_training_data = [dataset for _, (dataset, _) in federated_data.items()]
                state, metrics = federated_averaging_process.next(state, anonymized_training_data)
                print(f"Round: {n}, Training metrics: {metrics}")

                for name, value in metrics._asdict().items():
                    tf.summary.scalar(name, value, step=n)          
    else:
        for n in range(number_rounds):
            federated_data = create_federated_data(training_source,
                                                   testing_source,
                                                   sample_size,
                                                   random_state)
            anonymized_training_data = [dataset for _, (dataset, _) in federated_data.items()]
            state, metrics = federated_averaging_process.next(state, anonymized_training_data)
            print(f"Round: {n}, Training metrics: {metrics}")
    
    return state, metrics

In [None]:
federated_averaging_process = (tff.learning
                                  .build_federated_averaging_process(create_tff_model_fn, 
                                                                     create_client_optimizer,
                                                                     create_server_optimizer,
                                                                     client_weight_fn=None,
                                                                     stateful_delta_aggregate_fn=None,
                                                                     stateful_model_broadcast_fn=None))
_random_state = np.random.RandomState(42)
_tensorboard_logging_dir = "../results/logs/tensorboard"
updated_state, current_metrics = simulate_federated_averaging(federated_averaging_process,
                                                              training_source=emnist_train,
                                                              testing_source=emnist_test,
                                                              sample_size=0.01,
                                                              random_state=_random_state,
                                                              number_rounds=5,
                                                              tensorboard_logging_dir=_tensorboard_logging_dir)

In [None]:
updated_state

In [None]:
current_metrics

# Customizing the model implementation

Keras is the recommended high-level model API for TensorFlow and you should be using Keras models and creating TFF models using [`tff.learning.from_keras_model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/from_keras_model) whenever possible.

However, [`tff.learning`](https://www.tensorflow.org/federated/api_docs/python/tff/learning) provides a lower-level model interface, [`tff.learning.Model`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/Model), that exposes the minimal functionality necessary for using a model for federated learning. Directly implementing this interface (possibly still using building blocks from [`keras`](https://www.tensorflow.org/guide/keras)) allows for maximum customization without modifying the internals of the federated learning algorithms.

Now we are going to repeat the above from scratch!

## Defining model variables

We start by defining a new Python class that inherits from `tff.learning.Model`. In the class constructor (i.e., the `__init__` method) we will initialize all relevant variables using TF primatives as well as define the our "input spec" which defines the shape and types of the tensors that will hold input data. 

In [None]:
class MNISTModel(tff.learning.Model):

    def __init__(self):
        
        # initialize some trainable variables
        self._weights = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_FEATURES, NUMBER_TARGETS)),
            name="weights",
            trainable=True
        )
        self._bias = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_TARGETS,)),
            name="bias",
            trainable=True
        )
        
        # initialize some variables used in computing metrics
        self._number_examples = tf.Variable(0.0, name='number_examples', trainable=False)
        self._total_loss = tf.Variable(0.0, name='total_loss', trainable=False)
        self._number_true_positives = tf.Variable(0.0, name='number_true_positives', trainable=False)
        
        # define the input spec
        self._input_spec = collections.OrderedDict([
            ('X', tf.TensorSpec([None, NUMBER_FEATURES], tf.float32)),
            ('y', tf.TensorSpec([None, 1], tf.int32))
        ])

    @property
    def input_spec(self):
        return self._input_spec
    
    @property
    def local_variables(self):
        return [self._number_examples, self._total_loss, self._number_true_positives]

    @property
    def non_trainable_variables(self):
        return []
    
    @property
    def trainable_variables(self):
        return [self._weights, self._bias]

    

## Defining the forward pass

With the variables for model parameters and cumulative statistics in place we can now define the `forward_pass` method that computes loss, makes predictions, and updates the cumulative statistics for a single batch of input data.

In [None]:
class MNISTModel(tff.learning.Model):

    def __init__(self):
        
        # initialize some trainable variables
        self._weights = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_FEATURES, NUMBER_TARGETS)),
            name="weights",
            trainable=True
        )
        self._bias = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_TARGETS,)),
            name="bias",
            trainable=True
        )
        
        # initialize some variables used in computing metrics
        self._number_examples = tf.Variable(0.0, name='number_examples', trainable=False)
        self._total_loss = tf.Variable(0.0, name='total_loss', trainable=False)
        self._number_true_positives = tf.Variable(0.0, name='number_true_positives', trainable=False)
        
        # define the input spec
        self._input_spec = collections.OrderedDict([
            ('X', tf.TensorSpec([None, NUMBER_FEATURES], tf.float32)),
            ('y', tf.TensorSpec([None, 1], tf.int32))
        ])

    @property
    def input_spec(self):
        return self._input_spec
    
    @property
    def local_variables(self):
        return [self._number_examples, self._total_loss, self._number_true_positives]

    @property
    def non_trainable_variables(self):
        return []
    
    @property
    def trainable_variables(self):
        return [self._weights, self._bias]

    @tf.function
    def _count_true_positives(self, y_true, y_pred):
        return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32))

    @tf.function
    def _linear_transformation(self, batch):
        X = batch['X']
        W, b = self.trainable_variables
        Z = tf.matmul(X, W) + b
        return Z
    
    @tf.function
    def _loss_fn(self, y_true, probabilities):
        return -tf.reduce_mean(tf.reduce_sum(tf.one_hot(y_true, NUMBER_TARGETS) * tf.math.log(probabilities), axis=1))
    
    @tf.function
    def _model_fn(self, batch):
        Z = self._linear_transformation(batch)
        probabilities = tf.nn.softmax(Z)
        return probabilities
    
    @tf.function
    def forward_pass(self, batch, training=True):
        probabilities = self._model_fn(batch)
        y_pred = tf.argmax(probabilities, axis=1, output_type=tf.int32)
        y_true = tf.reshape(batch['y'], shape=[-1])

        # compute local variables
        loss = self._loss_fn(y_true, probabilities)
        true_positives = self._count_true_positives(y_true, y_pred)
        number_examples = tf.size(y_true, out_type=tf.float32)
        
        # update local variables
        self._total_loss.assign_add(loss)
        self._number_true_positives.assign_add(true_positives)
        self._number_examples.assign_add(number_examples)

        batch_output = tff.learning.BatchOutput(
            loss=loss,
            predictions=y_pred,
            num_examples=tf.cast(number_examples, tf.int32)
        )
        return batch_output


## Defining the local metrics

Next, we define a method `report_local_outputs` that returns a set of local metrics. These are the values, in addition to model updates (which are handled automatically), that are eligible to be aggregated to the server in a federated learning or evaluation process.

Finally, we need to determine how to aggregate the local metrics emitted by each device by defining `federated_output_computation`. This is the only part of the code that isn't written in TensorFlow - it's a federated computation expressed in TFF.

In [None]:
class MNISTModel(tff.learning.Model):

    def __init__(self):
        
        # initialize some trainable variables
        self._weights = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_FEATURES, NUMBER_TARGETS)),
            name="weights",
            trainable=True
        )
        self._bias = tf.Variable(
            initial_value=lambda: tf.zeros(dtype=tf.float32, shape=(NUMBER_TARGETS,)),
            name="bias",
            trainable=True
        )
        
        # initialize some variables used in computing metrics
        self._number_examples = tf.Variable(0.0, name='number_examples', trainable=False)
        self._total_loss = tf.Variable(0.0, name='total_loss', trainable=False)
        self._number_true_positives = tf.Variable(0.0, name='number_true_positives', trainable=False)
        
        # define the input spec
        self._input_spec = collections.OrderedDict([
            ('X', tf.TensorSpec([None, NUMBER_FEATURES], tf.float32)),
            ('y', tf.TensorSpec([None, 1], tf.int32))
        ])

    @property
    def federated_output_computation(self):
        return self._aggregate_metrics_across_clients
    
    @property
    def input_spec(self):
        return self._input_spec
    
    @property
    def local_variables(self):
        return [self._number_examples, self._total_loss, self._number_true_positives]

    @property
    def non_trainable_variables(self):
        return []
    
    @property
    def trainable_variables(self):
        return [self._weights, self._bias]
    
    @tff.federated_computation
    def _aggregate_metrics_across_clients(metrics):
        aggregated_metrics = {
            'number_examples': tff.federated_sum(metrics.number_examples),
            'average_loss': tff.federated_mean(metrics.average_loss, metrics.number_examples),
            'accuracy': tff.federated_mean(metrics.accuracy, metrics.number_examples)
        }
        return aggregated_metrics

    @tf.function
    def _count_true_positives(self, y_true, y_pred):
        return tf.reduce_sum(tf.cast(tf.equal(y_true, y_pred), tf.float32))

    @tf.function
    def _linear_transformation(self, batch):
        X = batch['X']
        W, b = self.trainable_variables
        Z = tf.matmul(X, W) + b
        return Z
    
    @tf.function
    def _loss_fn(self, y_true, probabilities):
        return -tf.reduce_mean(tf.reduce_sum(tf.one_hot(y_true, NUMBER_TARGETS) * tf.math.log(probabilities), axis=1))
    
    @tf.function
    def _model_fn(self, batch):
        Z = self._linear_transformation(batch)
        probabilities = tf.nn.softmax(Z)
        return probabilities
    
    @tf.function
    def forward_pass(self, batch, training=True):
        probabilities = self._model_fn(batch)
        y_pred = tf.argmax(probabilities, axis=1, output_type=tf.int32)
        y_true = tf.reshape(batch['y'], shape=[-1])

        # compute local variables
        loss = self._loss_fn(y_true, probabilities)
        true_positives = self._count_true_positives(y_true, y_pred)
        number_examples = tf.cast(tf.size(y_true), tf.float32)
        
        # update local variables
        self._total_loss.assign_add(loss)
        self._number_true_positives.assign_add(true_positives)
        self._number_examples.assign_add(number_examples)

        batch_output = tff.learning.BatchOutput(
            loss=loss,
            predictions=y_pred,
            num_examples=tf.cast(number_examples, tf.int32)
        )
        return batch_output

    @tf.function
    def report_local_outputs(self):
        local_metrics = collections.OrderedDict([
            ('number_examples', self._number_examples),
            ('average_loss', self._total_loss / self._number_examples),
            ('accuracy', self._number_true_positives / self._number_examples)
        ])
        return local_metrics


Here are a few points worth highlighting:

* All state that your model will use must be captured as TensorFlow variables, as TFF does not use Python at runtime (remember your code should be written such that it can be deployed to mobile devices).
* Your model should describe what form of data it accepts (input_spec), as in general, TFF is a strongly-typed environment and wants to determine type signatures for all components. Declaring the format of your model's input is an essential part of it.
* Although technically not required, we recommend wrapping all TensorFlow logic (forward pass, metric calculations, etc.) as tf.functions, as this helps ensure the TensorFlow can be serialized, and removes the need for explicit control dependencies.

The above is sufficient for evaluation and algorithms like Federated SGD. However, for Federated Averaging, we need to specify how the model should train locally on each batch.

In [None]:
class MNISTrainableModel(MNISTModel, tff.learning.TrainableModel):
    
    def __init__(self, optimizer):
        super().__init__()
        self._optimizer = optimizer

    @tf.function
    def train_on_batch(self, batch):
        with tf.GradientTape() as tape:
            output = self.forward_pass(batch)
        gradients = tape.gradient(output.loss, self.trainable_variables)
        self._optimizer.apply_gradients(zip(tf.nest.flatten(gradients), tf.nest.flatten(self.trainable_variables)))
        return output


# Simulating federated training with the new model

With all the above in place, the remainder of the process looks like what we've seen already - just replace the model constructor with the constructor of our new model class, and use the two federated computations in the iterative process you created to cycle through training rounds.

In [None]:
def create_custom_tff_model_fn():
    optimizer = keras.optimizers.SGD(learning_rate=0.02)
    return MNISTrainableModel(optimizer)
    
federated_averaging_process = (tff.learning
                                  .build_federated_averaging_process(create_custom_tff_model_fn))

_random_state = np.random.RandomState(42)
updated_state, current_metrics = simulate_federated_averaging(federated_averaging_process,
                                                              training_source=emnist_train,
                                                              testing_source=emnist_test,
                                                              sample_size=0.01,
                                                              random_state=_random_state,
                                                              number_rounds=10)

In [None]:
updated_state

In [None]:
current_metrics

# Evaluation

All of our experiments so far presented only federated training metrics - the average metrics over all batches of data trained across all clients in the round. Should we be concerened about overfitting? Yes! In federated averaging algorithms there are two different ways to over-fit. 

1. Overfitting the shared model (especially if we use the same set of clients on each round).
2. Over-ftting local models on the clients.


## Federated evaluation

To perform evaluation on federated data, you can construct another federated computation designed for just this purpose, using the [`tff.learning.build_federated_evaluation`](https://www.tensorflow.org/federated/api_docs/python/tff/learning/build_federated_evaluation) function, and passing in your model constructor as an argument. Note that evaluation doesn't perform gradient descent and there's no need to construct optimizers.


In [None]:
tff.learning.build_federated_evaluation?

In [None]:
federated_evaluation = (tff.learning
                          .build_federated_evaluation(create_custom_tff_model_fn))

In [None]:
# function type signature: SERVER_MODEL, FEDERATED_DATA -> METRICS
print(federate_evaluation.type_signature)

The `federated_evaluation` function is similar to `tff.utils.IterativeProcess.next` but with two important differences. 

1. Function does not return the server state; since evaluation doesn't modify the model or any other aspect of state - you can think of it as stateless.
2. Function only needs the model and doesn't require any other part of server state that might be associated with training, such as optimizer variables.

In [None]:
training_metrics = federated_evaluation(updated_state.model, federated_training_data)

In [None]:
training_metrics

Note the numbers may look marginally better than what was reported by the last round of training. By convention, the training metrics reported by the iterative training process generally reflect the performance of the model at the beginning of the training round, so the evaluation metrics will always be one step ahead.

## Evaluating on client data not used in training

Since we are training a shared model for digit classication we might also want to evaluate the performance of the model on client test datasets where the corresponding training dataset was not used in training.

In [None]:
_random_state = np.random.RandomState(42)
client_datasets = sample_client_datasets(emnist_test, sample_size=0.01, random_state=_random_state)
federated_testing_data = [create_testing_dataset(client_dataset) for _, client_dataset in client_datasets.items()]

In [None]:
testing_metrics = federated_evaluation(updated_state.model, federated_testing_data)

In [None]:
testing_metrics

# Adding evaluation to our federated averaging simulation

In [None]:
def simulate_federated_averaging(federated_averaging_process: tff.utils.IterativeProcess,
                                 federated_evaluation,
                                 training_source: tff.simulation.ClientData,
                                 testing_source: tff.simulation.ClientData,
                                 sample_size: typing.Union[float, int],
                                 random_state: np.random.RandomState,
                                 number_rounds: int,
                                 tensorboard_logging_dir: str = None):
    
    state = federated_averaging_process.initialize()
    
    if tensorboard_logging_dir is not None:
        
        if not os.path.isdir(tensorboard_logging_dir):
            os.makedirs(tensorboard_logging_dir)

        summary_writer = (tf.summary
                            .create_file_writer(tensorboard_logging_dir))

        with summary_writer.as_default():
            for n in range(number_rounds):
                federated_data = create_federated_data(training_source,
                                                       testing_source,
                                                       sample_size,
                                                       random_state)
                
                # extract the training and testing datasets
                anonymized_training_data = []
                anonymized_testing_data = []
                for training_dataset, testing_dataset in federated_data.values():
                    anonymized_training_data.append(training_dataset)
                    anonymized_testing_data.append(testing_dataset)
        
                state, _ = federated_averaging_process.next(state, anonymized_training_data)
                training_metrics = federated_evaluation(state.model, anonymized_training_data)
                testing_metrics = federated_evaluation(state.model, anonymized_testing_data)
                print(f"Round: {n}, Training metrics: {training_metrics}, Testing metrics: {testing_metrics}")

                # tensorboard logging
                for name, value in training_metrics._asdict().items():
                    tf.summary.scalar(name, value, step=n)
                
                for name, value in testing_metrics._asdict().items():
                    tf.summary.scalar(name, value, step=n)
    else:
        for n in range(number_rounds):
            federated_data = create_federated_data(training_source,
                                                       testing_source,
                                                       sample_size,
                                                       random_state)
                
            # extract the training and testing datasets
            anonymized_training_data = []
            anonymized_testing_data = []
            for training_dataset, testing_dataset in federated_data.values():
                anonymized_training_data.append(training_dataset)
                anonymized_testing_data.append(testing_dataset)

            state, _ = federated_averaging_process.next(state, anonymized_training_data)
            training_metrics = federated_evaluation(state.model, anonymized_training_data)
            testing_metrics = federated_evaluation(state.model, anonymized_testing_data)
            print(f"Round: {n}, Training metrics: {training_metrics}, Testing metrics: {testing_metrics}")
    
    return state, (training_metrics, testing_metrics)

In [None]:
federated_averaging_process = (tff.learning
                                  .build_federated_averaging_process(create_tff_model_fn, 
                                                                     create_client_optimizer,
                                                                     create_server_optimizer,
                                                                     client_weight_fn=None,
                                                                     stateful_delta_aggregate_fn=None,
                                                                     stateful_model_broadcast_fn=None))

federated_evaluation = (tff.learning
                           .build_federated_evaluation(create_tff_model_fn))

_random_state = np.random.RandomState(42)
updated_state, current_metrics = simulate_federated_averaging(federated_averaging_process,
                                                              federated_evaluation,
                                                              training_source=emnist_train,
                                                              testing_source=emnist_test,
                                                              sample_size=0.01,
                                                              random_state=_random_state,
                                                              number_rounds=15)

# Wrapping up

## Interesting resources

[PySyft](https://github.com/OpenMined/PySyft) is a Python library for secure and private Deep Learning created by [OpenMined](https://www.openmined.org/). PySyft decouples private data from model training, using
[Federated Learning](https://ai.googleblog.com/2017/04/federated-learning-collaborative.html),
[Differential Privacy](https://en.wikipedia.org/wiki/Differential_privacy),
and [Multi-Party Computation (MPC)](https://en.wikipedia.org/wiki/Secure_multi-party_computation) within the main Deep Learning frameworks like PyTorch and TensorFlow.
