This tutorial is based on the paper 
"Communication-Efficient Learning of Deep Networks from Decentralized Data"
which introduced the concept of Federated Learning in 2017

This papers introduces a new decentralized approach to Machine Learning in which many separate devices learn a local model from their local data, thus those chunks of local data never leaves the device. The local models are sent to a global server which aggregates those models to create a global model.

This framework has introduced as a way of training the Google Keyboard preditive text feature across many Google (Android) smartphones that used the keyboard. The main driving force for the development of this framework was the possibility of training a global model without having to copy user data
from every smartphone due to privacy concerns

The federated learning experiment runs like this:

1. A fraction of devices are selected from a population of devices
2. Those devices receive the current "global model", which can be an uninitialized model
3. Those devices run a certain number of local epochs using their data and this global model
4. Each device will have its own local model which will be sent to the global server
5. The global server will aggregate all those local models by means of a simple weighted averaging
6. The updated global model will be distributed again to a new population of random devices (back to step 2)


The above loop will run for a certain number of global epochs, and then the training will be done and a final global model obtained.

The code in this notebook will demonstrate how this process works using convolutional neural networks and fully connected networks with the MNIST, FMNIST and CIFAR datasets.

In [None]:
# First, importing all the necessary Python modules
import os
import copy
import time
import numpy as np
import torch
import matplotlib
import matplotlib.pyplot as plt

# This module is where the actual local learning of each device will happen
from update import LocalUpdate, test_inference

# This module is defining the CNN's and FC networks that will be used in the tutorial
from models import MLP, CNNMnist, CNNFashion_Mnist, CNNCifar

# Those are utility functions to get the data, average model weights, etc
from utils import get_dataset, average_weights, exp_details


# To fix potential issues with matplot lib
import os    
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
# This class contains all the parameter for the experiment
# which are explained below
class Args:
    
  # This sets the number of global epochs the experiment will run
  global_epochs = 10

  # This sets the size of the population (how many local devices may participate in the experiment) 
  num_users = 100

  # This is the fraction of devices that are selected to participate in the experiment
  # every global epoch    
  frac = 0.1

  # This sets how many local epochs each device will run before sending its local model
  # to the global server
  local_ep = 5
    
  # This sets the batch size each device will use in its local training epochs
  local_bs = 4

  # This is the learning rate used by the devices
  lr = 0.1

  # This is the optimizer used by the devices
  # It can be sgd or adam 
  optimizer = "sgd"  
    
  # This sets which model will be used
  # it can be "mlp" (fully connected) or "cnn" (convolution)
  model = "cnn"

  # This sets how many channels the input data has
  # All datasets in the experiment have 1 channel
  num_channels = 1
    
  # Selects the dataset that will be used
  # it can be: cifar, mnist or fmnist
  dataset = "cifar"
    
  # All the datasets in the experiment have 10 classes  
  num_classes = 10

  # This sets if the experiment will use non-IID data
  # non-IID data means that the data present on any given device
  # is based on the device particular usage, hence its local dataset
  # is not representative of the whole population distribution 
  iid = 1

  # Set this to one to make every device selected for training
  # have an unequal (random) amount of data
  unequal = 1
    
  # If we want verbose logs, set this one. If not, set to zero
  verbose = 1

  # Value to set the Pytorch seed, to ensure reproducibility between runs   
  seed = 1

  # If the experiment will use the GPU for training   
  gpu = None

In [None]:
start_time = time.time()

# Parses the experiments parameters
args = Args()
exp_details(args)

# This variable is used by Pytorch to device if it is going to use GPU or not
if args.gpu is None:
    device = "cpu"
else:
    device = "cuda:0"

# Load select dataset and device population
# Device population is the population which the devices selected for training
# will be sampled from, each device will have its own local data
train_dataset, test_dataset, device_population = get_dataset(args)

In [None]:
# Builds the selected model
if args.model == 'cnn':
    # Convolutional neural netork
    if args.dataset == 'mnist':
        global_model = CNNMnist(args=args)
    elif args.dataset == 'fmnist':
        global_model = CNNFashion_Mnist(args=args)
    elif args.dataset == 'cifar':
        global_model = CNNCifar(args=args)

elif args.model == 'mlp':
    # Multi-layer preceptron (fully connected network)
    img_size = train_dataset[0][0].shape
    len_in = 1
    for x in img_size:
        len_in *= x
        global_model = MLP(dim_in=len_in, dim_hidden=64,
                           dim_out=args.num_classes)

# Prints the model architecture on the screen
print(global_model)        

In [None]:
# Send the model to device (GPU or CPU).
global_model.to(device)

# Defining some helper variables for training
train_loss, train_accuracy = [], []

In [None]:
# Runs the experiment for the selected number of global epochs (rounds)
for epoch in range(args.global_epochs):
    
    print(f'Global Training Round : {epoch+1}\n')
    
    # Sample a fraction of the population
    # Each device will be represented by an integer number (ID) in the
    # "selected_devices" list     
    m = max(int(args.frac * args.num_users), 1)
    selected_devices = np.random.choice(range(args.num_users), m, replace=False)
    print("Devices selected in global round {}: {}".format(epoch+1,selected_devices))

    # Sets the model to train
    global_model.train()

     # Local weights and local losses are lists that keeps track of
    # each selected device model weights and training loss    
    local_weights, local_losses = [], []
    # For each selected device, runs the local training
    for device in selected_devices:
        
        # Creates an object of the class "LocalUpdate" which will represent
        # each device selected for training
        local_device = LocalUpdate(args=args, dataset=train_dataset,
                                  idxs=device_population[device], device_id=device)

        # This function will run the device local training and return the
        # obtained local model. Note that we are passing a full copy of the
        # global model to the device
        device_weights, device_loss = local_device.update_weights(
            model=copy.deepcopy(global_model), global_round=epoch)
        
        local_weights.append(copy.deepcopy(device_weights))
        local_losses.append(copy.deepcopy(device_loss))
        
    # This function will take all local models weights and do a simple averaging
    # of all those weights to obtain the new global model weights
    global_weights = average_weights(local_weights)

    # Load the obtained weights into the global model structure
    global_model.load_state_dict(global_weights)

    # Calculates the average loss of this global epoch by averaging all the local losses of every
    # selected device     
    loss_avg = sum(local_losses) / len(local_losses)
    
    train_loss.append(loss_avg)

    # Calculate avg training accuracy over all users at every global epoch
    list_acc, list_loss = [], []
    global_model.eval()
    for device in range(args.num_users):
        
        # Creates a class to represent each the local device in the whole population
        local_device = LocalUpdate(args=args, dataset=train_dataset,
                                  idxs=device_population[device], device_id=device)

        # Calculates the device accuracy and loss in its local data using the global model
        acc, loss = local_device.inference(model=global_model)

        list_acc.append(acc)
        list_loss.append(loss)

    train_accuracy.append(sum(list_acc)/len(list_acc))
    
    print(f'\nAvg Training Stats after {epoch+1} global epochs:')
    print('Global Training Loss: {:.3f}'.format(np.mean(np.array(train_loss))))
    print('Global Train Accuracy: {:.2f}%\n'.format(100*train_accuracy[-1]))

In [None]:
# Test the final global model after completion of training
test_acc, test_loss = test_inference(args, global_model, test_dataset)

print(f' \n Results after {args.global_epochs} global epochs of training:')
print("|---- Avg Train Accuracy: {:.2f}%".format(100*train_accuracy[-1]))
print("|---- Test Accuracy: {:.2f}%".format(100*test_acc))

print('\n Total Run Time: {0:0.4f} seconds'.format(time.time()-start_time))

In [None]:
# Plot the loss curve
plt.title('Training Loss vs Communication rounds (global epochs)')
plt.plot(range(len(train_loss)), train_loss, color='r')
plt.ylabel('Training loss')
plt.xlabel('Communication Rounds')
plt.show()
plt.savefig('save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_loss.png'.
            format(args.dataset, args.model, args.global_epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))

In [None]:
# Plot Average Accuracy vs Global Epochs
plt.figure()
plt.title('Average Accuracy vs Global Epochs')
plt.plot(range(len(train_accuracy)), train_accuracy, color='r')
plt.ylabel('Average Accuracy')
plt.xlabel('Communication Rounds')
plt.savefig('save/fed_{}_{}_{}_C[{}]_iid[{}]_E[{}]_B[{}]_acc.png'.
            format(args.dataset, args.model, args.global_epochs, args.frac,
                   args.iid, args.local_ep, args.local_bs))