# Baseline implementation

In [None]:
%pip install wandb --quiet

In [None]:
!curl -O http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip cifar10.zip

In [None]:
import wandb

wandb.init(project="step-2", entity="aml-federated-learning")

In [None]:
E = 1
STEP_SIZE = 5
GAMMA = 0.1

# K = 1, NUMBE_OR_CLIENTS = 2, MAX_TIME = 3 -> 58 sec

K = 10 # to set
NUMBER_OF_CLIENTS = 100 # to set
MAX_TIME = 1000 #to set

batch_size = 50

lr = 0.05

DATA_DISTRIBUTION = "non-iid" # "iid" | "non-iid"
DIRICHELET_ALPHA = 0.05 # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0

DIRICHELET_ALPHAS = [0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.00]

assert(DATA_DISTRIBUTION == "iid" or NUMBER_OF_CLIENTS == 100)

wandb.config.update({
    "batch-size": batch_size,
    "learning-rate": lr,
    # "momentum": MOMENTUM,
    # "weight_decay": WEIGHT_DECAY,
    "num_epochs": E,
    "step_size": STEP_SIZE,
    "gamma": GAMMA,
    "K": K,
    "number_of_clients": NUMBER_OF_CLIENTS,
    "max_time": MAX_TIME,
    "data_distribution": DATA_DISTRIBUTION,
    #"dirichelet_alpha": DIRICHELET_ALPHA
})

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

#net = Net()
#net = net.to("cuda")
#print(net)

In [None]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set):
    self.i = i
    self.train_set = train_set
    self.batch_size = 32
    # if num_workers = 0 -> we can omit, it should be faster then num_workers = 2
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                         shuffle=False)
    self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size,
                                         shuffle=False)
    self.net = Net()
    #self.net = self.net.to("cuda")
    # create your optimizer
    self.optimizer = optim.SGD(self.net.parameters(), lr=lr)
    self.criterion = nn.CrossEntropyLoss()
    # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=STEP_SIZE, gamma=GAMMA)
    wandb.watch(self.net, criterion=self.criterion, log_freq=100, log_graph=False)
    
  def clientUpdate(self, parameters):
    self.net.load_state_dict(parameters)
    theta = parameters
    for e in range(E):
      for images, labels in self.train_loader:
        #images = images.to("cuda")
        #labels = labels.to("cuda")
        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        loss = self.criterion(output, labels)
        loss.backward()
        wandb.log({f"client-loss-{self.i}": loss.item()})
        self.optimizer.step()    # Does the update
    
    return_dict = {}
    for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
      return_dict[k1] = v1 - v2
    return return_dict

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        #data = data.to("cuda")
        #labels = labels.to("cuda")

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)
                
    return running_corrects / n


In [None]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [None]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random

from tqdm.notebook import tqdm

from statistics import mean


def selectClients(k):
  return random.choices(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    ratio = len(trainsets[i])/len(trainset)
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

random.seed(42)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform = transform)

for alpha in DIRICHELET_ALPHAS:
  net = Net() # I create a new network for each value of alpha
  
  if DATA_DISTRIBUTION == "iid":
    # split the training set
    trainset_len = ( len(trainset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
    print(trainset_len)
    trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

    lengths = len(trainset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
    print(lengths)
    trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
  else:
    dirichelet_splits = parse_csv(f"cifar10/federated_train_alpha_{alpha:.2f}.csv")
    trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]

  # split the validation set
  testset_len = ( len(testset) // NUMBER_OF_CLIENTS ) * NUMBER_OF_CLIENTS
  print(testset_len)
  testset = torch.utils.data.Subset(testset, list(range(testset_len)))

  lengths = len(testset) // NUMBER_OF_CLIENTS * np.ones(NUMBER_OF_CLIENTS, dtype=np.int)
  # print(lengths)
  testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)

  clientsSizes = torch.zeros(NUMBER_OF_CLIENTS)

  clients = list()

  for c in range(NUMBER_OF_CLIENTS):
    clients.append(Client(c, trainsets[c], testsets[c]))

  #for step in tqdm(range(MAX_TIME)):
  for step in range(MAX_TIME):
    if (step // 100) > 0:
      print(f"{alpha} - {step}")
    selected_clients = selectClients(K)
    #print(f"Client(s) {[client.i for client in selected_clients]} selected")

    deltaThetas = list()
    for i, c in enumerate(selected_clients):
      deltaThetas.append(c.clientUpdate(net.state_dict()))
      
    g = aggregateClient(deltaThetas)
    
    parameters = {}
    for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
      parameters[k1] = v1 - v2 # todo: add server learning rate gamma
    net.load_state_dict(parameters)

  model_parameters = net.state_dict()
  avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)
  
  wandb.log({f"accuracy-{alpha}": avg_accuracy})
  
  print(f"Average accuracy after {MAX_TIME} rounds with DIRICHELET_ALPHA = {alpha} is {avg_accuracy}")

In [None]:
# I think we can remove this
# from collections import Counter

# print(Counter(label for _, label in iter(trainsets[0])))
# print(Counter(label for _, label in iter(trainsets[1])))
# print(Counter(label for _, label in iter(trainsets[2])))

In [None]:
import time

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}.pth"

# parameters of the trained model
server_model = net.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename)
# save the model on wandb
wandb.save(artifact_filename)

# Finish the wandb session and upload all data
wandb.finish(0, quiet=False)