# Comparison between Fed-Avg and Fed-Prox algorithm for non-i.i.d dataset for 4 clients

## Importing the libraries

In [1]:
import collections
from matplotlib import pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)

2022-10-03 02:41:11.368876: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-10-03 02:41:11.722048: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2022-10-03 02:41:11.722073: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
2022-10-03 02:41:11.779037: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2022-10-03 02:41:12.654857: W tensorflow/stream_executor/platform/de

## Loading the data

In [34]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [35]:
NUM_CLIENTS = 16 # Number of clients
NUM_EPOCHS = 10 # Number of epochs for the local data at each client
NUM_ROUNDS = 20 # Number of rounds to train the model
BATCH_SIZE = 20 # Batch size
SHUFFLE_BUFFER = 100
PREFETCH_BUFFER = 10

def preprocess(dataset):

  def batch_format_fn(element):
    return collections.OrderedDict(
        x=tf.reshape(element['pixels'], [-1, 784]),
        y=tf.reshape(element['label'], [-1, 1]))

  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).batch(
      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

def make_federated_data(client_data, client_ids):
  return [
      preprocess(client_data.create_tf_dataset_for_client(x))
      for x in client_ids
  ]

sample_clients = emnist_train.client_ids[0:NUM_CLIENTS]
federated_train_data = make_federated_data(emnist_train, sample_clients)

## Creating the Keras Model

In [36]:
example_dataset = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0])

preprocessed_example_dataset = preprocess(example_dataset)

In [37]:
# Model for each client
def create_keras_model():
  return tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])

def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

## Training the model
The second of the pair of federated computations, `next`, represents a single
round of Federated Averaging, which consists of pushing the server state
(including the model parameters) to the clients, on-device training on their
local data, collecting and averaging model updates, and producing a new updated
model at the server.

In [38]:
# Fedprox algorithm, proximal strength is the proximal term in FedProx
iterative_process = tff.learning.algorithms.build_weighted_fed_prox(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1),
    proximal_strength=2.0)

state = iterative_process.initialize()
for round_num in range(1, NUM_ROUNDS+1):
  result = iterative_process.next(state, federated_train_data)
  state = result.state
  metrics = result.metrics
print(metrics['client_work']['train']['sparse_categorical_accuracy'])

0.89252275
