# Baseline implementation

In [1]:
%pip install wandb --quiet

[K     |████████████████████████████████| 1.7 MB 5.3 MB/s 
[K     |████████████████████████████████| 140 kB 48.5 MB/s 
[K     |████████████████████████████████| 97 kB 6.2 MB/s 
[K     |████████████████████████████████| 180 kB 46.1 MB/s 
[K     |████████████████████████████████| 63 kB 1.6 MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [50]:
!wget http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip cifar10.zip

--2021-12-29 09:14:02--  http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 142.250.152.128, 142.250.136.128, 209.85.200.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|142.250.152.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1627997 (1.6M) [application/zip]
Saving to: ‘cifar10.zip’


2021-12-29 09:14:02 (189 MB/s) - ‘cifar10.zip’ saved [1627997/1627997]

Archive:  cifar10.zip
   creating: cifar10/
  inflating: cifar10/federated_train_alpha_0.00.csv  
  inflating: cifar10/test.csv        
  inflating: cifar10/federated_train_alpha_10.00.csv  
  inflating: cifar10/federated_train_alpha_0.05.csv  
  inflating: cifar10/federated_train_alpha_100.00.csv  
  inflating: cifar10/federated_train_alpha_0.10.csv  
  inflating: cifar10/federated_train_alpha_0.20.csv  
  inflating: cifar10/federated_train_alpha_1.00.csv  
  inflating: cifar10/federated_train

In [41]:
import wandb

wandb.init(project="step-2", entity="aml-federated-learning")

[34m[1mwandb[0m: Currently logged in as: [33mpeiro98[0m (use `wandb login --relogin` to force relogin)


In [42]:
E = 2
STEP_SIZE = 5
GAMMA = 0.1

# K = 1, NUMBE_OR_CLIENTS = 2, MAX_TIME = 3 -> 58 sec

K = 10 # to set
NUMBER_OF_CLIENTS = 100 # to set
MAX_TIME = 20 #to set

batch_size = 50

lr = 0.05

DATA_DISTRIBUTION = "non-iid" # "iid" | "non-iid"
DIRICHELET_ALPHA = 0.1 # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0

assert(DATA_DISTRIBUTION == "iid" or NUMBER_OF_CLIENTS == 100)

wandb.config.update({
    "batch-size": batch_size,
    "learning-rate": lr,
    # "momentum": MOMENTUM,
    # "weight_decay": WEIGHT_DECAY,
    "num_epochs": E,
    "step_size": STEP_SIZE,
    "gamma": GAMMA,
    "K": K,
    "number_of_clients": NUMBER_OF_CLIENTS,
    "max_time": MAX_TIME,
    "data_distribution": DATA_DISTRIBUTION,
    "dirichelet_alpha": DIRICHELET_ALPHA
})

In [43]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
net = net.to("cuda")
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [44]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set):
    self.i = i
    self.train_set = train_set
    self.batch_size = 32
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
    self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size,
                                         shuffle=False, num_workers=2)
    self.net = Net()
    self.net = self.net.to("cuda")
    # create your optimizer
    self.optimizer = optim.SGD(self.net.parameters(), lr=lr)
    self.criterion = nn.CrossEntropyLoss()
    # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=STEP_SIZE, gamma=GAMMA)
    wandb.watch(self.net, criterion=self.criterion, log_freq=100, log_graph=True)
    
  def clientUpdate(self, parameters):
    self.net.load_state_dict(parameters)
    theta = parameters
    for e in range(E):
      for images, labels in self.train_loader:
        images = images.to("cuda")
        labels = labels.to("cuda")
        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        loss = self.criterion(output, labels)
        loss.backward()
        wandb.log({f"client-loss-{self.i}": loss.item()})
        self.optimizer.step()    # Does the update
    
    return_dict = {}
    for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
      return_dict[k1] = v1 - v2
    return return_dict

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        data = data.to("cuda")
        labels = labels.to("cuda")

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)
                
    return running_corrects / n


In [45]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [46]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

random.seed(42)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform = transform)


if DATA_DISTRIBUTION == "iid":
  # split the training set
  trainset_len = ( len(trainset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
  print(trainset_len)
  trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

  lengths = len(trainset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
  print(lengths)
  trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
else:
  dirichelet_splits = parse_csv(f"cifar10/federated_train_alpha_{DIRICHELET_ALPHA:.2f}.csv")
  trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


# split the validation set
testset_len = ( len(testset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
print(testset_len)
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
# print(lengths)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(NUMBER_OF_CLIENTS)
clients = list()

def selectClients(k):
  return random.choices(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    ratio = len(trainsets[i])/len(trainset)
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

for c in range(NUMBER_OF_CLIENTS):
  clients.append(Client(c, trainsets[c], testsets[c]))


for t in range(MAX_TIME):
  selected_clients = selectClients(K)
  print(f"Client(s) {[client.i for client in selected_clients]} selected")

  deltaThetas = list()
  for i, c in enumerate(selected_clients):
    deltaThetas.append(c.clientUpdate(net.state_dict()))
    
  g = aggregateClient(deltaThetas)
  
  parameters = {}
  for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
    parameters[k1] = v1 - v2 # todo: add server learning rate gamma
  net.load_state_dict(parameters)

Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging 

10000


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m: logging 

Client(s) [63, 2, 27, 22, 73, 67, 89, 8, 42, 2] selected
Client(s) [21, 50, 2, 19, 64, 54, 22, 58, 80, 0] selected
Client(s) [80, 69, 34, 15, 95, 33, 9, 9, 84, 60] selected
Client(s) [80, 72, 53, 97, 37, 55, 82, 61, 86, 57] selected
Client(s) [70, 4, 22, 28, 7, 23, 10, 27, 63, 36] selected
Client(s) [37, 20, 26, 93, 64, 60, 17, 72, 16, 37] selected
Client(s) [98, 63, 55, 68, 84, 77, 22, 3, 31, 26] selected
Client(s) [21, 94, 87, 31, 65, 39, 91, 45, 26, 24] selected
Client(s) [56, 26, 58, 89, 39, 21, 99, 50, 9, 4] selected
Client(s) [10, 62, 79, 42, 6, 38, 99, 52, 97, 86] selected
Client(s) [1, 72, 68, 53, 26, 64, 11, 43, 45, 95] selected
Client(s) [87, 26, 50, 17, 91, 87, 29, 63, 60, 15] selected
Client(s) [76, 53, 77, 53, 0, 32, 1, 92, 87, 83] selected
Client(s) [30, 5, 87, 94, 8, 48, 6, 76, 76, 12] selected
Client(s) [47, 54, 26, 87, 42, 21, 53, 72, 20, 31] selected
Client(s) [99, 64, 43, 51, 12, 22, 33, 58, 23, 22] selected
Client(s) [7, 63, 22, 90, 85, 7, 23, 66, 21, 13] selected
C

In [47]:
from collections import Counter

print(Counter(label for _, label in iter(trainsets[0])))
print(Counter(label for _, label in iter(trainsets[1])))
print(Counter(label for _, label in iter(trainsets[2])))

Counter({1: 248, 5: 36, 4: 33, 9: 31, 6: 29, 2: 29, 8: 28, 3: 22, 7: 22, 0: 22})
Counter({6: 222, 3: 35, 7: 35, 8: 34, 9: 32, 5: 32, 1: 28, 0: 28, 4: 27, 2: 27})
Counter({6: 199, 5: 41, 3: 40, 9: 39, 8: 37, 7: 34, 1: 33, 0: 29, 4: 25, 2: 23})


In [48]:
from statistics import mean

model_parameters = net.state_dict()
avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)

print(f"Average accuracy after {MAX_TIME} rounds is {avg_accuracy}")

Average accuracy after 20 rounds is 0.1067


In [49]:
import time

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}.pth"

# parameters of the trained model
server_model = net.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename)
# save the model on wandb
wandb.save(artifact_filename)

# Finish the wandb session and upload all data
wandb.finish(0, quiet=False)

VBox(children=(Label(value=' 0.35MB of 0.35MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
client-loss-0,▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▄▄▃▂▄▄▄▄▄▃▄▄▂▁█▄▄▄▄▃▄▄▃▃
client-loss-1,▅▅▅▅▄▄▅▅▄▃▅▅▅▅▃▂▆▅▂▁▅▅▅▅▄▄▅▅▄▃▅▅▅▅▂▁█▅▃▂
client-loss-10,████████▇▇████▇▆██▆▅████████▇▆████▆▅██▄▁
client-loss-11,▆▆▆▆▅▅▅▅▄▂█▆▆▅▄▃▆▆▃▁
client-loss-12,▄▄▄▄▄▄▄▄▃▃▅▅▄▄▃▁█▄▃▃▄▄▄▄▄▄▄▄▃▂▅▅▄▄▃▁█▄▃▃
client-loss-13,▄▄▄▄▄▃▃▄▂▁█▄▄▄▄▃▄▄▃▃
client-loss-15,▇██▇▇▇▇█▇▆▇██▇▇▆▇█▆▅▇▇█▇▇▇▇▇▆▅▇█▇█▆▄▆█▅▁
client-loss-16,▇▇▇▇█▇▇▇▆▅▇▇█▇▅▄▇▆▃▁
client-loss-17,▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▅▆▆▅▄▆▆▆▆▆▆▆▆▅▃█▆▆▆▅▄▆▆▃▁
client-loss-19,▅▅▅▅▅▅▄▅▅▅▄▅▄▄▅▅▅▄▅▃█▅▅▄▅▅▁▅▅▅▄▅▃▁▅▅▄▃▅▁

0,1
client-loss-0,1.67262
client-loss-1,1.48022
client-loss-10,1.2643
client-loss-11,0.6838
client-loss-12,1.6971
client-loss-13,1.90821
client-loss-15,1.13326
client-loss-16,1.78265
client-loss-17,0.7658
client-loss-19,1.00165
