[TUTORIAL](https://colab.research.google.com/github/tensorflow/federated/blob/v0.21.0/docs/tutorials/federated_learning_for_image_classification.ipynb#scrollTo=Q3ynrxd53HzY)

In [1]:
import nest_asyncio
nest_asyncio.apply()
import tensorflow as tf
import tensorflow_federated as tff

## Preparing Input Data

In [2]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
weights_hid = []

In [3]:
NUM_CLIENTS = 10
BATCH_SIZE = 20

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

In [4]:
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

## Preparing the Model

In [5]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

In [6]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

## Building your own Federated Learning algorithm

While the tff.learning API allows one to create many variants of Federated Averaging, there are other federated algorithms that do not fit neatly into this framework. For example, you may want to add regularization, clipping, or more complicated algorithms such as federated GAN training. You may also be instead be interested in federated analytics.

For these more advanced algorithms, we'll have to write our own custom algorithm using TFF. In many cases, federated algorithms have 4 main components:

A server-to-client broadcast step.
A local client update step.
A client-to-server upload step.
A server update step.

In TFF, we generally represent federated algorithms as a tff.templates.IterativeProcess (which we refer to as just an IterativeProcess throughout). This is a class that contains initialize and next functions. Here, initialize is used to initialize the server, and next will perform one communication round of the federated algorithm. Let's write a skeleton of what our iterative process for FedAvg should look like.

First, we have an initialize function that simply creates a tff.learning.Model, and returns its trainable weights.

In [7]:
def initialize_fn():
  model = model_fn()
  return model.trainable_variables

This function looks good, but as we will see later, we will need to make a small modification to make it a "TFF computation".

We also want to sketch the next_fn

In [8]:
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights

## Tensorflow Block

### Client Update
We will use our tff.learning.Model to do client training in essentially the same way you would train a TensorFlow model. In particular, we will use tf.GradientTape to compute the gradient on batches of data, then apply these gradient using a client_optimizer. We focus only on the trainable weights.

In [9]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)
  
  print(client_weights)

  return client_weights

### Server Update
The server update for FedAvg is simpler than the client update. We will implement "vanilla" federated averaging, in which we simply replace the server model weights by the average of the client model weights. Again, we only focus on the trainable weights.

In [10]:
@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

The snippet could be simplified by simply returning the mean_client_weights. However, more advanced implementations of Federated Averaging use mean_client_weights with more sophisticated techniques, such as momentum or adaptivity.


## Building your own Federated Learning Algorithm
Now that we've gotten a glimpse of the Federated Core, we can build our own federated learning algorithm. Remember that above, we defined an initialize_fn and next_fn for our algorithm. The next_fn will make use of the client_update and server_update we defined using pure TensorFlow code.

However, in order to make our algorithm a federated computation, we will need both the next_fn and initialize_fn to each be a tff.federated_computation.

### Creating the initialization computation
The initialize function will be quite simple: We will create a model using model_fn. However, remember that we must separate out our TensorFlow code using tff.tf_computation.

In [11]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

We can then pass this directly into a federated computation using 

*   List item
*   List item

`tff.federated_value`.

In [12]:
@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

### Creating the `next_fn`

We now use our client and server update code to write the actual algorithm. We will first turn our `client_update` into a `tff.tf_computation` that accepts a client datasets and server weights, and outputs an updated client weights tensor.

We will need the corresponding types to properly decorate our function. Luckily, the type of the server weights can be extracted directly from our model.

In [13]:
whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

print(str(tf_dataset_type))
model_weights_type = server_init.type_signature.result
print(model_weights_type)


<float32[?,784],int32[?,1]>*
<float32[784,10],float32[10]>


In [14]:
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

[<tf.Variable 'dense/kernel:0' shape=(784, 10) dtype=float32>, <tf.Variable 'dense/bias:0' shape=(10,) dtype=float32>]


:The `tff.tf_computation` version of the server update can be defined in a similar way, using types we've already extracted.

In [15]:
@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  print('hello server')
  return server_update(model, mean_client_weights)

hello server


Last, but not least, we need to create the `tff.federated_computation` that brings this all together. This function will accept two *federated values*, one corresponding to the server weights (with placement `tff.SERVER`), and the other corresponding to the client datasets (with placement `tff.CLIENTS`).

Note that both these types were defined above! We simply need to give them the proper placement using `tff.FederatedType`.

In [16]:
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

Remember the 4 elements of an FL algorithm?

1. A server-to-client broadcast step.
2. A local client update step.
3. A client-to-server upload step.
4. A server update step.

Now that we've built up the above, each part can be compactly represented as a single line of TFF code. This simplicity is why we had to take extra care to specify things such as federated types!

In [17]:
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))
  
  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights

We now have a tff.federated_computation for both the algorithm initialization, and for running one step of the algorithm. To finish our algorithm, we pass these into tff.templates.IterativeProcess.


In [18]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

## Evaluating the algorithm

Let's run a few rounds, and see how the loss changes. First, we will define an evaluation function using the centralized approach discussed in the second tutorial.

We first create a centralized evaluation dataset, and then apply the same preprocessing we used for the training data.

In [19]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

Next, we write a function that accepts a server state, and uses Keras to evaluate on the test dataset. If you're familiar with `tf.Keras`, this will all look familiar, though note the use of `set_weights`!

In [20]:
def evaluate(server_state):
  keras_model = create_keras_model()
  keras_model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(server_state)
  keras_model.evaluate(central_emnist_test)

In [21]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



Let's train for a few rounds and see if anything changes.

In [22]:
for round in range(30):
  server_state = federated_algorithm.next(server_state, federated_train_data)

In [23]:
evaluate(server_state)



In [24]:
print((server_state[1]))

[-0.00193115 -0.00283541  0.00199257  0.00015633  0.00357123  0.00100648
 -0.00112897  0.00109023  0.00087481 -0.00279613]
