In [None]:
!pip install --upgrade wandb

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import torch.nn as nn
import wandb
import os
from typing import Any, Dict, List
import copy
import random
import wandb

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ["WANDB_API_KEY"] = "183c1a6a36cbdf0405f5baacb72690845ecc8573"

In [None]:
class Client:
    def __init__(self,
                 client_id: Any,
                 model: torch.nn.Module,
                 loss: torch.nn.modules.loss._Loss,
                 optimizer: torch.optim.Optimizer,
                 optimizer_conf: Dict,
                 batch_size: int,
                 epochs: int,
                 server=None) -> None:
        self.client_id = client_id
        self.model = model
        self.loss = loss
        self.optimizer = optimizer(self.model.parameters(), **optimizer_conf)
        self.batch_size = batch_size
        self.epochs = epochs
        self.server = server
        self.accuracy = None
        self.total_loss = None

        self.data = None
        self.data_loader = None

    def setData(self, data):
        self.data = data
        self.data_loader = torch.utils.data.DataLoader(self.data,
                                                       batch_size=self.batch_size,
                                                       shuffle=True)
        self.server.total_data += len(self.data)

    def update_weights(self):
        for eps in range(self.epochs):
            total_loss = 0
            total_batches = 0
            total_correct = 0

            for _, (feature, label) in enumerate(self.data_loader):
                feature = feature.to(device)
                label = label.to(device)
                
                y_pred = self.model(feature)
                y_pred_decode = torch.argmax(y_pred, dim=1)
                
                total_correct += y_pred_decode.eq(label).sum().item()
                loss = self.loss(y_pred, label)
                
                self.optimizer.zero_grad(set_to_none=True)
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
                total_batches += 1
                

            self.total_loss = total_loss / total_batches
            self.accuracy = total_correct / (total_batches * self.batch_size)


class Server:
    def __init__(self,
                 model: torch.nn.Module,
                 loss: torch.nn.modules.loss._Loss,
                 optimizer: torch.optim.Optimizer,
                 optimizer_conf: Dict,
                 n_client: int = 10,
                 chosen_prob: float = 0.8,
                 local_batch_size: int = 8,
                 local_epochs: int = 10) -> None:

        # global model info
        self.model = model
        self.loss = loss
        self.optimizer = optimizer
        self.optimizer_conf = optimizer_conf
        self.n_client = n_client
        self.local_batch_size = local_batch_size
        self.local_epochs = local_epochs
        self.total_data = 0

        # create clients
        self.client_pool: List[Client] = []
        self.create_client()
        self.chosen_prob = chosen_prob
        self.avg_loss = 0
        self.avg_acc = 0

    def create_client(self):
        # this function is reusable, so reset client pool is needed
        self.client_pool: List[Client] = []
        self.total_data = 0

        for i in range(self.n_client):
            model = copy.deepcopy(self.model)
            new_client = Client(client_id=i,
                                model=model,
                                loss=self.loss,
                                optimizer=self.optimizer,
                                optimizer_conf=self.optimizer_conf,
                                batch_size=self.local_batch_size,
                                epochs=self.local_epochs,
                                server=self)
            self.client_pool.append(new_client)

    def broadcast(self):
        model_state_dict = copy.deepcopy(self.model.state_dict())
        for client in self.client_pool:
            client.model.load_state_dict(model_state_dict)

    def aggregate(self):
        self.avg_loss = 0
        self.avg_acc = 0
        chosen_clients = random.sample(self.client_pool,
                                       int(len(self.client_pool) * self.chosen_prob))

        global_model_weights = copy.deepcopy(self.model.state_dict())
        for key in global_model_weights:
            global_model_weights[key] = torch.zeros_like(
                global_model_weights[key])

        for client in chosen_clients:
            client.update_weights()
            print(f"Client {client.client_id}: Acc {client.accuracy}, Loss: {client.total_loss}")
            self.avg_loss += 1 / len(chosen_clients) * client.total_loss
            self.avg_acc += 1 / len(chosen_clients) * client.accuracy
            local_model_weights = copy.deepcopy(client.model.state_dict())
            for key in global_model_weights:
                global_model_weights[key] += 1 / len(chosen_clients) * local_model_weights[key]

        self.model.load_state_dict(global_model_weights)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root='cifar_data', train=True,
                                        download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='cifar_data', 
                                       train=False, 
                                       download=True, 
                                       transform=transform)

# # This is for iid data
# trainloader = torch.utils.data.DataLoader(trainset, 
#                                           batch_size=500,
#                                           shuffle=True)
testloader = torch.utils.data.DataLoader(testset,
                                        batch_size=100,
                                        shuffle=True)

In [None]:
n_client = 100
chosen_prob = 0.8
local_batch_size = 32
local_epochs = 10

In [None]:
epochs = 200
criteria = nn.CrossEntropyLoss()
optimizer = optim.SGD
optimizer_conf = dict(
    lr=0.001,
    momentum=0.9
)

In [None]:
model = Net().to(device)

server = Server(
    model=model,
    loss=criteria,
    optimizer=optimizer,
    n_client=n_client,
    chosen_prob=chosen_prob,
    optimizer_conf=optimizer_conf,
    local_batch_size=local_batch_size,
    local_epochs=local_epochs
)

In [None]:
# # This is for iid data
# for batch_idx, (batch_feature, batch_label) in enumerate(trainloader):
#     server.client_pool[batch_idx].setData(list(zip(batch_feature, batch_label)))

# This is for non-iid data
from collections import Counter
import random
import numpy as np
from tqdm import tqdm

classes_pair = [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9)]

data_label = np.array(trainset.targets)

chosen_counter = Counter()
for _, client in tqdm(enumerate(server.client_pool)):
    # sample until we have a pair of class with insufficient client owning
    class_pair = random.choice(classes_pair)
    while chosen_counter[class_pair] == 100:
        class_pair = random.choice(classes_pair)
        
    chosen_counter[class_pair] += 1
        
    first_class, second_class = class_pair
    first_class_sample_idx = list(np.where(data_label == first_class)[0])
    second_class_sample_idx = list(np.where(data_label == second_class)[0])
    
    client_first_class_sample_idx = random.sample(first_class_sample_idx, k=500)
    client_second_class_sample_idx = random.sample(second_class_sample_idx, k=500)
    
    client_data = []
    
    for i in range(500):
        client_data.append(trainset[client_first_class_sample_idx[i]])
        client_data.append(trainset[client_second_class_sample_idx[i]])
    
    client.setData(client_data)

In [None]:
wandb.init(project="fl", name="CNN_CIFAR_10_noniid")
for i in range(epochs):
    server.aggregate()
    server.broadcast()
    total_correct = 0
    with torch.no_grad():
        for _, (test_feature, test_label) in enumerate(testloader):
            test_feature = test_feature.to(device)
            test_label = test_label.to(device)
            y_pred = server.model(test_feature)
            y_pred_decode = torch.argmax(y_pred, dim=1)
            
            total_correct += y_pred_decode.eq(test_label).sum().item()
    
    test_acc = total_correct / 10000
        
    print("Overall acc: {}, overall_loss: {}, test_acc: {}".format(server.avg_acc, server.avg_loss, test_acc))
    wandb.log({"acc": server.avg_acc, "loss": server.avg_loss, "test_acc": test_acc})