In [1]:
%matplotlib inline

dataset = 'mnist'
num_clients = 100

filename = '{}_diri{}a01_42_fedavg_ep'.format(dataset, num_clients)

gpu = 0

In [2]:
import copy
import json
import logging
import os
import random
import time
from datetime import timedelta

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, Subset, DataLoader, SubsetRandomSampler
from torchvision import transforms, datasets

In [3]:
seed = 42

os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.benchmark = True
try:
    torch.use_deterministic_algorithms(False)
except AttributeError:
    torch.set_deterministic(False)

In [4]:
path_root = os.path.dirname(os.getcwd())
path_data = os.path.join(path_root, 'data')
path_logs = os.path.join(path_root, 'logs')
path_models = os.path.join(path_root, 'models', filename)
path_results = os.path.join(path_root, 'results', filename)

for p in [path_data, path_logs, path_models, path_results]:
    os.makedirs(p, exist_ok=True)

In [5]:
logger = logging.getLogger(filename)
logger.setLevel(logging.INFO)

streamformatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
streamhandler = logging.StreamHandler()
streamhandler.setFormatter(streamformatter)
logger.addHandler(streamhandler)

fileformatter = logging.Formatter('%(message)s')
filehandler = logging.FileHandler(os.path.join(path_logs, filename + '.log'), mode='w')
filehandler.setFormatter(fileformatter)
logger.addHandler(filehandler)

In [6]:
class CustomDataset(Dataset):
    def __init__(self, dataset):
        super().__init__()
        self.dataset = dataset

    def __getitem__(self, idx):
        data, target = self.dataset[idx]
        return data, target, idx

    def __len__(self):
        return len(self.dataset)

In [7]:
if dataset == 'cifar10':
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([x / 255 for x in [125.3, 123, 113.9]], [x / 255 for x in [63, 62.1, 66.7]])
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([x / 255 for x in [125.3, 123, 113.9]], [x / 255 for x in [63, 62.1, 66.7]])
    ])

    train_dataset = datasets.CIFAR10(path_data, train=True, transform=train_transform, download=True)
    custom_dataset = CustomDataset(train_dataset)
    test_dataset = datasets.CIFAR10(path_data, train=False, transform=test_transform)
    
elif dataset == 'svhn':
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4376821, 0.4437697, 0.47280442), (0.19803012, 0.20101562, 0.19703614))
    ])

    train_dataset = datasets.SVHN(path_data, split='train', transform=train_transform, download=True)
    custom_dataset = CustomDataset(train_dataset)
    test_dataset = datasets.SVHN(path_data, split='test', transform=test_transform, download=True)
    
elif dataset == 'fmnist':
    train_transform = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    train_dataset = datasets.FashionMNIST(path_data, train=True, transform=train_transform, download=True)
    custom_dataset = CustomDataset(train_dataset)
    test_dataset = datasets.FashionMNIST(path_data, train=False, transform=test_transform)

elif dataset == 'mnist':
    train_transform = transforms.Compose([
        transforms.RandomCrop(28, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))
    ])

    train_dataset = datasets.MNIST(path_data, train=True, transform=train_transform, download=True)
    custom_dataset = CustomDataset(train_dataset)
    test_dataset = datasets.MNIST(path_data, train=False, transform=test_transform)

In [8]:
with open(os.path.join(path_data, f'{dataset}_diri{num_clients}a01_42.json')) as f:
    indices = json.load(f)

In [9]:
batch_size, num_workers = 128, 0

inferloaders, subset_indices = {}, []
for k, v in indices.items():
    infersubset = Subset(custom_dataset, v['index'])
    inferloaders[k] = DataLoader(infersubset, batch_size=batch_size, num_workers=num_workers)
    subset_indices.extend(v['index'])

train_subset = Subset(train_dataset, indices=subset_indices)
fed_trainloader = DataLoader(train_subset, batch_size=batch_size, num_workers=num_workers)
    
try:
    train_labels = np.asarray(custom_dataset.dataset.targets)
    test_labels = np.asarray(test_dataset.targets)
except AttributeError:
    train_labels = np.asarray(custom_dataset.dataset.labels)
    test_labels = np.asarray(test_dataset.labels)
subset_classes = np.unique(train_labels[subset_indices])
boolarr = [True if y in subset_classes else False for y in test_labels]
subset_indices = np.arange(len(test_dataset))[boolarr]
test_subset = Subset(test_dataset, indices=subset_indices)
testloader = DataLoader(test_subset, batch_size=batch_size, num_workers=num_workers)

In [10]:
# class ResidualBlock(nn.Module):
#     expansion = 1

#     def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False):
#         super().__init__()

#         self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
#         self.bn1 = nn.BatchNorm2d(out_channel)
#         self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=kernel_size, stride=1, padding=padding, bias=bias)
#         self.bn2 = nn.BatchNorm2d(out_channel)

#         self.shortcut = nn.Sequential()
#         if (stride != 1) or (in_channel != self.expansion * out_channel):
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channel, self.expansion * out_channel, kernel_size=1, stride=stride, bias=bias),
#                 nn.BatchNorm2d(self.expansion * out_channel)
#             )

#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = self.bn2(self.conv2(out))
#         out += self.shortcut(x)
#         out = F.relu(out)
#         return out

# class Bottleneck(nn.Module):
#     expansion = 4
    
#     def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, padding=1, bias=False):
#         super().__init__()

#         self.conv1 = nn.Conv2d(in_channel, out_channel, kernel_size=1, bias=bias)
#         self.bn1 = nn.BatchNorm2d(out_channel)
#         self.conv2 = nn.Conv2d(out_channel, out_channel, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
#         self.bn2 = nn.BatchNorm2d(out_channel)
#         self.conv3 = nn.Conv2d(out_channel, self.expansion * out_channel, kerenel_size=1, bias=bias)
#         self.bn3 = nn.BatchNorm2d(self.expansion * out_channel)

#         self.shortcut = nn.Sequential()
#         if (stride != 1) or (in_channel != self.expansion * out_channel):
#             self.shortcut = nn.Sequential(
#                 nn.Conv2d(in_channel, self.expansion * out_channel, kernel_size=1, stride=stride, bias=bias),
#                 nn.BatchNorm2d(self.expansion * out_channel)
#             )

#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = F.relu(self.bn2(self.conv2(out)))
#         out = self.bn3(self.conv2(out))
#         out += self.shortcut(x)
#         out = F.relu(out)
#         return out
    
# class ResNet(nn.Module):
#     def __init__(self, block, in_channel=3, out_channels=[64, 128, 256, 512], num_blocks=[2, 2, 2, 2], strides=[1, 2, 2, 2], num_classes=10):
#         super().__init__()
#         self.in_channel = out_channels[0]

#         self.conv1 = nn.Conv2d(in_channel, out_channels[0], kernel_size=3, stride=strides[0], padding=1, bias=False)
#         self.bn1 = nn.BatchNorm2d(out_channels[0])
        
#         self.block1 = self._make_layer(block, out_channels[0], num_blocks[0], strides[0])
#         self.block2 = self._make_layer(block, out_channels[1], num_blocks[1], strides[1])
#         self.block3 = self._make_layer(block, out_channels[2], num_blocks[2], strides[2])
#         self.block4 = self._make_layer(block, out_channels[3], num_blocks[3], strides[3])
        
#         self.linear = nn.Linear(out_channels[-1] * block.expansion, num_classes)

#     def _make_layer(self, block, out_channel, num_blocks, stride):
#         strides = [stride] + [1] * (num_blocks - 1)
#         layers = []
#         for stride in strides:
#             layers.append(block(self.in_channel, out_channel, stride=stride))
#             self.in_channel = out_channel * block.expansion
#         return nn.Sequential(*layers)

#     def forward(self, x):
#         out = F.relu(self.bn1(self.conv1(x)))
#         out = self.block1(out)
#         out = self.block2(out)
#         out = self.block3(out)
#         out = self.block4(out)
#         out = F.avg_pool2d(out, 4)
#         out = out.view(out.size(0), -1)
#         out = self.linear(out)
#         return out

In [11]:
class MLP(nn.Module):
    def __init__(self, in_feature=784, num_classes=10):
        super().__init__()
        self.fc1 = nn.Linear(in_feature, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3 = nn.Linear(200, num_classes)
        
    def forward(self, x):
        x = x.view(-1, 784)
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        out = self.fc3(out)
        return out

In [12]:
# class CNN(nn.Module):
#     def __init__(self, in_channel=3, num_classes=10):
#         super().__init__()
#         self.conv1 = nn.Conv2d(in_channel, 64, 5)
#         self.conv2 = nn.Conv2d(64, 64, 5)
#         self.fc1 = nn.Linear(64 * 5 * 5, 384)
#         self.fc2 = nn.Linear(384, 192)
#         self.fc3 = nn.Linear(192, num_classes)
        
#     def forward(self, x):
#         out = F.max_pool2d(F.relu(self.conv1(x)), 2, 2)
#         out = F.max_pool2d(F.relu(self.conv2(out)), 2, 2)
#         out = out.view(-1, 64 * 5 * 5)
#         out = F.relu(self.fc1(out))
#         out = F.relu(self.fc2(out))
#         out = self.fc3(out)
#         return out

In [13]:
def update_avg(w):
    w_avg = copy.deepcopy(w[0])
    
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], float(len(w)))
        
    return w_avg

In [14]:
# def simplenorm(x):
#     return x / np.sum(x)

# def weighting(forget_cnt_per_client, standardize=0.1):
#     if (standardize == 0.0) or (np.sum(forget_cnt_per_client) == 0):
#         weights = np.ones(len(forget_cnt_per_client))
#     else:
#         weights = 1 - standardize + len(forget_cnt_per_client) * standardize * simplenorm(forget_cnt_per_client)
#     return weights

In [15]:
# def update_wavg(w, standardized_weights):
#     w_avg = copy.deepcopy(w[0])
    
#     w_avg.update((k, v * standardized_weights[0]) for k, v in w_avg.items())
    
#     for key in w_avg.keys():
#         for i in range(1, len(w)):
#             w_avg[key] += (w[i][key] * standardized_weights[i])
#         w_avg[key] = torch.div(w_avg[key], float(len(w)))
        
#     return w_avg

In [16]:
def train(model, loader, epochs, lr, weight_decay, criterion, device):
    epoch_losses = []
    
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    model.train()
    
    for ep in range(epochs):
        batch_losses = []
        
        for inputs, labels, _ in loader:
            inputs, labels = inputs.to(device), labels.to(device)

            model.zero_grad()
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss = loss.mean()
            
            loss.backward()
            optimizer.step()
            
            batch_losses.append(loss.item())

        epoch_losses.append(sum(batch_losses) / len(batch_losses))
        
    local_weights = model.state_dict()

    return local_weights, epoch_losses

In [17]:
def inference(model, loader, criterion, device):
    avg_loss, correct, num_samples = 0, 0, 0

    model.eval()
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)

            loss = criterion(outputs, labels)
            loss = loss.mean()
            avg_loss += loss.item()
            
            _, predicted = torch.max(outputs, 1)
            correct += torch.sum(torch.eq(predicted, labels)).item()
            num_samples += len(labels)

    acc = correct / num_samples
    avg_loss /= len(loader)
    
    return acc, avg_loss

In [18]:
def infer_train(model, loader, device, match_history, flag=False):
    forgettables = []
    correct, total = 0, 0

    model.eval()
    
    with torch.no_grad():
        for inputs, labels, indices in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            match = predicted.eq(labels)
            match = match.type(torch.IntTensor)
            total += len(indices)
            
            for j, idx in enumerate(indices):
                sample_history = match_history.get(idx.item(), [])
                sample_history.append(match[j].item())
                match_history[idx.item()] = sample_history
                correct += match[j].item()
                if flag is True:
                    try:
                        if match[j].item() - sample_history[-2] == -1:
                            forgettables.append(idx.item())
                    except IndexError:
                        continue
            
    acc = correct / total
            
    return match_history, forgettables, acc

In [19]:
def formatfloats(li):
    new = [float(f'{e:>8.4f}') for e in li]
    return new

In [20]:
# resnet9 = {'block': ResidualBlock, 'num_blocks': [1, 1, 1, 1]}
# resnet18 = {'block': ResidualBlock, 'num_blocks': [2, 2, 2, 2]}
# resnet34 = {'block': ResidualBlock, 'num_blocks': [3, 4, 6, 3]}
# resnet50 = {'block': Bottleneck, 'num_blocks': [3, 4, 6, 3]}
# resnet101 = {'block': Bottleneck, 'num_blocks': [3, 4, 23, 3]}
# resnet152 = {'block': Bottleneck, 'num_blocks': [3, 8, 36, 3]}

In [21]:
device = torch.device(f'cuda:{gpu}' if torch.cuda.is_available() else 'cpu')

In [22]:
num_classes = 10

fed_model = MLP()
fed_weights = fed_model.state_dict()

fed_model.to(device)
criterion = nn.CrossEntropyLoss(reduction='none').to(device)

In [23]:
rounds = 1000
epochs = 10
lr = 0.01
wdecay = 0.0001

logger.info(f'\nAlgorithm: FedAvg Dynamic Epoch\nClients: {num_clients}\nDataset: {dataset}\nModel: MLP | Rounds: {rounds} | Epochs: {epochs} | LR: {lr}\n')

train_accs, train_losses, test_accs, test_losses = [], [], [], []
match_history, round_forget_history = {}, []
# forget_cnt_per_client = [0] * len(indices.keys())

st = time.time()
curr_participants = np.random.choice(num_clients, 10)

for r in range(rounds):
    
    forget_cnt_per_client = [0] * 10
    
    round_forget, round_samples = 0, 0
    forget_history, forgettables = {}, {}
    local_weights, local_losses = [], []
    logger.info(f' | Global Training Round : {r + 1} / {rounds} |')
    logger.info(f' | Current Participants : {sorted(curr_participants.tolist())} |')

    next_participants = np.random.choice(num_clients, 10)
    logger.info(f' |    Next Participants : {sorted(next_participants.tolist())} |')
    
    tmp_cnt = 0
    for i, k in enumerate(indices.keys()):
        if int(k) in curr_participants:
            match_history, forgettables, global_acc = infer_train(fed_model, inferloaders[k], device, match_history, flag=True)
            
            forget_cnt = len(forgettables)
            round_forget += forget_cnt
            forget_cnt_per_client[tmp_cnt] = forget_cnt
            tmp_cnt += 1
            round_samples += len(indices[k]['index'])

            fed_model.train()

            sampler_idx = indices[k]['index'].copy()
            sampler = SubsetRandomSampler(sampler_idx)

            trainloader = DataLoader(custom_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)

            local_model = copy.deepcopy(fed_model)
            if r < 1:
                dyn_epochs = epochs
            elif forget_cnt > 0:
                dyn_epochs = int(epochs * forget_cnt // len(sampler_idx))
            else:
                dyn_epochs = 1
            w, ls = train(local_model, trainloader, dyn_epochs, lr, wdecay, criterion, device)

            local_weights.append(copy.deepcopy(w))
            train_losses.append(ls)

            match_history, _, local_acc = infer_train(local_model, inferloaders[k], device, match_history, flag=False)
            train_accs.append(local_acc)
            
            logger.info('  |-- [Party {:>2}] Average Train Loss: {:>8.4f} Train Accuracy: global {:>6.2f}% local {:>6.2f}% ... {:>4} forgettables out of {:>4} ({:>5.2f}%) ... total data used {:>4}'.format(
                k, sum(ls) / len(ls), 100 * global_acc, 100 * local_acc, forget_cnt, len(indices[k]['index']), 100 * forget_cnt / len(indices[k]['index']), len(sampler_idx)
            ))
            logger.info('  |--    Epoch Losses ({:>2}): {}'.format(dyn_epochs, formatfloats(ls)))
            
        elif int(k) in next_participants:
            match_history, forgettables, global_acc = infer_train(fed_model, inferloaders[k], device, match_history, flag=True)
            
            fed_model.train()

            sampler_idx = indices[k]['index'].copy()
            sampler = SubsetRandomSampler(sampler_idx)

            trainloader = DataLoader(custom_dataset, batch_size=batch_size, sampler=sampler, num_workers=num_workers)

            local_model = copy.deepcopy(fed_model)
            if r < 1:
                dyn_epochs = epochs
            elif forget_cnt > 0:
                dyn_epochs = epochs * forget_cnt // len(indices[k]['index'])
            else:
                dyn_epochs = 1
            _, _ = train(local_model, trainloader, dyn_epochs, lr, wdecay, criterion, device)
            
            match_history, _, local_acc = infer_train(local_model, inferloaders[k], device, match_history, flag=False)

            logger.info('  |-- [Party {:>2}] Training for forgettable counting in the next round'.format(k))
        
#     standardized_weights = weighting(forget_cnt_per_client, standardize)
    fed_weights = update_avg(local_weights)
    fed_model.load_state_dict(fed_weights)
#     train_acc, _ = inference(fed_model, fed_trainloader, criterion, device)
    
    curr_participants = next_participants.copy()
    
    if (r + 1) % 100 == 0:
        torch.save(fed_model.state_dict(), os.path.join(path_models, filename + f'_round{r+1}.pth'))
    
    test_acc, test_ls = inference(fed_model, testloader, criterion, device)
    test_accs.append(test_acc)
    test_losses.append(test_ls)
    round_forget_history.append(round_forget)
    logger.info('    |---- Number of Forgettables: {} ({:.2f}%)'.format(round_forget, 100 * round_forget / round_samples))
#     logger.info('    |---- Train Accuracy: {:>.2f}%'.format(100 * train_acc))
    logger.info('    |---- Test Accuracy: {:>.2f}%'.format(100 * test_acc))
    logger.info('    |---- Test Loss: {:.4f}'.format(test_ls))
    logger.info('    |---- Elapsed time: {}'.format(timedelta(seconds=time.time()-st)))
    logger.info(f'\nTest Acc: Highest {np.max(test_accs) * 100:.4f}% ({np.argmax(test_accs)+1} round) | Avg {np.mean(test_accs) * 100:.4f}% ({np.argmax(test_accs > np.mean(test_accs))+1} round) | Curr {test_accs[-1] * 100:.4f}%')

2021-09-23 01:05:07,404 - INFO - 
Algorithm: FedAvg Dynamic Epoch
Clients: 100
Dataset: mnist
Model: MLP | Rounds: 1000 | Epochs: 10 | LR: 0.001

2021-09-23 01:05:07,410 - INFO -  | Global Training Round : 1 / 1000 |
2021-09-23 01:05:07,411 - INFO -  | Current Participants : [14, 20, 51, 60, 71, 74, 74, 82, 86, 92] |
2021-09-23 01:05:07,412 - INFO -  |    Next Participants : [1, 2, 21, 23, 29, 37, 52, 87, 87, 99] |
2021-09-23 01:05:10,598 - INFO -   |-- [Party  1] Training for forgettable counting in the next round
2021-09-23 01:05:13,288 - INFO -   |-- [Party  2] Training for forgettable counting in the next round
2021-09-23 01:05:15,971 - INFO -   |-- [Party 14] Average Train Loss:   0.7118 Train Accuracy: global  14.83% local  78.50% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:15,976 - INFO -   |--    Epoch Losses (10): [1.465, 0.7606, 0.7266, 0.6833, 0.639, 0.592, 0.5845, 0.5644, 0.5597, 0.5426]


10
600


2021-09-23 01:05:18,528 - INFO -   |-- [Party 20] Average Train Loss:   0.9734 Train Accuracy: global   7.67% local  62.33% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:18,529 - INFO -   |--    Epoch Losses (10): [1.6949, 1.0947, 0.9721, 0.9122, 0.8864, 0.8673, 0.8558, 0.8161, 0.8366, 0.7983]


10
600


2021-09-23 01:05:21,086 - INFO -   |-- [Party 21] Training for forgettable counting in the next round
2021-09-23 01:05:23,632 - INFO -   |-- [Party 23] Training for forgettable counting in the next round
2021-09-23 01:05:26,295 - INFO -   |-- [Party 29] Training for forgettable counting in the next round
2021-09-23 01:05:28,957 - INFO -   |-- [Party 37] Training for forgettable counting in the next round
2021-09-23 01:05:31,218 - INFO -   |-- [Party 51] Average Train Loss:   1.3379 Train Accuracy: global   0.00% local  63.33% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:31,220 - INFO -   |--    Epoch Losses (10): [2.0357, 1.554, 1.4501, 1.3679, 1.3132, 1.2396, 1.1694, 1.1429, 1.053, 1.0535]


10
600


2021-09-23 01:05:33,441 - INFO -   |-- [Party 52] Training for forgettable counting in the next round
2021-09-23 01:05:35,738 - INFO -   |-- [Party 60] Average Train Loss:   1.1755 Train Accuracy: global  17.50% local  65.33% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:35,741 - INFO -   |--    Epoch Losses (10): [1.6845, 1.3057, 1.2061, 1.1637, 1.1228, 1.1152, 1.0566, 1.0542, 1.0338, 1.0129]


10
600


2021-09-23 01:05:37,971 - INFO -   |-- [Party 71] Average Train Loss:   0.6470 Train Accuracy: global  11.67% local  87.00% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:37,973 - INFO -   |--    Epoch Losses (10): [1.484, 0.7927, 0.7059, 0.6565, 0.5421, 0.5363, 0.4898, 0.4471, 0.425, 0.3901]


10
600


2021-09-23 01:05:40,224 - INFO -   |-- [Party 74] Average Train Loss:   0.6190 Train Accuracy: global  37.50% local  85.83% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:40,225 - INFO -   |--    Epoch Losses (10): [1.2638, 0.7031, 0.5763, 0.5601, 0.5218, 0.5252, 0.5151, 0.5123, 0.5051, 0.5069]


10
600


2021-09-23 01:05:42,386 - INFO -   |-- [Party 82] Average Train Loss:   0.9861 Train Accuracy: global  25.67% local  77.00% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:42,387 - INFO -   |--    Epoch Losses (10): [1.4475, 1.1044, 1.0293, 0.9623, 0.9474, 0.9285, 0.8895, 0.8752, 0.8531, 0.8239]


10
600


2021-09-23 01:05:44,891 - INFO -   |-- [Party 86] Average Train Loss:   1.0572 Train Accuracy: global   2.17% local  69.67% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:44,893 - INFO -   |--    Epoch Losses (10): [1.8943, 1.1681, 1.1294, 1.0321, 0.9677, 0.9301, 0.897, 0.8576, 0.8628, 0.8332]


10
600


2021-09-23 01:05:47,437 - INFO -   |-- [Party 87] Training for forgettable counting in the next round
2021-09-23 01:05:49,987 - INFO -   |-- [Party 92] Average Train Loss:   0.3427 Train Accuracy: global   0.00% local  95.17% ...    0 forgettables out of  600 ( 0.00%) ... total data used  600
2021-09-23 01:05:49,990 - INFO -   |--    Epoch Losses (10): [1.1386, 0.4112, 0.343, 0.2485, 0.2231, 0.2335, 0.2128, 0.2129, 0.2039, 0.1993]


10
600


2021-09-23 01:05:52,523 - INFO -   |-- [Party 99] Training for forgettable counting in the next round
2021-09-23 01:05:54,116 - INFO -     |---- Number of Forgettables: 0 (0.00%)
2021-09-23 01:05:54,117 - INFO -     |---- Test Accuracy: 15.46%
2021-09-23 01:05:54,119 - INFO -     |---- Test Loss: 2.3968
2021-09-23 01:05:54,120 - INFO -     |---- Elapsed time: 0:00:46.711908
2021-09-23 01:05:54,123 - INFO - 
Test Acc: Highest 15.4600% (1 round) | Avg 15.4600% (1 round) | Curr 15.4600%
2021-09-23 01:05:54,124 - INFO -  | Global Training Round : 2 / 1000 |
2021-09-23 01:05:54,125 - INFO -  | Current Participants : [1, 2, 21, 23, 29, 37, 52, 87, 87, 99] |
2021-09-23 01:05:54,127 - INFO -  |    Next Participants : [1, 20, 21, 32, 48, 57, 59, 63, 75, 88] |
2021-09-23 01:05:56,136 - INFO -   |-- [Party  1] Average Train Loss:   0.8204 Train Accuracy: global   0.17% local  80.17% ...  481 forgettables out of  600 (80.17%) ... total data used  600
2021-09-23 01:05:56,138 - INFO -   |--    Epoch

8
600


2021-09-23 01:05:58,268 - INFO -   |-- [Party  2] Average Train Loss:   0.4621 Train Accuracy: global   0.33% local  91.67% ...  548 forgettables out of  600 (91.33%) ... total data used  600
2021-09-23 01:05:58,269 - INFO -   |--    Epoch Losses ( 9): [1.2624, 0.5375, 0.4409, 0.3847, 0.3593, 0.3387, 0.2978, 0.2934, 0.2438]


9
600


2021-09-23 01:06:00,595 - INFO -   |-- [Party 20] Training for forgettable counting in the next round
2021-09-23 01:06:02,072 - INFO -   |-- [Party 21] Average Train Loss:   1.3387 Train Accuracy: global   7.67% local  58.50% ...  353 forgettables out of  600 (58.83%) ... total data used  600
2021-09-23 01:06:02,073 - INFO -   |--    Epoch Losses ( 5): [1.8624, 1.2826, 1.2437, 1.1725, 1.1325]


5
600


2021-09-23 01:06:03,985 - INFO -   |-- [Party 23] Average Train Loss:   0.8866 Train Accuracy: global   0.50% local  78.17% ...  469 forgettables out of  600 (78.17%) ... total data used  600
2021-09-23 01:06:03,987 - INFO -   |--    Epoch Losses ( 7): [1.6278, 1.0021, 0.8177, 0.7819, 0.7137, 0.647, 0.6159]


7
600


2021-09-23 01:06:06,107 - INFO -   |-- [Party 29] Average Train Loss:   0.7478 Train Accuracy: global   0.00% local  82.67% ...  487 forgettables out of  600 (81.17%) ... total data used  600
2021-09-23 01:06:06,109 - INFO -   |--    Epoch Losses ( 8): [1.7712, 0.7178, 0.689, 0.6412, 0.5851, 0.5475, 0.5209, 0.5096]


8
600


2021-09-23 01:06:08,642 - INFO -   |-- [Party 32] Training for forgettable counting in the next round
2021-09-23 01:06:10,207 - INFO -   |-- [Party 37] Average Train Loss:   0.5370 Train Accuracy: global  52.33% local  92.50% ...  249 forgettables out of  600 (41.50%) ... total data used  600
2021-09-23 01:06:10,209 - INFO -   |--    Epoch Losses ( 4): [0.8707, 0.4683, 0.4349, 0.374]


4
600


2021-09-23 01:06:11,298 - INFO -   |-- [Party 48] Training for forgettable counting in the next round
2021-09-23 01:06:13,256 - INFO -   |-- [Party 52] Average Train Loss:   0.6537 Train Accuracy: global   0.17% local  84.17% ...  493 forgettables out of  600 (82.17%) ... total data used  600
2021-09-23 01:06:13,257 - INFO -   |--    Epoch Losses ( 8): [1.3311, 0.7384, 0.5975, 0.567, 0.5222, 0.5096, 0.482, 0.4822]


8
600


2021-09-23 01:06:15,273 - INFO -   |-- [Party 57] Training for forgettable counting in the next round
2021-09-23 01:06:17,500 - INFO -   |-- [Party 59] Training for forgettable counting in the next round
2021-09-23 01:06:19,728 - INFO -   |-- [Party 63] Training for forgettable counting in the next round
2021-09-23 01:06:21,911 - INFO -   |-- [Party 75] Training for forgettable counting in the next round
2021-09-23 01:06:23,488 - INFO -   |-- [Party 87] Average Train Loss:   0.8408 Train Accuracy: global  15.50% local  79.17% ...  400 forgettables out of  600 (66.67%) ... total data used  600
2021-09-23 01:06:23,489 - INFO -   |--    Epoch Losses ( 6): [1.3311, 0.8745, 0.7934, 0.7247, 0.6972, 0.6238]


6
600


2021-09-23 01:06:24,926 - INFO -   |-- [Party 88] Training for forgettable counting in the next round
2021-09-23 01:06:26,078 - INFO -   |-- [Party 99] Average Train Loss:   0.1815 Train Accuracy: global  50.83% local 100.00% ...  295 forgettables out of  600 (49.17%) ... total data used  600
2021-09-23 01:06:26,079 - INFO -   |--    Epoch Losses ( 4): [0.7217, 0.0042, 0.0, 0.0]


4
600


2021-09-23 01:06:27,737 - INFO -     |---- Number of Forgettables: 3775 (69.91%)
2021-09-23 01:06:27,739 - INFO -     |---- Test Accuracy: 10.17%
2021-09-23 01:06:27,740 - INFO -     |---- Test Loss: 2.8718
2021-09-23 01:06:27,740 - INFO -     |---- Elapsed time: 0:01:20.331970
2021-09-23 01:06:27,742 - INFO - 
Test Acc: Highest 15.4600% (1 round) | Avg 12.8150% (1 round) | Curr 10.1700%
2021-09-23 01:06:27,743 - INFO -  | Global Training Round : 3 / 1000 |
2021-09-23 01:06:27,744 - INFO -  | Current Participants : [1, 20, 21, 32, 48, 57, 59, 63, 75, 88] |
2021-09-23 01:06:27,745 - INFO -  |    Next Participants : [14, 41, 46, 58, 59, 61, 61, 79, 90, 91] |
2021-09-23 01:06:29,248 - INFO -   |-- [Party  1] Average Train Loss:   0.7238 Train Accuracy: global   9.83% local  80.17% ...  423 forgettables out of  600 (70.50%) ... total data used  600
2021-09-23 01:06:29,250 - INFO -   |--    Epoch Losses ( 7): [1.1389, 0.7134, 0.6733, 0.6497, 0.6322, 0.6416, 0.6178]


7
600


2021-09-23 01:06:30,821 - INFO -   |-- [Party 14] Training for forgettable counting in the next round
2021-09-23 01:06:32,170 - INFO -   |-- [Party 20] Average Train Loss:   1.1301 Train Accuracy: global   1.50% local  62.17% ...  384 forgettables out of  600 (64.00%) ... total data used  600
2021-09-23 01:06:32,172 - INFO -   |--    Epoch Losses ( 6): [2.0571, 1.1206, 0.9426, 0.9291, 0.8723, 0.859]


6
600


2021-09-23 01:06:33,477 - INFO -   |-- [Party 21] Average Train Loss:   1.2511 Train Accuracy: global   2.00% local  58.67% ...  351 forgettables out of  600 (58.50%) ... total data used  600
2021-09-23 01:06:33,479 - INFO -   |--    Epoch Losses ( 5): [1.6573, 1.2238, 1.174, 1.1194, 1.0808]


5
600


2021-09-23 01:06:35,168 - INFO -   |-- [Party 32] Average Train Loss:   0.8408 Train Accuracy: global   2.67% local  79.00% ...  473 forgettables out of  600 (78.83%) ... total data used  600
2021-09-23 01:06:35,170 - INFO -   |--    Epoch Losses ( 7): [1.5818, 0.826, 0.7837, 0.728, 0.684, 0.646, 0.6361]


7
600


2021-09-23 01:06:36,814 - INFO -   |-- [Party 41] Training for forgettable counting in the next round
2021-09-23 01:06:38,495 - INFO -   |-- [Party 46] Training for forgettable counting in the next round
2021-09-23 01:06:40,151 - INFO -   |-- [Party 48] Average Train Loss:   0.8618 Train Accuracy: global   9.67% local  79.50% ...  427 forgettables out of  600 (71.17%) ... total data used  600
2021-09-23 01:06:40,153 - INFO -   |--    Epoch Losses ( 7): [1.2383, 0.8947, 0.8303, 0.8013, 0.7686, 0.7565, 0.743]


7
600


2021-09-23 01:06:41,799 - INFO -   |-- [Party 57] Average Train Loss:   1.1627 Train Accuracy: global   0.00% local  71.33% ...  435 forgettables out of  600 (72.50%) ... total data used  600
2021-09-23 01:06:41,800 - INFO -   |--    Epoch Losses ( 7): [2.5621, 1.2554, 0.977, 0.8983, 0.8631, 0.8301, 0.7529]


7
600


2021-09-23 01:06:43,474 - INFO -   |-- [Party 58] Training for forgettable counting in the next round
2021-09-23 01:06:45,211 - INFO -   |-- [Party 59] Average Train Loss:   0.7509 Train Accuracy: global   8.67% local  83.00% ...  447 forgettables out of  600 (74.50%) ... total data used  600
2021-09-23 01:06:45,212 - INFO -   |--    Epoch Losses ( 7): [1.1002, 0.7738, 0.7039, 0.6853, 0.6602, 0.6552, 0.6773]


7
600


2021-09-23 01:06:47,128 - INFO -   |-- [Party 61] Training for forgettable counting in the next round
2021-09-23 01:06:49,466 - INFO -   |-- [Party 63] Average Train Loss:   0.3140 Train Accuracy: global   0.00% local  98.67% ...  592 forgettables out of  600 (98.67%) ... total data used  600
2021-09-23 01:06:49,468 - INFO -   |--    Epoch Losses ( 9): [1.9066, 0.1875, 0.1171, 0.1163, 0.1273, 0.1234, 0.0978, 0.0764, 0.0737]


9
600


2021-09-23 01:06:51,389 - INFO -   |-- [Party 75] Average Train Loss:   1.1263 Train Accuracy: global   5.17% local  70.67% ...  463 forgettables out of  600 (77.17%) ... total data used  600
2021-09-23 01:06:51,390 - INFO -   |--    Epoch Losses ( 7): [2.3419, 1.2356, 0.9714, 0.9029, 0.8619, 0.8062, 0.7642]


7
600


2021-09-23 01:06:53,295 - INFO -   |-- [Party 79] Training for forgettable counting in the next round
2021-09-23 01:06:55,217 - INFO -   |-- [Party 88] Average Train Loss:   0.8063 Train Accuracy: global   0.00% local  74.67% ...  426 forgettables out of  600 (71.00%) ... total data used  600
2021-09-23 01:06:55,219 - INFO -   |--    Epoch Losses ( 7): [1.3941, 0.7834, 0.7634, 0.716, 0.681, 0.6814, 0.6246]


7
600


2021-09-23 01:06:57,137 - INFO -   |-- [Party 90] Training for forgettable counting in the next round
2021-09-23 01:06:59,053 - INFO -   |-- [Party 91] Training for forgettable counting in the next round
2021-09-23 01:07:00,639 - INFO -     |---- Number of Forgettables: 4421 (73.68%)
2021-09-23 01:07:00,642 - INFO -     |---- Test Accuracy: 11.66%
2021-09-23 01:07:00,643 - INFO -     |---- Test Loss: 2.6040
2021-09-23 01:07:00,644 - INFO -     |---- Elapsed time: 0:01:53.235548
2021-09-23 01:07:00,645 - INFO - 
Test Acc: Highest 15.4600% (1 round) | Avg 12.4300% (1 round) | Curr 11.6600%
2021-09-23 01:07:00,646 - INFO -  | Global Training Round : 4 / 1000 |
2021-09-23 01:07:00,647 - INFO -  | Current Participants : [14, 41, 46, 58, 59, 61, 61, 79, 90, 91] |
2021-09-23 01:07:00,648 - INFO -  |    Next Participants : [2, 6, 20, 38, 50, 50, 54, 61, 63, 72] |
2021-09-23 01:07:02,310 - INFO -   |-- [Party  2] Training for forgettable counting in the next round
2021-09-23 01:07:04,151 - INFO

8
600


2021-09-23 01:07:07,949 - INFO -   |-- [Party 20] Training for forgettable counting in the next round
2021-09-23 01:07:09,803 - INFO -   |-- [Party 38] Training for forgettable counting in the next round
2021-09-23 01:07:10,702 - INFO -   |-- [Party 41] Average Train Loss:   1.8329 Train Accuracy: global  24.33% local  38.50% ...  220 forgettables out of  600 (36.67%) ... total data used  600
2021-09-23 01:07:10,704 - INFO -   |--    Epoch Losses ( 3): [2.2623, 1.7284, 1.5079]


3
600


2021-09-23 01:07:12,554 - INFO -   |-- [Party 46] Average Train Loss:   0.5707 Train Accuracy: global   1.83% local  86.00% ...  501 forgettables out of  600 (83.50%) ... total data used  600
2021-09-23 01:07:12,555 - INFO -   |--    Epoch Losses ( 8): [0.956, 0.5999, 0.5353, 0.5183, 0.4932, 0.5058, 0.4783, 0.4785]


8
600


2021-09-23 01:07:14,695 - INFO -   |-- [Party 50] Training for forgettable counting in the next round
2021-09-23 01:07:16,820 - INFO -   |-- [Party 54] Training for forgettable counting in the next round
2021-09-23 01:07:18,731 - INFO -   |-- [Party 58] Average Train Loss:   1.0063 Train Accuracy: global   1.00% local  78.67% ...  450 forgettables out of  600 (75.00%) ... total data used  600
2021-09-23 01:07:18,733 - INFO -   |--    Epoch Losses ( 7): [1.559, 1.0859, 1.0071, 0.9273, 0.8583, 0.8098, 0.7966]


7
600
0
600


ZeroDivisionError: division by zero

In [None]:
train_losses = np.asarray(train_losses)
train_accs = np.asarray(train_accs)

with open(os.path.join(path_results, f'{filename}_tr_ls.npy'), 'wb') as f:
    np.save(f, train_losses)
with open(os.path.join(path_results, f'{filename}_tr_acc.npy'), 'wb') as f:
    np.save(f, train_accs)
with open(os.path.join(path_results, f'{filename}_te_ls.npy'), 'wb') as f:
    np.save(f, test_losses)
with open(os.path.join(path_results, f'{filename}_te_acc.npy'), 'wb') as f:
    np.save(f, test_accs)

In [None]:
fig, axs = plt.subplots(nrows=3, ncols=2, figsize=(50, 30))
axs = axs.ravel()

axs[0].plot(test_accs, c='orange')
axs[0].set_title('Test Accuracies')
axs[0].set_xlabel('Rounds')
axs[0].set_ylabel('Test Accuracy')
axs[1].plot(test_losses, c='blue')
axs[1].set_title('Test Losses')
axs[1].set_xlabel('Rounds')
axs[1].set_ylabel('Test Loss')
axs[2].plot(train_accs, c='red')
axs[2].set_title('Train Average Accuracies by Epochs')
axs[2].set_xlabel('Epochs')
axs[2].set_ylabel('Train Average Accuracy')
axs[3].plot(train_losses.mean(axis=1), c='turquoise')
axs[3].set_title('Train Average Losses by Epochs')
axs[3].set_xlabel('Epochs')
axs[3].set_ylabel('Train Average Loss')
axs[4].plot(np.mean(train_accs.reshape(-1, 10), axis=1), c='green')
axs[4].set_title('Train Average Accuracies by Rounds')
axs[4].set_xlabel('Rounds')
axs[4].set_ylabel('Train Average Accuracy')
axs[5].plot(train_losses.mean(axis=1).reshape(-1, 10).mean(axis=1), c='lightpink')
axs[5].set_title('Train Average Losses by Rounds')
axs[5].set_xlabel('Rounds')
axs[5].set_ylabel('Train Average Loss')

plt.show()