# Baseline implementation

In [1]:
# download the Cifar10 non-iid splits, if not present

from os import path
import urllib.request
import zipfile

if not path.exists("cifar10"):
    save_path = "cifar10.zip"
    urllib.request.urlretrieve("http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10_v1.1.zip", save_path)
    
    with zipfile.ZipFile(save_path, 'r') as zip_ref:
        zip_ref.extractall(".")

In [98]:
config = {
    "E": 1, # number of local epochs
    "K": 5, # number of clients selected each round # [5, 10, 20]
    "NUMBER_OF_CLIENTS": 100, # total number of clients
    "MAX_TIME": 10000,
    "BATCH_SIZE": 50,
    "VALIDATION_BATCH_SIZE": 500,
    "LR": 0.01,
    "DATA_DISTRIBUTION": "non-iid", # "iid" | "non-iid"
    "DIRICHELET_ALPHA": [0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0],
    "AVERAGE_ACCURACY": np.zeros(8),
    "FED_AVG_M": False,
    "FED_AVG_M_BETA": 0.9,
    "FED_AVG_M_GAMMA": 1,
    "LR_DECAY": 0.99,
    "LOG_FREQUENCY": 5,
    "DEVICE": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "AUGMENTATION_PROB": 0.0,
    "ALPHA": 0.01 # for FedDyn
}

In [99]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

In [100]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self, *, input_size=32):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        
        # output of the conv layer is (w', h') = (w - 5 + 1, h - 5 + 1)
        # max_pool2d halves the dimensions (w', h') = (w / 2, h / 2)

        # dynamically compute the image size
        size = input_size // 4 - 3
        self.fc1 = nn.Linear(64 * (size * size), 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
net = net.to(device)
print(net)

In [101]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set, *, input_size=32):
    self.i = i
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=config["BATCH_SIZE"],
                                         shuffle=True, num_workers=0)
    self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=config["BATCH_SIZE"],
                                         shuffle=False, num_workers=0)
    self.net = Net(input_size=input_size)
    self.net = self.net.to(device)
    # create your optimizer
    self.optimizer = optim.SGD(self.net.parameters(), lr=config["LR"])
    self.criterion = nn.CrossEntropyLoss()
    
    self.previous_gradient = {key:torch.zeros(params.shape, device=device) for key, params in self.net.state_dict().items()}
    
  def clientUpdate(self, alpha, parameters):
    self.net.load_state_dict(parameters)

    for e in range(config["E"]):
      for images, labels in self.train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        
        # compute the loss of the model
        loss = self.criterion(output, labels)
        
        # compute the dot product term
        loss -= sum(
            torch.sum(old_grad * cur_params) 
            for (old_grad, cur_params) 
            in zip(self.previous_gradient.values(), self.net.state_dict().values())
        )
        
        loss += (alpha / 2) * sum(
            # torch.sum(cur_params * cur_params) 
            torch.linalg.norm(cur_params.reshape(-1) - old_params.reshape(-1), 2) ** 2
            for cur_params, old_params
            in zip(self.net.state_dict().values(), parameters.values())
        )

        loss.backward()
        
        self.optimizer.step()    # Does the update
    
    # store the previous gradient
    self.previous_gradient = {
      key: old_grad - alpha * (cur_params - old_params)
      for (key, old_grad), cur_params, old_params in zip(self.previous_gradient.items(), self.net.state_dict().values(), parameters.values())
    }

    return self.net.state_dict()

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    self.net.train(False)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        data = data.to(device)
        labels = labels.to(device)

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)

    self.net.train(True)
                
    return running_corrects / n


In [104]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [None]:
import time
import json
import numpy
from copy import deepcopy

def listToString(l): 
    return " ".join(str(l))

def printJSON(alpha, acc, net):
    timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
    artifact_filename = f"artifacts/server_model-{timestr}"

    # parameters of the trained model
    server_model = net.state_dict()
    # save the model on the local file system
    torch.save(server_model, artifact_filename + ".pth")
    config_copy = deepcopy(config)
    config_copy["DIRICHELET_ALPHA"] = listToString(config_copy["DIRICHELET_ALPHA"])
    config_copy["AVERAGE_ACCURACY"] = numpy.array2string(config_copy["AVERAGE_ACCURACY"])
    config_copy["DEVICE"] = ""
    data = {
        "config": config_copy,
        "alpha": listToString(alpha),
        "accuracy": acc
    }

    with open(artifact_filename + ".json", "w") as f:
        f.write(json.dumps(data, indent=4))

    # If you want to cat the file, my suggestion is to avoid this is a pretty heavy operation at least on my pc
    #artifact_filename += ".json"
    #!cat artifact_filename


In [None]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from statistics import mean

from tqdm.notebook import tqdm

random.seed(42)

random_transform = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(1),
        transforms.ColorJitter(0.9, 0.9)
    ]
)

train_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.RandomApply([random_transform], config["AUGMENTATION_PROB"]),
  transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262]),
])

test_transform = transforms.Compose([
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)


if config["DATA_DISTRIBUTION"] == "iid":
  # split the training set
  trainset_len = ( len(trainset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
  trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

  lengths = len(trainset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
  trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
else:
  dirichelet_splits = parse_csv(f"cifar10/federated_train_alpha_{config['DIRICHELET_ALPHA']:.2f}.csv")
  trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


# split the validation set
testset_len = ( len(testset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(config["NUMBER_OF_CLIENTS"])
clients = list()

def selectClients(k):
  return random.choices(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    #ratio = len(trainsets[i])/len(trainset)
    ratio = len(trainsets[i])/(len(trainsets[i])*config['K'])
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

for c in range(config["NUMBER_OF_CLIENTS"]):
  clients.append(Client(c, trainsets[c], testsets[c]))

# initial learning rate
lr = config["LR"]

# initial alpha_i value
alpha_i = config["DIRICHELET_ALPHA"][0]

# collect the test accuracies over the epochs
test_accuracies = []

# initialize h_0
h = {key:torch.zeros(params.shape, device=device) for key, params in net.state_dict().items()}

m = config["NUMBER_OF_CLIENTS"]
alpha = config["ALPHA"]
K = config["K"]

accuracies = list()

for step in tqdm(range(config["MAX_TIME"])):
#for t in range(MAX_TIME):
  selected_clients = selectClients(K)
  #print(f"Client(s) {[client.i for client in selected_clients]} selected")

  thetas = list()
  for i, c in enumerate(selected_clients):
    thetas.append(c.clientUpdate(alpha, net.state_dict()))
  
  h = {
    key: prev_h - alpha * 1 / m * sum(theta[key] - old_params for theta in thetas)
    for (key, prev_h), old_params in zip(h.items(), net.state_dict().values())
  }

  new_parameters = {
    key: (1 / K) * sum(theta[key] for theta in thetas)
    for key in net.state_dict().keys()
  }

  new_parameters = {
    key: params - (1 / alpha) * h_params
    for (key, params), h_params in zip(new_parameters.items(), h.values())
  }

  net.load_state_dict(new_parameters)

  if step % config["LOG_FREQUENCY"] == 0:
    client_losses_accuracies = [client.compute_accuracy(new_parameters) for client in clients]
    client_losses, client_accuracies = zip(*client_losses_accuracies)
    
    avg_client_accuracy = mean(client_acc for client_acc in client_accuracies)
    accuracies.append(avg_client_accuracy * 100)
    print(f"Average accuracy after {step} rounds is {avg_client_accuracy*100}")

  avg_accuracy = mean(float(client.compute_accuracy(new_parameters)[1]) for client in clients)

  config["AVERAGE_ACCURACY"][alpha_i] = avg_accuracy
  print(f"Average accuracy with alpha = {alpha} after {step+1} rounds is {avg_accuracy*100}")
  printJSON(alpha, accuracies, net)

!zip -r artifacts.zip ./artifacts # For save the folder on Kaggle