# Baseline implementation

In [1]:
# download the Cifar10 non-iid splits, if not present

from os import path
import urllib.request
import zipfile

if not path.exists("cifar10"):
    save_path = "cifar10.zip"
    urllib.request.urlretrieve("http://storage.googleapis.com/gresearch/federated-vision-datasets/cifar10.zip", save_path)
    
    with zipfile.ZipFile(save_path, 'r') as zip_ref:
        zip_ref.extractall(".")

In [9]:
config = {
    "E": 1, # number of local epochs
    "K": 1, # number of clients selected each round
    "NUMBER_OF_CLIENTS": 1, # total number of clients
    "MAX_TIME": 100,
    "BATCH_SIZE": 50,
    "LR": 0.001,
    "DATA_DISTRIBUTION": "iid", # "iid" | "non-iid"
    #"DIRICHELET_ALPHA": 0.00, # 0.00, 0.05, 0.10, 0.20, 0.50, 1.00, 10.00, 100.0
    #"LR_DECAY": 0.99,
    "WEIGHT_DECAY": 1e-3 #,
    #"LOG_FREQUENCY": 100,
    #"TRANSFORM_CROP": 24,
    #"TRANSFORM_RND_HFLIP_PROB": 0.25,
    #"TRANSFORM_BRIGHTNESS": 0.5,
    #"TRANSFORM_CONTRAST": 0.5,
    #"TRANSFORM_SATURATION": 0.5,
    #"TRANSFORM_HUE": 0.5
}

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

device = "cuda" if torch.cuda.is_available() else "cpu"

In [5]:
# From: https://pytorch.org/tutorials/beginner/blitz/neural_networks_tutorial.html
class Net(nn.Module):

    def __init__(self, *, input_size=32):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution
        self.conv1 = nn.Conv2d(3, 64, 5)
        self.conv2 = nn.Conv2d(64, 64, 5)
        
        # output of the conv layer is (w', h') = (w - 5 + 1, h - 5 + 1)
        # max_pool2d halves the dimensions (w', h') = (w / 2, h / 2)

        # dynamically compute the image size
        size = input_size // 4 - 3
        self.fc1 = nn.Linear(64 * (size * size), 384)
        self.fc2 = nn.Linear(384, 192)
        self.fc3 = nn.Linear(192, 10)

    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square, you can specify with a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = torch.flatten(x, 1) # flatten all dimensions except the batch dimension
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


net = Net()
net = net.to(device)
print(net)

Net(
  (conv1): Conv2d(3, 64, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=1600, out_features=384, bias=True)
  (fc2): Linear(in_features=384, out_features=192, bias=True)
  (fc3): Linear(in_features=192, out_features=10, bias=True)
)


In [6]:
import torch.optim as optim

class Client():
  def __init__(self, i, train_set, validation_set, *, input_size=32):
    self.i = i
    self.train_loader = torch.utils.data.DataLoader(train_set, batch_size=config["BATCH_SIZE"],
                                         shuffle=True, num_workers=0)
    self.validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=config["BATCH_SIZE"],
                                         shuffle=False, num_workers=0)
    self.net = Net(input_size=input_size)
    self.net = self.net.to(device)
    # create your optimizer
    self.optimizer = optim.SGD(self.net.parameters(), lr=config["LR"], momentum = 0.9, weight_decay = config["WEIGHT_DECAY"])
    self.criterion = nn.CrossEntropyLoss()
    # wandb.watch(self.net, criterion=self.criterion, log_freq=100, log_graph=False)
    
  def clientUpdate(self, lr, parameters):
    self.net.load_state_dict(parameters)
    for g in self.optimizer.param_groups:
      g['lr'] = lr

    for e in range(config["E"]):
      for images, labels in self.train_loader:
        images = images.to(device)
        labels = labels.to(device)
        # in your training loop:
        self.optimizer.zero_grad()   # zero the gradient buffers
        output = self.net(images)
        loss = self.criterion(output, labels)
        loss.backward()
        # wandb.log({f"client-loss-{self.i}": loss.item()})
        self.optimizer.step()    # Does the update
    
    return_dict = {}
    for (k1, v1), (k2, v2) in zip(parameters.items(), self.net.state_dict().items()):
      return_dict[k1] = v1 - v2
    return return_dict

  def compute_accuracy(self, parameters):
    self.net.load_state_dict(parameters)

    running_corrects = 0
    n = 0
    for data, labels in self.validation_loader:
        data = data.to(device)
        labels = labels.to(device)

        outputs = self.net(data)

        _, preds = torch.max(outputs.data, 1)

        running_corrects += torch.sum(preds == labels.data).data.item()
        n += len(preds)

                
    return running_corrects / n


In [7]:
from collections import defaultdict

def parse_csv(filename):
  splits = defaultdict(lambda: [])
  with open(filename) as f:
    for line in f:
      if not line[0].isdigit():
        continue

      user_id, image_id, _ = (int(token) for token in line.split(","))
      splits[user_id].append(image_id)

  return splits


In [11]:
import torchvision
import torchvision.transforms as transforms
import numpy as np
import random
from statistics import mean

from tqdm.notebook import tqdm

random.seed(42)

train_transform = transforms.Compose([
  transforms.ToTensor(),
     #transforms.RandomCrop(config["TRANSFORM_CROP"]),
     #transforms.RandomHorizontalFlip(config["TRANSFORM_RND_HFLIP_PROB"]),
     #transforms.ColorJitter(brightness=config["TRANSFORM_BRIGHTNESS"]),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
  transforms.ToTensor(),
  #transforms.CenterCrop(config["TRANSFORM_CROP"]),
  #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=test_transform)


if config["DATA_DISTRIBUTION"] == "iid":
  # split the training set
  trainset_len = ( len(trainset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
  trainset = torch.utils.data.Subset(trainset, list(range(trainset_len)))

  lengths = len(trainset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
  trainsets = torch.utils.data.random_split(dataset=trainset, lengths=lengths)
else:
  dirichelet_splits = parse_csv(f"cifar10/federated_train_alpha_{config['DIRICHELET_ALPHA']:.2f}.csv")
  trainsets = [torch.utils.data.Subset(trainset, indices) for indices in dirichelet_splits.values()]


# split the validation set
testset_len = ( len(testset) // config["NUMBER_OF_CLIENTS"] ) * config["NUMBER_OF_CLIENTS"]
testset = torch.utils.data.Subset(testset, list(range(testset_len)))

lengths = len(testset) // config["NUMBER_OF_CLIENTS"] * np.ones(config["NUMBER_OF_CLIENTS"], dtype=int)
testsets = torch.utils.data.random_split(dataset=testset, lengths=lengths)


clientsSizes = torch.zeros(config["NUMBER_OF_CLIENTS"])
clients = list()

def selectClients(k):
  return random.choices(clients, k=k)

def aggregateClient(deltaThetas):
  parameters = None
  for i,d in enumerate(deltaThetas):
    #ratio = len(trainsets[i])/len(trainset)
    ratio = len(trainsets[i])/(len(trainsets[i])*config['K'])
    
    if i == 0:
      parameters = {k:ratio*v for k, v in d.items()}
    else:
      for (k, v) in d.items():
        parameters[k] += ratio * v
   
  return parameters

for c in range(config["NUMBER_OF_CLIENTS"]):
  clients.append(Client(c, trainsets[c], testsets[c]))#, input_size=config["TRANSFORM_CROP"]))

# if config["FED_AVG_M"]:
#   old_parameters = {}

# initial learning rate
lr = config["LR"]

# collect the test accuracies over the epochs
test_accuracies = []

for step in tqdm(range(config["MAX_TIME"])):
#for t in range(MAX_TIME):
  selected_clients = selectClients(config["K"])
  #print(f"Client(s) {[client.i for client in selected_clients]} selected")

  deltaThetas = list()
  for i, c in enumerate(selected_clients):
    deltaThetas.append(c.clientUpdate(lr, net.state_dict()))
    
  g = aggregateClient(deltaThetas)
  
  parameters = {}
  for (k1, v1), (k2, v2) in zip(net.state_dict().items(), g.items()):
    
    # if config["FED_AVG_M"]:
    #   if k1 in old_parameters:
    #     parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * (config["FED_AVG_M_BETA"] * old_parameters[k1] + v2)  
    #     old_parameters[k1] = config["FED_AVG_M_BETA"] * old_parameters[k1] + v2
    #   else:
    #     parameters[k1] = v1 - config["FED_AVG_M_GAMMA"] * v2
    #     old_parameters[k1] = v2
        
    # else:
    parameters[k1] = v1 - v2 # todo: add server learning rate gamma

  net.load_state_dict(parameters)

  # lr *= config["LR_DECAY"]

  # if step % config["LOG_FREQUENCY"] == 0:
  #   model_parameters = net.state_dict()
  #   avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)
  #   test_accuracies.append(avg_accuracy)

  #   print(f"Average accuracy after {step} rounds is {avg_accuracy}")

Files already downloaded and verified
Files already downloaded and verified


  0%|          | 0/100 [00:00<?, ?it/s]

In [12]:
from statistics import mean

model_parameters = net.state_dict()
avg_accuracy = mean(client.compute_accuracy(model_parameters) for client in clients)
test_accuracies.append(avg_accuracy)

print(f"Average accuracy after {config['MAX_TIME']} rounds is {avg_accuracy}")

Average accuracy after 100 rounds is 0.4697


In [None]:
import time
import json

timestr = time.strftime("%Y_%m_%d-%I_%M_%S_%p")
artifact_filename = f"artifacts/server_model-{timestr}"

# parameters of the trained model
server_model = net.state_dict()
# save the model on the local file system
torch.save(server_model, artifact_filename + ".pth")

data = {
    "config": config,
    "test_accuracies": test_accuracies
}

with open(artifact_filename + ".json", "w") as f:
    f.write(json.dumps(data, indent=4))

# save the model on wandb
# wandb.save(artifact_filename)
# Finish the wandb session and upload all data
# wandb.finish(0, quiet=False)