<a href="https://colab.research.google.com/github/lfbarba/Compatible-Embeddings/blob/master/EMNIST_averaged_delayed_gradients.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip3 install torch torchvision

!pip3 install --upgrade wandb
!wandb login d6f99b98acf9c1a284aa2ba5830f3eca60fde2f0

Requirement already up-to-date: wandb in /usr/local/lib/python3.6/dist-packages (0.8.20)
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[32mSuccessfully logged in to Weights & Biases![0m


In [0]:
import torch
import random
import math 
import gc
import copy
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F
from torch import nn
from torchvision import datasets, transforms
# Init wandb
import wandb

In [0]:
from google.colab import drive
drive.mount('/gdrive')
%cd /gdrive/My\ Drive/

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /gdrive
/gdrive/My Drive


In [0]:
transform = transforms.Compose([transforms.Resize((28,28)),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5,), (0.5,))
                               ])
training_dataset = datasets.EMNIST(root='', split = 'byclass', train=True, download=False, transform=transform)
validation_dataset = datasets.EMNIST(root='', split = 'byclass', train=False, download=False, transform=transform)

In [0]:
def im_convert(tensor):
  image = tensor.cpu().clone().detach().numpy()
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.5, 0.5, 0.5)) + np.array((0.5, 0.5, 0.5))
  image = image.clip(0, 1)
  return image

In [0]:
class CNN(nn.Module):
    def __init__(self, dataset):
        super(CNN, self).__init__()
        self.dataset = dataset

        if dataset == "cifar10" or dataset == "mnist" or dataset.startswith("emnist"):
            if dataset == "cifar10" or dataset == "mnist":
                self.num_classes = 10
            elif dataset == 'emnist-byclass':
                self.num_classes = 62
            elif dataset == 'emnist-balanced' or dataset == 'emnist-bymerge':
                self.num_classes = 47

            if dataset == "cifar10":
                side_size = 32
                n_filters = 3
            else:
                n_filters = 1
                side_size = 28

            flatten_side_size = side_size - 2 * 3
            self.net = nn.Sequential(
                nn.Conv2d(n_filters, 32, (3, 3)),
                nn.ReLU(),
                nn.Conv2d(32, 64, (3, 3)),
                nn.ReLU(),
                nn.Conv2d(64, 64, (3, 3)),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(64 * flatten_side_size * flatten_side_size, 64),
                nn.ReLU(),
                nn.Linear(64, self.num_classes)
            )

        elif dataset.startswith("digits"):
            if dataset == "digits":
                self.num_classes = 10
            else:
                self.num_classes = int(dataset[6:])
            self.net = nn.Sequential(
                nn.Conv2d(1, 8, (3, 3)),
                nn.ReLU(),
                nn.Conv2d(8, 16, (3, 3)),
                nn.ReLU(),
                nn.Flatten(),
                nn.Linear(16 * 4 * 4, 32),
                nn.ReLU(),
                nn.Linear(32, self.num_classes)
            )
        

    def forward(self, x):
        return self.net(x)

In [0]:
class Trainer():
  def __init__(self, splitter, validation_loader, num_workers):
    self.num_workers = num_workers
    self.splitter = splitter
    self.validation_loader = validation_loader

  def setUpExperiment(self, model, device, config):
    self.config = config
    self.algorithm = self.config.algorithm
    self.max_num_rounds = self.config.num_rounds_to_run
    self.num_local_steps = self.config.num_local_steps
    self.device = device
    self.sampling_rate  = self.config.sampling_rate
    self.with_training_loss = self.config.with_training_loss

    self.training_iterators = []
    for worker in range(0, self.num_workers):
      self.training_iterators.append(self.getLoaderIter(worker))
    
    self.criterion = self.config.criterion
    self.model = model

    wandb.watch(self.model)

    self.step_size = self.config.step_size
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr = self.step_size)
    # self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer, milestones=[2000, 4000], gamma=0.1)
    
    self.delta = []
    self.sum_error = []
    self.avg_error = []
    param_list = list(self.model.parameters())
    for i in range(0, len(param_list)):
        self.delta.append(param_list[i]*0)
        self.sum_error.append(param_list[i]*0)
        self.avg_error.append(param_list[i]*0)

    self.num_rounds = 0
    

    while self.num_rounds <= self.max_num_rounds:
      if self.algorithm == 'FedLearn':    
        self.runRoundFedLern()
      elif self.algorithm == 'DAC':
        self.runRound()
      else:
        print('No valid algorithm provided')
        return
      self.num_rounds += 1
      if self.num_rounds % 2 == 0:
        print('Num_local_steps : ', self.num_local_steps, ', rounds communication : ', self.num_rounds)
        if self.with_training_loss:
          #Chek Training data
          avg_loss, avg_acc = self.checkModel('Training', self.splitter.full_training_loader)
          wandb.log({"Training Accuracy": avg_acc, "Training Loss": avg_loss}, step = self.num_rounds)  
        #Check Validation data
        avg_loss, avg_acc = self.checkModel('Validation', self.validation_loader)
        wandb.log({"Test Accuracy": avg_acc, "Test Loss": avg_loss}, step = self.num_rounds)
        

  def getLoaderIter(self, worker):
    return iter(self.splitter.data_loaders[worker])

  def getInputsLabels(self, training_list):
    inputs = []
    labels = []
    for worker in training_list:
      try:
        batch_inputs, batch_labels = self.training_iterators[worker].next()
      except:
        self.training_iterators[worker] = self.getLoaderIter(worker)
        batch_inputs, batch_labels = self.training_iterators[worker].next()

      inputs.append(batch_inputs)
      labels.append(batch_labels)
      
    joined_inputs = torch.cat(inputs, 0).to(self.device)
    joined_labels = torch.cat(labels, 0).to(self.device)
    return joined_inputs, joined_labels

  def computeJoinedGradient(self):

    training_list = list(np.arange(0, self.config.num_workers))
    random.shuffle(training_list)
    training_list = training_list[0:math.ceil(self.config.sampling_rate * self.config.num_workers)]
    
    joined_inputs, joined_labels = self.getInputsLabels(training_list)
 
    outputs = self.model(joined_inputs)
    loss = self.criterion(outputs, joined_labels)

    self.optimizer.zero_grad()
    loss.backward()
  def runRoundFedLern(self):
    # Run several batches on the same parameters and in the end apply the 
    # average of the computed gradients
    for i, param in enumerate(self.model.parameters()):
      self.sum_error[i] = self.sum_error[i] * 0

    for current_local_step in range(0, self.num_local_steps):
      #Computes the average joined gradient of all the workers
      self.computeJoinedGradient()
      for i, param in enumerate(self.model.parameters()):
        self.sum_error[i] += param.grad
        # This line is only useful in the last step when the gradient is set
        # to be the average of the joined gradients 
        param.grad = self.sum_error[i]/self.num_local_steps

    # After going through all the data, it gives just one step     
    self.optimizer.step()  

  def runRound(self):
    # Update the error and sum before starting a new round
    for i, param in enumerate(self.model.parameters()):
      self.avg_error[i] = self.sum_error[i]/self.num_local_steps
      self.delta[i] = self.delta[i] + self.avg_error[i]
      self.sum_error[i] = self.sum_error[i] * 0

    for current_local_step in range(0, self.num_local_steps):
      self.computeJoinedGradient()
      for i, param in enumerate(self.model.parameters()):
        self.sum_error[i] += param.grad - self.delta[i]
        param.grad = self.delta[i] + self.avg_error[i]

      self.optimizer.step()  

  def checkModel(self, data_label, loader):
    sum_loss = 0.0
    sum_corrects = 0.0
    with torch.no_grad():
      for inputs, labels in loader:
        batch_inputs = inputs.to(self.device)
        batch_labels = labels.to(self.device)
        outputs = self.model(batch_inputs)
        loss = self.criterion(outputs, batch_labels)

        _, preds = torch.max(outputs, 1)
        sum_loss += loss.item()
        sum_corrects += torch.sum(preds == batch_labels.data)
        
    avg_loss = sum_loss/len(loader)
    avg_acc = sum_corrects.float()/ len(loader)
    print('{} loss: {:.4f}, {} acc {:.4f} '.format(data_label, avg_loss, data_label, avg_acc.item()))

    return avg_loss, avg_acc

In [0]:
class Partition(object):
    """ Dataset-like object, but only access a subset of it. """

    def __init__(self, data, index):
        self.data = data
        self.index = index

    def __len__(self):
        return len(self.index)

    def __getitem__(self, index):
        data_idx = self.index[index][0]
        return self.data[data_idx]

In [0]:
class SplitLoader():
  def __init__(self, training_dataset, num_workers, batch_size = 100, iid_data = False):
    #split the data
    self.batch_size = batch_size
    self.shuffle = iid_data
    self.num_workers = num_workers
    self.full_data = training_dataset
    self.size = math.floor(len(self.full_data)/self.num_workers)
    #sort the data or shuffle it
    self.indices_labels = []
    self.sorting()
    
    self.indices_per_worker = []
    for i in range(0, self.num_workers):
      self.indices_per_worker.append(self.indices_labels[i*self.size:i*self.size+self.size])
    self.setWorkerLoaders()

  def setWorkerLoaders(self):
    self.full_training_loader = torch.utils.data.DataLoader(self.full_data, 100, shuffle=True)
    self.data_loaders = []
    for worker in range(0, self.num_workers):
      partition = Partition(self.full_data, self.indices_per_worker[worker])
      self.data_loaders.append(torch.utils.data.DataLoader(partition, self.batch_size, shuffle=True))

  def takeLabel(self, elem):
    return elem[1]

  def sorting(self):
    for index, data in enumerate(self.full_data):
      self.indices_labels.append((index, data[1]))  
    if self.shuffle:
      random.shuffle(self.indices_labels)
    else:
      self.indices_labels.sort(key=self.takeLabel)

In [0]:
class Configuration():
  def __init__(self, dict):
    self.dict = dict
    self.experiment_id = dict['experiment_id']
    self.dataset = dict['dataset']
    self.num_workers = dict['num_workers']
    self.worker_batch_size = dict['worker_batch_size']
    self.validation_batch_size = dict['validation_batch_size']
    self.iid_data = dict['iid_data']
    self.num_local_steps = dict['num_local_steps']
    self.step_size = dict['step_size']
    self.sampling_rate = dict['sampling_rate']
    self.with_training_loss = dict['with_training_loss']
    self.num_workers = dict['num_workers']
    self.num_rounds_to_run = dict['num_rounds_to_run']
    self.algorithm = dict['algorithm']
    self.criterion = dict['criterion']
    self.with_training_loss = dict['with_training_loss']

In [0]:
def runExperiment(validation_loader, splitter, config):          
  # Create the logger experiment in wandb
  wandb.init(project = config.experiment_id, config = config.dict)

  # if config.algorithm == 'DAC':
  #   config.step_size /= 1.2 * math.sqrt(config.num_local_steps)
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  trainer = Trainer(splitter, validation_loader, config.num_workers)
  model = CNN(config.dataset).to(device)

  print("num_local_steps ", config.num_local_steps, " num_workers ", config.num_workers, 
        " batch_size per worker ", config.worker_batch_size, " sampling_rate ", config.sampling_rate, ' step_size ', config.step_size)
  trainer.setUpExperiment(model, device, config)
  # #save data
  # wandb.save('model.h5')

In [0]:
config_dict = {
    'dataset':'emnist-byclass',
    'experiment_id':'delayed_avg_correction',
    'num_workers':698,
    'worker_batch_size':25,
    'validation_batch_size':100,
    'iid_data': False,
    'num_local_steps': 1,
    'step_size': 0.001,
    'sampling_rate': 1.0/35,
    'with_training_loss': False,
    'num_rounds_to_run': 300,
    'algorithm':'FedLearn', #'DAC' or 'FedLearn'
    'criterion':nn.CrossEntropyLoss(),
    'with_training_loss':False
}

config = Configuration(config_dict)

validation_loader = torch.utils.data.DataLoader(validation_dataset, 
                                                batch_size = config.validation_batch_size, shuffle=False)
splitter = SplitLoader(training_dataset, config.num_workers, 
                       config.worker_batch_size, iid_data = config.iid_data)
runExperiment(validation_loader, splitter, config)

num_local_steps  1  num_workers  698  batch_size per worker  25  sampling_rate  0.02857142857142857  step_size  0.001
Num_local_steps :  1 , rounds communication :  2
Validation loss: 4.0469, Validation acc 1.9115 
Num_local_steps :  1 , rounds communication :  4
Validation loss: 3.9286, Validation acc 8.2483 
Num_local_steps :  1 , rounds communication :  6
Validation loss: 3.8858, Validation acc 5.4055 
Num_local_steps :  1 , rounds communication :  8
Validation loss: 3.7104, Validation acc 8.0481 
Num_local_steps :  1 , rounds communication :  10
Validation loss: 3.6207, Validation acc 14.6443 
Num_local_steps :  1 , rounds communication :  12
Validation loss: 3.3897, Validation acc 21.0146 
Num_local_steps :  1 , rounds communication :  14
Validation loss: 3.3191, Validation acc 23.7311 
Num_local_steps :  1 , rounds communication :  16
Validation loss: 3.0546, Validation acc 31.9278 
Num_local_steps :  1 , rounds communication :  18
Validation loss: 2.9928, Validation acc 36.2070 

In [0]:
# for images, labels in splitter.data_loaders[0]:
#   break
# fig = plt.figure(figsize=(25, 4))

# for idx in np.arange(20):
#   ax = fig.add_subplot(2, 10, idx+1, xticks=[], yticks=[])
#   plt.imshow(im_convert(images[idx]))
#   ax.set_title(labels[idx].item())