In [1]:
import os
import random
import numpy as np
import torch
import json

# Configuration

In [2]:
GPU = 0

SEEDS = [0, 1, 42, 2 ** 8, 2 ** 16]
SEED = SEEDS[2]    # Default 42

# DATASETS = ['cifar10', 'mnist', 'permuted_mnist', 'fmnist']
DATASETS = ['cifar10', 'mnist', 'fmnist']
DATASET = DATASETS[0]    # Default cifar10

PSET = [16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56]
NUM_PARTIES = PSET[0]    # 0(16) ~ 11(56) # Default 16

IFD, STRATIFY = True, True
# IFD, STRATIFY = True, False
# IFD, STRATIFY = False, False
# IFD, STRATIFY = False, True

BOOST_FRACS = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
# BOOST_FRACS = [0.25, 0.27, 0.29, 0.31, 0.33, 0.35]
BOOST_FRAC = BOOST_FRACS[0]    # Default 0.3

BATCH_SIZES = [64, 128, 256, 512, 1024]
BATCH_SIZE = BATCH_SIZES[0]    # Default 64

ROUNDSET = [30, 50, 100]
ROUNDS = ROUNDSET[1]    # Default 50

EPOCHSET = [10, 20, 30, 40, 50]
EPOCHS = EPOCHSET[1]    # Default 20

MODELS = ['MNISTLeNet5', 'LeNet5', 'VGG9', 'ResNet18', 'MobileNetV2', 'DenseNet', 'EfficientNet']
MODEL = MODELS[3]    # Default VGG9

OPTIMIZERS = ['sgd', 'adam']
OPTIMIZER = OPTIMIZERS[0]    # Default sgd

LRS = [0.1, 0.01, 0.001]
LR = LRS[1]    # Default 0.01

if OPTIMIZER == 'adam':
    MOMENTUM = None
elif OPTIMIZER == 'sgd':
    MOMENTUM = 0.9
    
WDECAYS = [0, 0.0001]
WEIGHT_DECAY = WDECAYS[0]    # Default 0.0001

print('Dataset: ', DATASET.upper())
print('Parties: ', NUM_PARTIES)
print('Identical Forgettable Distribution: ', IFD)
print('Stratification: ', STRATIFY)
print('Data Boost Fraction: ', BOOST_FRAC)
print('Batch Size: ', BATCH_SIZE)
print('Communicative Rounds: ', ROUNDS)
print('Local Epochs: ', EPOCHS)
print('Model: ', MODEL)
print('Optimizer: ', OPTIMIZER.upper())
print('Learning Rate: ', LR)
print('Momentum: ', MOMENTUM)
print('Weight Decay: ', WEIGHT_DECAY)
print('Seed: ', SEED)

if IFD:
    forget_dist = 'FSTR'
else:
    forget_dist = 'FRND'
    
if STRATIFY:
    target_dist = 'LSTR'
else:
    target_dist = 'LRND'

FNAME = f'P{NUM_PARTIES}_{MODEL}_BT{int(BOOST_FRAC * 10)}_BS{BATCH_SIZE}_R{ROUNDS}_E{EPOCHS}_{forget_dist}_{target_dist}_S{SEED}'

print('\nFile name: ', FNAME)

Dataset:  CIFAR10
Parties:  16
Identical Forgettable Distribution:  True
Stratification:  True
Data Boost Fraction:  0.0
Batch Size:  64
Communicative Rounds:  50
Local Epochs:  20
Model:  ResNet18
Optimizer:  SGD
Learning Rate:  0.01
Momentum:  0.9
Weight Decay:  0
Seed:  42

File name:  P16_ResNet18_BT0_BS64_R50_E20_FSTR_LSTR_S42


# Basic Setup

In [3]:
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed = SEED
np.random.seed = SEED
torch.manual_seed = SEED

In [4]:
device = torch.device(f'cuda:{GPU}' if torch.cuda.is_available() else 'cpu')
print('Device: ', device)

Device:  cuda:0


In [5]:
PATH_ROOT = os.path.dirname(os.getcwd())

PATH_DATA = os.path.join(PATH_ROOT, 'data')
PATH_DSET = os.path.join(PATH_DATA, DATASET)
PATH_PART = os.path.join(PATH_DSET, '{}parties'.format(NUM_PARTIES))
PATH_FGD = os.path.join(PATH_PART, 'ifd') if IFD else os.path.join(PATH_PART, 'non_ifd')
PATH_TGD = os.path.join(PATH_FGD, 'stratified') if STRATIFY else os.path.join(PATH_FGD, 'random')

PATH_SAVE = os.path.join(PATH_ROOT, 'saves')
SAVE_DSET = os.path.join(PATH_SAVE, DATASET)
if not os.path.exists(SAVE_DSET):
    os.mkdir(SAVE_DSET)
    
PATH_CONF = os.path.join(PATH_ROOT, 'configs')

print('    Project Root: ', PATH_ROOT)
print('        Data Dir: ', PATH_DATA)
print(' {:>11} Dir: '.format(DATASET.upper()), PATH_DSET)
print(' Party Split Dir: ', PATH_PART)
print(' Forget Dist Dir: ', PATH_FGD)
print('  Label Dist Dir: ', PATH_TGD)
print('        Save Dir: ', PATH_SAVE)
print('{:>7} Save Dir: '.format(DATASET.upper()), SAVE_DSET)
print('      Config Dir: ', PATH_CONF)

    Project Root:  /home/dev/projects/aaai20
        Data Dir:  /home/dev/projects/aaai20/data
     CIFAR10 Dir:  /home/dev/projects/aaai20/data/cifar10
 Party Split Dir:  /home/dev/projects/aaai20/data/cifar10/16parties
 Forget Dist Dir:  /home/dev/projects/aaai20/data/cifar10/16parties/ifd
  Label Dist Dir:  /home/dev/projects/aaai20/data/cifar10/16parties/ifd/stratified
        Save Dir:  /home/dev/projects/aaai20/saves
CIFAR10 Save Dir:  /home/dev/projects/aaai20/saves/cifar10
      Config Dir:  /home/dev/projects/aaai20/configs


In [6]:
config = {
    'seed': SEED,
    'dataset': DATASET,
    'num_parties': NUM_PARTIES,
    'ifd': IFD,
    'stratifiy': STRATIFY,
    'boost_frac': BOOST_FRAC,
    'batch_size': BATCH_SIZE,
    'comm_round': ROUNDS,
    'local_epoch': EPOCHS,
    'model': MODEL,
    'optimizer': OPTIMIZER,
    'lr': LR,
    'momentum': MOMENTUM,
    'weight_decay': WEIGHT_DECAY
}

with open('{}.json'.format(os.path.join(PATH_CONF, FNAME)), 'w', encoding='utf-8') as f:
    json.dump(config, f, indent=4, separators=(',', ': '))

print('File name: ', FNAME)

File name:  P16_ResNet18_BT0_BS64_R50_E20_FSTR_LSTR_S42


# Data Setup

In [7]:
import pandas as pd

from torch.utils.data import Subset, DataLoader
from torchvision import datasets, transforms

In [8]:
parties = []

for i in range(1, NUM_PARTIES + 1):
    df = pd.read_csv(os.path.join(PATH_TGD, '{}_p{}.csv'.format(DATASET, i)))
    if BOOST_FRAC > 0:
        boost = df.iloc[df[df['forgettable'] == 1].sample(frac=BOOST_FRAC).index]
        df = df.append(boost, ignore_index=True)
    parties.append(df)
    print('[Party {:>2}] Samples {}, Forgettables {}, Unforgettables {}'.format(
        i, len(df.index), len(df.loc[df['forgettable'] == 1].index), len(df.loc[df['forgettable'] == 0].index)
    ))

[Party  1] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  2] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  3] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  4] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  5] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  6] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  7] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  8] Samples 3125, Forgettables 2218, Unforgettables 907
[Party  9] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 10] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 11] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 12] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 13] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 14] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 15] Samples 3125, Forgettables 2218, Unforgettables 907
[Party 16] Samples 3125, Forgettables 2218, Unforgettab

In [9]:
if DATASET == 'cifar10':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
    ])
    train_dataset = datasets.CIFAR10(PATH_DATA, train=True, transform=transform, download=True)
    test_dataset = datasets.CIFAR10(PATH_DATA, train=False, transform=transform, download=False)
    
elif DATASET == 'mnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    train_dataset = datasets.MNIST(PATH_DATA, train=True, transform=transform, download=True)
    test_dataset = datasets.MNIST(PATH_DATA, train=False, transform=transform, download=False)
    
elif DATASET == 'fmnist':
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    train_dataset = datasets.FashionMNIST(PATH_DATA, train=True, transform=transform, download=True)
    test_dataset = datasets.FashionMNIST(PATH_DATA, train=False, transform=transform, download=False)

Files already downloaded and verified


In [10]:
NUM_WORKERS = 0
print('Num workers: ', NUM_WORKERS)

trainloaders = []

for p in parties:
    train_subset = Subset(train_dataset, p['indices'].to_numpy())
    trainloaders.append(DataLoader(train_subset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS))
    
testloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

Num workers:  0


# Model

In [11]:
import torch.nn as nn
import torch.nn.functional as F

In [12]:
class MNISTLeNet5(nn.Module):  # CNN
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4 * 4 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = x.view(-1, 4 * 4 * 50)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [13]:
class LeNet5(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(5 * 5 * 50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = x.view(-1, 5 * 5 * 50)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [14]:
class VGG9(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Dropout2d(p=0.05),

            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.fc_layer = nn.Sequential(
            nn.Dropout(p=0.1),
            nn.Linear(4096, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=0.1),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x

In [15]:
class ResidualBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()

        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion * planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [16]:
class ResNet18(nn.Module):
    def __init__(self):
        super().__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(ResidualBlock, 64, 2, stride=1)
        self.layer2 = self._make_layer(ResidualBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(ResidualBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(ResidualBlock, 512, 2, stride=2)
        self.linear = nn.Linear(512 * ResidualBlock.expansion, 10)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [17]:
class MobileBlock(nn.Module):
    def __init__(self, in_planes, out_planes, expansion, stride):
        super().__init__()
        self.stride = stride

        planes = expansion * in_planes
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn3 = nn.BatchNorm2d(out_planes)

        self.shortcut = nn.Sequential()
        if stride == 1 and in_planes != out_planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_planes),
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out = out + self.shortcut(x) if self.stride==1 else out
        return out

In [18]:
class MobileNetV2(nn.Module):
    cfg = [(1, 16, 1, 1),
           (6, 24, 2, 1),  # NOTE: change stride 2 -> 1 for CIFAR10
           (6, 32, 3, 2),
           (6, 64, 4, 2),
           (6, 96, 3, 1),
           (6, 160, 3, 2),
           (6, 320, 1, 1)]

    def __init__(self, num_classes=10):
        super(MobileNetV2, self).__init__()
        # NOTE: change conv1 stride 2 -> 1 for CIFAR10
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.layers = self._make_layers(in_planes=32)
        self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
        self.bn2 = nn.BatchNorm2d(1280)
        self.linear = nn.Linear(1280, num_classes)

    def _make_layers(self, in_planes):
        layers = []
        for expansion, out_planes, num_blocks, stride in self.cfg:
            strides = [stride] + [1] * (num_blocks-1)
            for stride in strides:
                layers.append(MobileBlock(in_planes, out_planes, expansion, stride))
                in_planes = out_planes
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layers(out)
        out = F.relu(self.bn2(self.conv2(out)))
        # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

In [19]:
import math

In [20]:
class BottleNeck(nn.Module):
    def __init__(self, in_planes, growth_rate):
        super().__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.conv1 = nn.Conv2d(in_planes, 4*growth_rate, kernel_size=1, bias=False)
        self.bn2 = nn.BatchNorm2d(4*growth_rate)
        self.conv2 = nn.Conv2d(4*growth_rate, growth_rate, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        out = self.conv1(F.relu(self.bn1(x)))
        out = self.conv2(F.relu(self.bn2(out)))
        out = torch.cat([out,x], 1)
        return out

In [21]:
class Transition(nn.Module):
    def __init__(self, in_planes, out_planes):
        super().__init__()
        self.bn = nn.BatchNorm2d(in_planes)
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)

    def forward(self, x):
        out = self.conv(F.relu(self.bn(x)))
        out = F.avg_pool2d(out, 2)
        return out

In [22]:
class DenseNet(nn.Module):
    def __init__(self, growth_rate=12, reduction=0.5, num_classes=10):
        super().__init__()
        self.growth_rate = growth_rate

        num_planes = 2*growth_rate
        self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, padding=1, bias=False)

        self.dense1 = self._make_dense_layers(BottleNeck, num_planes, 6)
        num_planes += 6 * growth_rate
        out_planes = int(math.floor(num_planes*reduction))
        self.trans1 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense2 = self._make_dense_layers(BottleNeck, num_planes, 12)
        num_planes += 12 * growth_rate
        out_planes = int(math.floor(num_planes*reduction))
        self.trans2 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense3 = self._make_dense_layers(BottleNeck, num_planes, 24)
        num_planes += 24 * growth_rate
        out_planes = int(math.floor(num_planes*reduction))
        self.trans3 = Transition(num_planes, out_planes)
        num_planes = out_planes

        self.dense4 = self._make_dense_layers(BottleNeck, num_planes, 16)
        num_planes += 16 * growth_rate

        self.bn = nn.BatchNorm2d(num_planes)
        self.linear = nn.Linear(num_planes, num_classes)

    def _make_dense_layers(self, block, in_planes, nblock):
        layers = []
        for i in range(nblock):
            layers.append(block(in_planes, self.growth_rate))
            in_planes += self.growth_rate
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv1(x)
        out = self.trans1(self.dense1(out))
        out = self.trans2(self.dense2(out))
        out = self.trans3(self.dense3(out))
        out = self.dense4(out)
        out = F.avg_pool2d(F.relu(self.bn(out)), 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# Federated Learning

In [23]:
import copy
import time
import torch.optim as optim

from datetime import timedelta

In [24]:
def average_weights(w):
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(len(w)))
    return w_avg

In [25]:
def test(model, testloader, criterion, device):
    loss, correct, total = 0, 0, 0

    model.eval()

    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        batch_ls = criterion(outputs, labels)
        loss += batch_ls.item()

        _, preds = torch.max(outputs, 1)
        preds = preds.view(-1)
        correct += torch.sum(torch.eq(preds, labels)).item()
        total += len(labels)

    accuracy = correct / total
    loss /= len(testloader)
    return accuracy, loss

In [26]:
def train(model, loader, epochs, lr, momentum, weight_decay, criterion, device):
    losses = []
    
    if momentum:
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
    else:
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    model.train()

    for ep in range(epochs):
        batch_ls = []
        
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)

            model.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            batch_ls.append(loss.item())

        loss_avg = sum(batch_ls) / len(batch_ls)
        losses.append(loss_avg)
    return model.state_dict(), losses

In [27]:
from torchsummary import summary

In [28]:
if DATASET == 'mnist':
    fed_model = MNISTLeNet5()
elif MODEL == 'LeNet5':
    fed_model = LeNet5()
elif MODEL == 'VGG9':
    fed_model = VGG9()
elif MODEL == 'ResNet18':
    fed_model = ResNet18()
elif MODEL == 'MobileNetV2':
    fed_model = MobileNetV2()
elif MODEL == 'DenseNet':
    fed_model = DenseNet()
elif MODEL == 'EfficientNet':
    fed_model = EfficientNet()
    
if DATASET == 'cifar10':
    in_shape = (3, 32, 32)
else:
    in_shape = (1, 28, 28)
    
summary(fed_model, in_shape, device='cpu')

fed_model.to(device)
fed_weights = fed_model.state_dict()
criterion = nn.CrossEntropyLoss().to(device)

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 64, 32, 32]           1,728
       BatchNorm2d-2           [-1, 64, 32, 32]             128
            Conv2d-3           [-1, 64, 32, 32]          36,864
       BatchNorm2d-4           [-1, 64, 32, 32]             128
            Conv2d-5           [-1, 64, 32, 32]          36,864
       BatchNorm2d-6           [-1, 64, 32, 32]             128
     ResidualBlock-7           [-1, 64, 32, 32]               0
            Conv2d-8           [-1, 64, 32, 32]          36,864
       BatchNorm2d-9           [-1, 64, 32, 32]             128
           Conv2d-10           [-1, 64, 32, 32]          36,864
      BatchNorm2d-11           [-1, 64, 32, 32]             128
    ResidualBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13          [-1, 128, 16, 16]          73,728
      BatchNorm2d-14          [-1, 128,

In [None]:
train_losses, test_accs, test_losses = [], [], []

st = time.time()
for r in range(ROUNDS):
    local_weights, local_losses = [], []
    print(f'\n | Global Training Round : {r + 1} / {ROUNDS} |')
    
    fed_model.train()
    
    for i, p in enumerate(parties):
        w, ls = train(copy.deepcopy(fed_model), trainloaders[i], EPOCHS, LR, MOMENTUM, WEIGHT_DECAY, criterion, device)
        local_weights.append(copy.deepcopy(w))
        train_losses.append(ls)
        print('  |-- [Party {:>2}] Average Train Loss: {:.4f} ... {} local epochs'.format(i + 1, sum(ls) / len(ls), EPOCHS))
        
    fed_weights = average_weights(local_weights)
    fed_model.load_state_dict(fed_weights)
    
    test_acc, test_ls = test(fed_model, testloader, criterion, device)
    test_accs.append(test_acc)
    test_losses.append(test_ls)
    print('    |---- Test Accuracy: {:.4f}%'.format(100 * test_acc))
    print('    |---- Test Loss: {:.4f}'.format(test_ls))
    print('    |---- Elapsed time: {}'.format(timedelta(seconds=time.time()-st)))


 | Global Training Round : 1 / 50 |
  |-- [Party  1] Average Train Loss: 0.4759 ... 20 local epochs
  |-- [Party  2] Average Train Loss: 0.4311 ... 20 local epochs
  |-- [Party  3] Average Train Loss: 0.4519 ... 20 local epochs
  |-- [Party  4] Average Train Loss: 0.4376 ... 20 local epochs
  |-- [Party  5] Average Train Loss: 0.4714 ... 20 local epochs
  |-- [Party  6] Average Train Loss: 0.4236 ... 20 local epochs
  |-- [Party  7] Average Train Loss: 0.4348 ... 20 local epochs
  |-- [Party  8] Average Train Loss: 0.4726 ... 20 local epochs
  |-- [Party  9] Average Train Loss: 0.4819 ... 20 local epochs
  |-- [Party 10] Average Train Loss: 0.4132 ... 20 local epochs
  |-- [Party 11] Average Train Loss: 0.4266 ... 20 local epochs
  |-- [Party 12] Average Train Loss: 0.4545 ... 20 local epochs
  |-- [Party 13] Average Train Loss: 0.4314 ... 20 local epochs
  |-- [Party 14] Average Train Loss: 0.4273 ... 20 local epochs
  |-- [Party 15] Average Train Loss: 0.4794 ... 20 local epochs
  |

# Save

In [None]:
train_losses = np.asarray(train_losses)

In [None]:
with open(os.path.join(SAVE_DSET, '{}_tr_ls.npy'.format(FNAME)), 'wb') as f:
    np.save(f, train_losses)
with open(os.path.join(SAVE_DSET, '{}_te_ls.npy'.format(FNAME)), 'wb') as f:
    np.save(f, test_losses)
with open(os.path.join(SAVE_DSET, '{}_te_acc.npy'.format(FNAME)), 'wb') as f:
    np.save(f, test_accs)

# Plot

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt

In [None]:
fig, axs = plt.subplots(nrows=2, ncols=2, figsize=(16, 9))
axs = axs.ravel()

axs[0].plot(test_accs, c='orange')
axs[0].set_title('Test Accuracies')
axs[0].set_xlabel('Rounds')
axs[1].set_ylabel('Test Accuracy')
axs[1].plot(test_losses, c='purple')
axs[1].set_title('Test Losses')
axs[1].set_xlabel('Rounds')
axs[1].set_ylabel('Test Loss')
axs[2].plot(train_losses.mean(axis=1), c='red')
axs[2].set_title('Train Average Losses')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Train Average Loss')
axs[3].plot(train_losses.mean(axis=1).reshape(-1, 10).mean(axis=1), c='blue')
axs[3].set_title('Train Average Losses')
axs[3].set_xlabel('Rounds')
axs[3].set_ylabel('Train Average Loss')

plt.show()