<a href="https://colab.research.google.com/github/harikuts/federated-learning-trials/blob/master/FederatedLearningRepro.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview

This notebook contains the reproduction of results of the original paper on federated learning.

## Plan

The roadmap for development is as follows:
*   Construct standard MNIST example.
*   To be continued.




# Standard MNIST Example

A standard MNIST example from Keras (https://keras.io/examples/mnist_cnn/) is used as a basis to compare our fedeerated model to.

In [0]:
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense, Dropout, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras import backend as K

import pdb

# Configuration
batch_size = 128
num_classes = 10
epochs = 12

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()

if K.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# pdb.set_trace()

model = Sequential()
model.add(Conv2D(32, kernel_size=(3, 3),
                 activation='relu',
                 input_shape=input_shape))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes, activation='softmax'))

model.compile(loss=keras.losses.categorical_crossentropy,
              optimizer=keras.optimizers.Adadelta(),
              metrics=['accuracy'])

model.fit(x_train, y_train,
          batch_size=batch_size,
          epochs=epochs,
          verbose=1,
          validation_data=(x_test, y_test))
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

# Federated Mode

In [10]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
from tensorflow.keras.layers import Conv2D, MaxPooling2D

import random
import pdb

# Configuration
batch_size = 128
num_classes = 10
epochs = 12

# Federated configuration
num_clients = 10
num_server_rounds = 8
num_client_rounds = 2

# mnist_train = tfds.load(name="mnist", split="train")
# mnist_train = mnist_train.repeat().shuffle(1024).batch(32)
# mnist_train = mnist_train.prefetch(tf.data.experimental.AUTOTUNE)
# mnist_test, info = tfds.load("mnist", split="test", with_info=True)

mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Splitting the dataset for different clients
nonIID = True
if nonIID:
  percentageMarkers = []
  for i in range(num_clients-1):
    percentageMarkers.append(random.random())
  percentageMarkers.append(1.0)
  percentageMarkers = sorted(percentageMarkers)
else:
  percentageMarkers = [1/num_clients * (n+1) for n in range(num_clients)]

# pdb.set_trace()

client_x_trains = []
client_y_trains = []
xMarkers = [int(marker * len(x_train)) for marker in percentageMarkers]
yMarkers = [int(marker * len(y_train)) for marker in percentageMarkers]
# pdb.set_trace()
for j in range(len(percentageMarkers)):
  client_x_trains.append(x_train[(xMarkers[j-1] if j > 0 else 0):xMarkers[j]])
  client_y_trains.append(x_train[(yMarkers[j-1] if j > 0 else 0):yMarkers[j]])

# pdb.set_trace()

# Model creation function
def createCNN():
  model = Sequential()
  model.add(Conv2D(32, kernel_size=(3, 3),
                  activation='relu'))
  model.add(Conv2D(64, (3, 3), activation='relu'))
  model.add(MaxPooling2D(pool_size=(2, 2)))
  model.add(Dropout(0.25))
  model.add(Flatten())
  model.add(Dense(128, activation='relu'))
  model.add(Dropout(0.5))
  model.add(Dense(num_classes, activation='softmax'))

  model.compile(loss=tf.keras.losses.categorical_crossentropy,
                optimizer=tf.keras.optimizers.Adadelta(),
                metrics=['accuracy'])
  return model

# Global model initialization
global_model = createCNN()

# Server action
for server_round in range(num_server_rounds):
  print("SERVER ROUND ", server_round, ":")
  # Clients' actions
  for client in range(num_clients):
    print("\tCLIENT ", client, "...")
    # Accept the global model
    local_model = tf.keras.models.clone_model(global_model)
    local_weights = []
    # Per each round
    for client_round in range(num_client_rounds):
      print("\t\tRound ", client_round)
      # Train on the local model
      round_model = tf.keras.models.clone_model(local_model)




SERVER ROUND  0 :
	CLIENT  0 ...
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


		Round  0
		Round  1
	CLIENT  1 ...
		Round  0
		Round  1
	CLIENT  2 ...
		Round  0
		Round  1
	CLIENT  3 ...
		Round  0
		Round  1
	CLIENT  4 ...
		Round  0
		Round  1
	CLIENT  5 ...
		Round  0
		Round  1
	CLIENT  6 ...
		Round  0
		Round  1
	CLIENT  7 ...
		Round  0
		Round  1
	CLIENT  8 ...
		Round  0
		Round  1
	CLIENT  9 ...
		Round  0
		Round  1
SERVER ROUND  1 :
	CLIENT  0 ...
		Round  0
		Round  1
	CLIENT  1 ...
		Round  0
		Round  1
	CLIENT  2 ...
		Round  0
		Round  1
	CLIENT  3 ...
		Round  0
		Round  1
	CLIENT  4 ...
		Round  0
		Round  1
	CLIENT  5 ...
		Round  0
		Round  1
	CLIENT  6 ...
		Round  0
		Round  1
	CLIENT  7 ...
		Round  0
		Round  1
	CLIENT  8 ...
		Round  0
		Round  1
	CLIENT  9 ...
		Round  0
		Round  1
SERVER ROUND  2 :
	CLIENT  0 ...
		Round  0
		Round  1
	CLIENT  1 ...
		Round  0
		Round  1
	CLIENT  2 ...
		Round  0
		Round  1
	CLIENT  3 ...
		Round  0
		Round  1
	CLIENT  4 ...
		Round  0
		Round  1
	CLIENT  5 ...
		Round  0
		Round  1
	CLIENT  6 ...
		