<a href="https://colab.research.google.com/github/mattiadutto/aml_federeted_learning/blob/dev_grouped_convolutions/aml_batched.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Baseline implementation

In [1]:
%pip install wandb --quiet

[K     |████████████████████████████████| 1.7 MB 12.3 MB/s 
[K     |████████████████████████████████| 180 kB 47.2 MB/s 
[K     |████████████████████████████████| 140 kB 51.8 MB/s 
[K     |████████████████████████████████| 97 kB 6.6 MB/s 
[K     |████████████████████████████████| 63 kB 1.8 MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
!wget http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
!unzip cifar10.zip

--2022-01-03 17:19:00--  http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.133.128, 74.125.140.128, 108.177.15.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.133.128|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1627997 (1.6M) [application/zip]
Saving to: ‘cifar10.zip’


2022-01-03 17:19:00 (155 MB/s) - ‘cifar10.zip’ saved [1627997/1627997]

Archive:  cifar10.zip
   creating: cifar10/
  inflating: cifar10/federated_train_alpha_0.00.csv  
  inflating: cifar10/test.csv        
  inflating: cifar10/federated_train_alpha_10.00.csv  
  inflating: cifar10/federated_train_alpha_0.05.csv  
  inflating: cifar10/federated_train_alpha_100.00.csv  
  inflating: cifar10/federated_train_alpha_0.10.csv  
  inflating: cifar10/federated_train_alpha_0.20.csv  
  inflating: cifar10/federated_train_alpha_1.00.csv  
  inflating: cifar10/federated_train_al

In [3]:
import wandb

wandb.init(project="step-2", entity="aml-federated-learning", mode="disabled")



In [4]:
E = 1
STEP_SIZE = 5
GAMMA = 0.1

# K = 1, NUMBE_OR_CLIENTS = 2, MAX_TIME = 3 -> 58 sec

K = 10 # to set
NUMBER_OF_CLIENTS = 100 # to set
MAX_TIME = 1000 #to set

batch_size = 10

lr = 0.05

DATA_DISTRIBUTION = "non-iid" # "iid" | "non-iid"
DIRICHELET_ALPHA = 0.05 # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0

assert(DATA_DISTRIBUTION == "iid" or NUMBER_OF_CLIENTS == 100)


wandb.config.update({
    "batch-size": batch_size,
    "learning-rate": lr,
    # "momentum": MOMENTUM,
    # "weight_decay": WEIGHT_DECAY,
    "num_epochs": E,
    "step_size": STEP_SIZE,
    "gamma": GAMMA,
    "K": K,
    "number_of_clients": NUMBER_OF_CLIENTS,
    "max_time": MAX_TIME,
    "data_distribution": DATA_DISTRIBUTION,
    "dirichelet_alpha": DIRICHELET_ALPHA
})

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        # kernel
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5*5 from image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [6]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class BatchedNet(nn.Module):

    def __init__(self, parallelism):
        super(BatchedNet, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution kernel
        self.P = parallelism
        self.ops = nn.Sequential(
            nn.Conv2d(3 * self.P, 6 * self.P, 5, groups=self.P),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),
            nn.Conv2d(6 * self.P, 16 * self.P, 5, groups=self.P),
            nn.ReLU(),
            nn.MaxPool2d(2),
            #nn.Flatten(1, -1)
        )
        self.ops2 = nn.Sequential(
            nn.Conv1d(16 * 5 * 5 * self.P, 120 * self.P,
                      kernel_size=1, groups=self.P),
            nn.ReLU(),
            nn.Conv1d(120 * self.P, 84 * self.P, kernel_size=1, groups=self.P),
            nn.ReLU(),
            nn.Conv1d(84 * self.P, 10 * self.P, kernel_size=1, groups=self.P)
        )

    def forward(self, x):
        x = self.ops(x)
        x = x.view(batch_size, -1, 1)
        return self.ops2(x)


In [7]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set):
    self.i = i
    self.train_set = train_set
    self.batch_size = 32
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=len(train_set),
                                         shuffle=False, num_workers=0, pin_memory=True)
    #self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size,
    #                                     shuffle=False, num_workers=0)
    #self.net = Net()
    #self.net = self.net.to(device)
    # create your optimizer
    #self.optimizer = optim.SGD(self.net.parameters(), lr=lr)
    #self.criterion = nn.CrossEntropyLoss()
    # self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=STEP_SIZE, gamma=GAMMA)
    #wandb.watch(self.net, criterion=self.criterion, log_freq=100, log_graph=True)
    
  """def clientUpdate(self, parameters):
    self.net.load_state_dict(parameters)
    theta = parameters
    for e in range(E):
      for images, labels in self.train_loader:
        images = images.to(device)
        labels = labels.to(device)
        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        loss = self.criterion(output, labels)
        loss.backward()
        wandb.log({f"client-loss-{self.i}": loss.item()})
        self.optimizer.step()    # Does the update
    
    return_dict = {}
    for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
      return_dict[k1] = v1 - v2
    return return_dict

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        data = data.to(device)
        labels = labels.to(device)

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)
                
    return running_corrects / n"""


In [8]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [9]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from time import perf_counter_ns

from tqdm.notebook import tqdm, trange

random.seed(42)

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)


if DATA_DISTRIBUTION == "iid":
    # split the training set
    trainset_len = (len(trainset) // NUMBER_OF_CLIENTS) * NUMBER_OF_CLIENTS
    print(trainset_len)
    trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

    lengths = len(trainset) // NUMBER_OF_CLIENTS * \
        np.ones(NUMBER_OF_CLIENTS, dtype=int)
    print(lengths)
    trainsets = torch.utils.data.random_split(
        dataset=trainset, lengths=lengths)
else:
    dirichelet_splits = parse_csv(
        f"cifar10/federated_train_alpha_{DIRICHELET_ALPHA:.2f}.csv")
    trainsets = [torch.utils.data.Subset(
        trainset, indices) for indices in dirichelet_splits.values()]


# split the validation set
testset_len = (len(testset) // NUMBER_OF_CLIENTS) * NUMBER_OF_CLIENTS
print(testset_len)
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // NUMBER_OF_CLIENTS * \
    np.ones(NUMBER_OF_CLIENTS, dtype=int)
# print(lengths)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(NUMBER_OF_CLIENTS)
clients = list()

# server reference model
reference = BatchedNet(1).to(device)


def selectClients(k):
    return random.sample(clients, k=k)


def aggregateClient(deltaThetas):
    parameters = None
    for i, d in enumerate(deltaThetas):
        ratio = len(trainsets[i])/len(trainset)

        if i == 0:
            parameters = {k: ratio*v for k, v in d.items()}
        else:
            for (k, v) in d.items():
                parameters[k] += ratio * v

    return parameters


for c in range(NUMBER_OF_CLIENTS):
    clients.append(Client(c, trainsets[c], testsets[c]))


batched_model = BatchedNet(K).to(device)
batched_optimizer = optim.SGD(batched_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss().to(device)


# move all the images and labels used by the selected clients to the gpu
# in one single pass
client_images, client_labels = [], []
for client in clients:
    # take all the images and labels used by the selected client
    images, labels = next(iter(client.train_loader))
    client_images.append(images.to(device))
    client_labels.append(labels.view((-1, 1)).to(device))


for step in trange(MAX_TIME):
    selected_clients = selectClients(K)
    selected_ids = [c.i for c in selected_clients]

    #t_2 = perf_counter_ns()

    # move all the images and labels used by the selected clients to the gpu
    # in one single pass
    # client_images, client_labels = [], []
    # for selected_client in selected_clients:
    #     # take all the images and labels used by the selected client
    #     images, labels = next(iter(selected_client.train_loader))
    #     client_images.append(images.to(device))
    #     client_labels.append(labels.view((-1, 1)).to(device))

    selected_client_images = [ci for i, ci in enumerate(client_images) if i in selected_ids]
    selected_client_labels = [cl for i, cl in enumerate(client_labels) if i in selected_ids]

    # client_images[i] is a Tensor of shape [Ni, 3, 32, 32] where Ni is the number of images
    # assigned to client i

    #t_1 = perf_counter_ns()

    # load the batched model state dict by stacking K times the parameters of the server model (reference)
    parameters = {key: torch.stack([params] * K).flatten(0, 1)
                  for key, params in reference.state_dict().items()}
    batched_model.load_state_dict(parameters)

    #t0 = perf_counter_ns()
    for epoch in range(E):
        # for each local epoch

        n_batches = selected_client_images[0].shape[0] // batch_size

        for i in range(n_batches):
            # for each local batch

            batched_optimizer.zero_grad(set_to_none=True)

            # load all the batches (one batch for each client)
            batch_images = [ci[i*batch_size:(i+1)*batch_size] for ci in selected_client_images]
            batch_labels = [cl[i*batch_size:(i+1)*batch_size] for cl in selected_client_labels]

            # reshape the batches as one tensor of shape [batch_size, K * 3, 32, 32]
            batch_images = torch.stack(batch_images).view((batch_size, -1, 32, 32))
            # print(batch_images.shape)

            # compute the batch ouput of the model
            # output[:, 10*i:10*(i+1)] is the model output for client i
            # shape of output is [10, 100, 1]
            output = batched_model(batch_images) 

            # compute the loss separately for each client
            loss = 0
            for _output, _labels in zip(torch.chunk(output, batch_size, dim=1), batch_labels):
               loss += criterion(_output, _labels)
            loss.backward()

            # apply the gradient descent step
            batched_optimizer.step()
    
    #t1 = perf_counter_ns()

    # extract the parameters of each client
    client_params = [dict() for _ in range(K)]
    for key, batched_parameters in batched_model.state_dict().items():
        # for each entry of the state dict

        # s represents the number of parameters of client i
        s = batched_parameters.shape[0] // K
        for i in range(K):
            # extract the parameters of client i
            client_params[i][key] = parameters[key][i * s: (i+1)*s] - batched_parameters[i * s: (i+1)*s]

    #t2 = perf_counter_ns()
    g = aggregateClient(client_params)
    #t3 = perf_counter_ns()

    parameters = dict()
    for (k1, v1), v2 in zip(reference.state_dict().items(), g.values()):
      parameters[k1] = v1 - v2 # todo: add server learning rate gamma

    #t4 = perf_counter_ns()
    reference.load_state_dict(parameters)

    # print("Data loading:", t_1-t_2)
    # print("Model setup time:", t0-t_1)
    # print("Training time:", t1-t0)
    # print("Clients parameters extraction:", t2-t1)
    # print("Aggregation:", t3-t2)
    # print("Computation of new parameters:", t4-t3)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified
10000


  0%|          | 0/1000 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

In [None]:
from collections import Counter

print(Counter(label for _, label in iter(trainsets[0])))
print(Counter(label for _, label in iter(trainsets[1])))
print(Counter(label for _, label in iter(trainsets[2])))

In [None]:
from statistics import mean

model_parameters = net.state_dict()
avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)

print(f"Average accuracy after {MAX_TIME} rounds is {avg_accuracy}")

In [None]:
import time

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}.pth"

# parameters of the trained model
server_model = net.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename)
# save the model on wandb
wandb.save(artifact_filename)

# Finish the wandb session and upload all data
wandb.finish(0, quiet=False)