# Building your own federated learning algorithm
## Tensorflow

In [1]:
%pip install --quiet --upgrade tensorflow-federated-nightly
%pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()

In [2]:
import tensorflow as tf
import tensorflow_federated as tff

In [3]:
# Load the dataset
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

Downloading emnist_all.sqlite.lzma: 100%|██████████| 170507172/170507172 [00:30<00:00, 5633346.62it/s]


In [5]:
NUM_CLIENTS= 10
BATCH_SIZE= 20
def preprocess(dataset):
    def batch_format_fn(element):
        """Flatten a batch of EMNIST data and return a (features, label) tuple."""
        return (tf.reshape(element['pixels'], [-1, 784]), 
                tf.reshape(element['label'], [-1, 1]))
    return dataset.batch(BATCH_SIZE).map(batch_format_fn)

In [6]:
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

In [7]:
def create_keras_model():
    initializer = tf.keras.initializers.GlorotNormal(seed=0)
    return tf.keras.Sequential([
        tf.keras.layers.Input(shape=(784,)),
        tf.keras.layers.Dense(10, kernel_initializer=initializer),
        tf.keras.layers.Softmax(),
    ])

In [8]:
#Wrap the model as a federated model
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
        keras_model,
        input_spec=federated_train_data[0].element_spec,
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )

In [9]:
#initialize function that creates a tff.learning model
def initialize_fn():
    model=model_fn()
    return model.trainable_variables

In [10]:
#sketch next fn
def next_fn(server_weights,federated_dataset):
    #Broadcast the server weights to the clients
    server_weights_at_client = broadcast(server_weights)

    #each client computes their updated weights
    client_weights = client_update(federated_dataset, server_weights_at_client)

    #The server averages these updates
    mean_client_weights = mean(client_weights)

    #the server updates its model
    server_weights = server_update(mean_client_weights)

    return server_weights

## Client update

In [None]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

## Server update

In [11]:
@tf.function
def server_update(model, mean_cleint_weights):
    '''Updates the server model weights as the average of the clients model weights '''
    model_weights= model.trainable_variables
    # assign the mean client weights to the server model.
    tf.nest.map_structure(lambda x, y:x.assign(y),
                        model_weights, mean_cleint_weights )
    return model_weights