# Baseline implementation

In [1]:
%pip install wandb --quiet

In [2]:
!wget http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip cifar10.zip

--2022-01-10 13:42:27--  http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.251.6.128, 142.250.159.128, 74.125.70.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.251.6.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1627997 (1.6M) [application/zip]
Saving to: ‘cifar10.zip.1’


2022-01-10 13:42:27 (102 MB/s) - ‘cifar10.zip.1’ saved [1627997/1627997]

Archive:  cifar10.zip
replace cifar10/federated_train_alpha_0.00.csv? [y]es, [n]o, [A]ll, [N]one, [r]ename: A
  inflating: cifar10/federated_train_alpha_0.00.csv  
  inflating: cifar10/test.csv        
  inflating: cifar10/federated_train_alpha_10.00.csv  
  inflating: cifar10/federated_train_alpha_0.05.csv  
  inflating: cifar10/federated_train_alpha_100.00.csv  
  inflating: cifar10/federated_train_alpha_0.10.csv  
  inflating: cifar10/federated_train_alpha_0.20.csv  
  inflating: cifar10/fede

In [3]:
import wandb

wandb.init(project="step-2", entity="aml-federated-learning", mode="disabled")



In [4]:
E = 1
STEP_SIZE = 5
GAMMA = 0.1

# K = 1, NUMBE_OR_CLIENTS = 2, MAX_TIME = 3 -> 58 sec
K = 5 # to set
NUMBER_OF_CLIENTS = 100 # to set
MAX_TIME = 500 #to set

batch_size = 50

lr = 0.25

DATA_DISTRIBUTION = "iid" # "iid" | "non-iid"
DIRICHELET_ALPHA = 0.00 # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0

FED_AVG_M = False # False | True -> False = Fed_Avg
if FED_AVG_M:
    FED_AVG_M_BETA = 0.9  
    FED_AVG_M_GAMMA = 1

LR_DECAY = 0.99

device = "cuda" # "cpu" | "cuda"

#assert(DATA_DISTRIBUTION == "iid" or NUMBER_OF_CLIENTS == 100)

wandb.config.update({
    "batch-size": batch_size,
    "learning-rate": lr,
    # "momentum": MOMENTUM,
    # "weight_decay": WEIGHT_DECAY,
    "num_epochs": E,
    "step_size": STEP_SIZE,
    "gamma": GAMMA,
    "K": K,
    "number_of_clients": NUMBER_OF_CLIENTS,
    "max_time": MAX_TIME,
    "data_distribution": DATA_DISTRIBUTION,
    "dirichelet_alpha": DIRICHELET_ALPHA
})

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(1600, 384)  # 5*5 from image dimension  # <- 50x576 and 1600x384
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
net = net.to(device)
print(net)

Net(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1600, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=10, bias=True)
)


In [10]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set):
    self.i = i
    self.train_set = train_set
    self.batch_size = 32
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                         shuffle=False, num_workers=0)
    self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size,
                                         shuffle=False, num_workers=0)
    self.net = Net()
    self.net = self.net.to(device)
    # create your optimizer
    self.optimizer = optim.SGD(self.net.parameters(), lr=lr)
    self.criterion = nn.CrossEntropyLoss()
    #if LR_DECAY is not None:
    #  self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=1, gamma=LR_DECAY)
    wandb.watch(self.net, criterion=self.criterion, log_freq=100, log_graph=False)
    
  def clientUpdate(self, lr, parameters):
    self.net.load_state_dict(parameters)
    for g in self.optimizer.param_groups:
      g['lr'] = lr

    theta = parameters
    for e in range(E):
      for images, labels in self.train_loader:
        images = images.to(device)
        labels = labels.to(device)
        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        loss = self.criterion(output, labels)
        loss.backward()
        wandb.log({f"client-loss-{self.i}": loss.item()})
        self.optimizer.step()    # Does the update

    # if LR_DECAY is not None:
    #   self.scheduler.step()
    
    return_dict = {}
    for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
      return_dict[k1] = v1 - v2
    return return_dict

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        data = data.to(device)
        labels = labels.to(device)

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)

                
    return running_corrects / n


In [11]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [12]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from statistics import mean

from tqdm.notebook import tqdm, trange

random.seed(42)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform = transform)


if DATA_DISTRIBUTION == "iid":
  # split the training set
  trainset_len = ( len(trainset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
  print(trainset_len)
  trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

  lengths = len(trainset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
  print(lengths)
  trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
else:
  dirichelet_splits = parse_csv(f"cifar10/federated_train_alpha_{DIRICHELET_ALPHA:.2f}.csv")
  trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


# split the validation set
testset_len = ( len(testset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
print(testset_len)
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
# print(lengths)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(NUMBER_OF_CLIENTS)
clients = list()

def selectClients(k):
  return random.choices(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    #ratio = len(trainsets[i])/len(trainset)
    ratio = len(trainsets[i])/(len(trainsets[i])*K)
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

for c in range(NUMBER_OF_CLIENTS):
  clients.append(Client(c, trainsets[c], testsets[c]))

if FED_AVG_M:
  old_parameters = {}

for step in tqdm(range(MAX_TIME)):
#for t in range(MAX_TIME):
  selected_clients = selectClients(K)
  #print(f"Client(s) {[client.i for client in selected_clients]} selected")

  deltaThetas = list()
  for i, c in enumerate(selected_clients):
    deltaThetas.append(c.clientUpdate(lr, net.state_dict()))
    
  g = aggregateClient(deltaThetas)
  
  parameters = {}
  for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
    
    if FED_AVG_M:
      if k1 in old_paramters:
        parameters[k1] = v1 - FED_AVG_M_GAMMA * (FED_AVG_M_BETA * old_parameters[k1] + v2)  
        old_parameters[k1] = FED_AVG_M_BETA * old_parameters[k1] + v2
      else:
        parameters[k1] = v1 - FED_AVG_M_GAMMA * v2
        old_parameters[k1] = v2
        
    else:
      parameters[k1] = v1 - v2 # todo: add server learning rate gamma

  net.load_state_dict(parameters)

  lr *= LR_DECAY

  if step % 100 == 0:
    model_parameters = net.state_dict()
    avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)

    print(f"Average accuracy after {step} rounds is {avg_accuracy}")

Files already downloaded and verified
Files already downloaded and verified
50000
[500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500
 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500
 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500
 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500
 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500 500
 500 500 500 500 500 500 500 500 500 500]
10000


  0%|          | 0/500 [00:00<?, ?it/s]

Average accuracy after 0 rounds is 0.1537
Average accuracy after 5 rounds is 0.2195
Average accuracy after 10 rounds is 0.3387
Average accuracy after 15 rounds is 0.3718
Average accuracy after 20 rounds is 0.4101
Average accuracy after 25 rounds is 0.4469
Average accuracy after 30 rounds is 0.4644
Average accuracy after 35 rounds is 0.4817
Average accuracy after 40 rounds is 0.4897
Average accuracy after 45 rounds is 0.5038
Average accuracy after 50 rounds is 0.5169
Average accuracy after 55 rounds is 0.5318
Average accuracy after 60 rounds is 0.5358
Average accuracy after 65 rounds is 0.5529000000000001
Average accuracy after 70 rounds is 0.5624
Average accuracy after 75 rounds is 0.5624
Average accuracy after 80 rounds is 0.583
Average accuracy after 85 rounds is 0.5866
Average accuracy after 90 rounds is 0.5901
Average accuracy after 95 rounds is 0.5947
Average accuracy after 100 rounds is 0.6011
Average accuracy after 105 rounds is 0.6102
Average accuracy after 110 rounds is 0.6167

In [None]:
from collections import Counter

print(Counter(label for _, label in iter(trainsets[0])))
print(Counter(label for _, label in iter(trainsets[1])))
print(Counter(label for _, label in iter(trainsets[2])))

In [13]:
from statistics import mean

model_parameters = net.state_dict()
avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)

print(f"Average accuracy after {MAX_TIME} rounds is {avg_accuracy}")

Average accuracy after 500 rounds is 0.7041


In [None]:
import time

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}.pth"

# parameters of the trained model
server_model = net.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename)
# save the model on wandb
wandb.save(artifact_filename)

# Finish the wandb session and upload all data
wandb.finish(0, quiet=False)