In [None]:
import os
import random
import time
import copy
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import matplotlib.pyplot as plt

from multiprocessing import cpu_count
from datetime import timedelta
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import transforms
from torchvision.datasets import CIFAR10

## Fix Seed

In [None]:
SEED = 42

os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True

## Assign Device

In [None]:
GPU = 0

device = torch.device(f'cuda:{GPU}' if torch.cuda.is_available() else 'cpu')

## Augment Data

In [None]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([x / 255 for x in [125.3, 123, 113.9]], [x / 255 for x in [63, 62.1, 66.7]])
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([x / 255 for x in [125.3, 123, 113.9]], [x / 255 for x in [63, 62.1, 66.7]])    
])

train_dataset = CIFAR10(root='./data', train=True, transform=transform_train, download=True)
test_dataset = CIFAR10(root='./data', train=False, transform=transform_test)

## Define Model

In [None]:
class BottleNeck(nn.Module):
    def __init__(self, in_planes, growth_rate):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, 4 * growth_rate, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(4 * growth_rate)
        self.conv2 = nn.Conv2d(4 * growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat([out, x], 1)
        return out

In [None]:
class Transition(nn.Module):
    def __init__(self, in_planes, out_planes):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_planes)
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.conv(F.relu(self.bn(x)))
        out = F.avg_pool2d(out, 2)
        return out

In [None]:
class DenseNet(nn.Module):
    def __init__(self, growth_rate=12, reduction=0.5, num_classes=10):
        super().__init__()
        self.growth_rate = growth_rate

        num_planes = 2 * growth_rate
        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)

        self.dense1 = self._make_dense_layers(BottleNeck, num_planes, 6)
        num_planes += 6 * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans1 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense2 = self._make_dense_layers(BottleNeck, num_planes, 12)
        num_planes += 12 * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans2 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense3 = self._make_dense_layers(BottleNeck, num_planes, 24)
        num_planes += 24 * growth_rate
        out_planes = int(math.floor(num_planes * reduction))
        self.trans3 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense4 = self._make_dense_layers(BottleNeck, num_planes, 16)
        num_planes += 16 * growth_rate

        self.bn = nn.BatchNorm2d(num_planes)
        self.linear = nn.Linear(num_planes, num_classes)

    def _make_dense_layers(self, block, in_planes, nblock):
        layers = []
        for i in range(nblock):
            layers.append(block(in_planes, self.growth_rate))
            in_planes += self.growth_rate
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.trans3(self.dense3(out))
        out = self.dense4(out)
        out = F.avg_pool2d(F.relu(self.bn(out)), 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

## Split Data

In [None]:
BATCH_SIZE = 2 ** 8

if cpu_count() > 5:
    NUM_WORKERS = cpu_count() // 2
elif cpu_count() < 2:
    NUM_WORKERS = 0
else:
    NUM_WORKERS = 2

testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

In [None]:
num_models = 10

x_idxs = np.arange(len(train_dataset))
ys = train_dataset.targets.copy()
splits = []
for i in range(num_models):
    x_idxs, x_splits, ys, _  = train_test_split(x_idxs, ys, test_size=1/num_models, random_state=SEED, shuffle=True, stratify=ys)
    splits.append(x_splits)

In [None]:
trainloaders = []

for idxs in splits:
    sampler = SubsetRandomSampler(idxs)
    trainloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=NUM_WORKERS)
    trainloaders.append(trainloader)

## Training

In [None]:
fed_model = DenseNet().to(device)
fed_weights = fed_model.state_dict()

criterion = nn.CrossEntropyLoss().to(device)

In [None]:
def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(len(w)))
    return w_avg

In [None]:
def train(model, loader, epochs, lr, momentum, weight_decay, criterion, device):
    losses = []
    
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    
    model.train()

    for ep in range(epochs):
        batch_ls = []
        
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            batch_ls.append(loss.item())

        loss_avg = sum(batch_ls) / len(batch_ls)
        losses.append(loss_avg)
    return model.state_dict(), losses

In [None]:
def test(model, testloader, criterion, device):
    loss, correct, total = 0, 0, 0

    model.eval()

    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        batch_ls = criterion(outputs, labels)
        loss += batch_ls.item()

        _, preds = torch.max(outputs, 1)
        preds = preds.view(-1)
        correct += torch.sum(torch.eq(preds, labels)).item()
        total += len(labels)

    accuracy = correct / total
    loss /= len(testloader)
    return accuracy, loss

In [None]:
ROUNDS = 5
EPOCHS = 5
LR = 0.01
MOMENTUM = 0.9
WEIGHT_DECAY = 0.0005

train_losses, test_accs, test_losses = [], [], []

st = time.time()
for r in range(ROUNDS):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {r + 1} / {ROUNDS} |')
    
    fed_model.train()
    
    for tr in trainloaders:
        w, ls = train(copy.deepcopy(fed_model), tr, EPOCHS, LR, MOMENTUM, WEIGHT_DECAY, criterion, device)
        local_weights.append(copy.deepcopy(w))
        train_losses.append(ls)
        print('  |-- [Client {:>2}] Average Train Loss: {:.4f} ... {} local epochs'.format(i + 1, sum(ls) / len(ls), EPOCHS))
        
    fed_weights = average_weights(local_weights)
    fed_model.load_state_dict(fed_weights)
    
    test_acc, test_ls = test(fed_model, testloader, criterion, device)
    test_accs.append(test_acc)
    test_losses.append(test_ls)
    print('    |---- Test Accuracy: {:.4f}%'.format(100 * test_acc))
    print('    |---- Test Loss: {:.4f}'.format(test_ls))
    print('    |---- Elapsed time: {}'.format(timedelta(seconds=time.time()-st)))

In [None]:
train_losses = np.asarray(train_losses)

fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(16, 9))
axs = axs.ravel()

axs[0].plot(test_accs, c='orange')
axs[0].set_title('Test Accuracies')
axs[0].set_xlabel('Rounds')
axs[1].set_ylabel('Test Accuracy')
axs[1].plot(test_losses, c='purple')
axs[1].set_title('Test Losses')
axs[1].set_xlabel('Rounds')
axs[1].set_ylabel('Test Loss')
axs[2].plot(train_losses.mean(axis=1), c='red')
axs[2].set_title('Train Average Losses')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Train Average Loss')
axs[3].plot(train_losses.mean(axis=1).reshape(-1, 10).mean(axis=1), c='blue')
axs[3].set_title('Train Average Losses')
axs[3].set_xlabel('Rounds')
axs[3].set_ylabel('Train Average Loss')

plt.show()