In [None]:
# Uncomment if running on googlecolab 
# !pip install hickle
# from google.colab import drive
# drive.mount('/content/drive/')
# %cd drive/MyDrive/PerCom2021-FL-master/

In [None]:
import pickle
import copy
from sklearn.preprocessing import normalize
from scipy.optimize import linear_sum_assignment
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
import hickle as hkl 
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from lapsolver import solve_dense
import csv
import logging
import numpy as np
import os
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
import argparse
import json
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.utils.data as data
from itertools import product
import math
import copy
import time
import torch.nn.functional as F
from sklearn.metrics import confusion_matrix

logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

args_logdir = "logs/cifar10"
#args_dataset = "cifar10"
args_datadir = "./data/cifar10"
args_init_seed = 0
args_net_config = [3072, 100, 8]
#args_partition = "hetero-dir"
args_partition = "homo"
args_experiment = ["u-ensemble", "pdm"]
args_trials = 1
#args_lr = 0.01
args_epochs = 10
args_reg = 1e-5
args_alpha = 0.5
args_iter_epochs = None

args_pdm_sig = 1.0
args_pdm_sig0 = 1.0
args_pdm_gamma = 1.0

In [None]:


# which GPU to use
# "-1,0,1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# DNN,CNN
modelType = "CNN"

#prior informatin for model type 
layerType = [0,1,1]

# algorithm .to('cpu)= "FEDAVG,FEDPER"
algorithm = "FEDMA"

# REALWORLD_CLIENT
dataSetName = 'REALWORLD_CLIENT'

#BALANCED, UNBALANCED
dataConfig = "BALANCED"

#ADAM, SGD
optimizer = "SGD"

#0, 1
ClientAllTest = True

# Save the client models a .h5 file
savedClientModel = 0

# Show training verbose
showTrainVerbose = 0

# input window size 
segment_size = 128

# input channel count
num_input_channels = 6

# client learning rate
learningRate = 0.01

# model drop out rate
dropout_rate = 0.5

# local epoch
localEpoch = 5

# communication round
communicationRound = 200

# CNN kernal size
kernelSize = 16

# Neuron distance measurement 
euclid = True

# Seed for data partioning and PyTorch training
randomSeed = 1

# FedMA parameter

iteration = 5

gammaValue = float(iteration + 2)


In [None]:
if(dataSetName == 'UCI'):
    ACTIVITY_LABEL = ['WALKING', 'WALKING_UPSTAIRS','WALKING_DOWNSTAIRS', 'SITTING', 'STANDING', 'LAYING']
else:
    ACTIVITY_LABEL = ['climbingdown', 'climbingup', 'jumping','lying', 'running', 'sitting', 'standing', 'walking']
    
activityCount = len(ACTIVITY_LABEL)

if(modelType == "DNN"):
    architectureType = str(algorithm)+'_'+str(learningRate)+'LR_'+str(localEpoch)+'LE_'+str(communicationRound)+'CR_400D_100D_'+str(dataSetName)+"_IT"+str(iteration)+"_NOTSAME"
else: 
    architectureType = str(algorithm)+'_'+str(learningRate)+'LR_'+str(localEpoch)+'LE_'+str(communicationRound)+'CR_196-16C_4M_1024D_'+str(dataSetName)+"_IT"+str(iteration)+"__NOTSAME"
mainDir = ''
filepath = mainDir + 'savedModels/'+architectureType+'/'+dataSetName+'/'
os.makedirs(filepath, exist_ok=True)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1'

if(dataSetName=='UCI'):
    clientCount = 5
else:
    clientCount = 15
    
np.random.seed(randomSeed)
torch.manual_seed(randomSeed)

In [None]:

parser = argparse.ArgumentParser(description='Probabilistic Federated CNN Matching')

parser.add_argument('--model', type=str, default='lenet', metavar='N',
                    help='neural network used in training')
parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                    help='dataset used for training')
parser.add_argument('--partition', type=str, default='homo', metavar='N',
                    help='how to partition the dataset on local workers')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 64)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                    help='learning rate (default: 0.01)')
parser.add_argument('--retrain_lr', type=float, default=0.1, metavar='RLR',
                    help='learning rate using in specific for local network retrain (default: 0.01)')
parser.add_argument('--fine_tune_lr', type=float, default=0.1, metavar='FLR',
                    help='learning rate using in specific for fine tuning the softmax layer on the data center (default: 0.01)')
parser.add_argument('--epochs', type=int, default=5, metavar='EP',
                    help='how many epochs will be trained in a training process')
parser.add_argument('--retrain_epochs', type=int, default=10, metavar='REP',
                    help='how many epochs will be trained in during the locally retraining process')
parser.add_argument('--fine_tune_epochs', type=int, default=10, metavar='FEP',
                    help='how many epochs will be trained in during the fine tuning process')
parser.add_argument('--partition_step_size', type=int, default=6, metavar='PSS',
                    help='how many groups of partitions we will have')
parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
                    help='the approximate fixed number of data points we will have on each local worker')
parser.add_argument('--partition_step', type=int, default=0, metavar='PS',
                    help='how many sub groups we are going to use for a particular training process')
parser.add_argument('--n_nets', type=int, default=2, metavar='NN',
                    help='number of workers in a distributed cluster')
parser.add_argument('--oneshot_matching', type=bool, default=False, metavar='OM',
                    help='if the code is going to conduct one shot matching')
parser.add_argument('--retrain', type=bool, default=False,
                    help='whether to retrain the model or load model locally')
parser.add_argument('--rematching', type=bool, default=False,
                    help='whether to recalculating the matching process (this is for speeding up the debugging process)')
parser.add_argument('--comm_type', type=str, default='layerwise',
                    help='which type of communication strategy is going to be used: layerwise/blockwise')
parser.add_argument('--comm_round', type=int, default=10,
                    help='how many round of communications we shoud use')


temp = ['--model=simple-cnn',
 '--dataset=cifar10',
 '--lr='+str(learningRate),
 '--retrain_lr='+str(learningRate),
 '--batch-size=64',
 '--epochs='+str(localEpoch),
 '--retrain_epochs='+str(localEpoch),
 '--n_nets='+str(clientCount),
 '--partition=hetero-dir',
 '--comm_type=fedma',
 '--comm_round='+str(communicationRound),
 '--retrain=True',
 '--rematching=True']
args = parser.parse_args(temp)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim=activityCount):
        super(SimpleCNN, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv1d(num_input_channels, 196, kernel_size=kernelSize, bias=True),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(4),
        )
        self.fc_layer = nn.Sequential(
            nn.Linear(5488, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(1024, output_dim)
        )

    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layer(x)
        return x


class SimpleCNNContainer(nn.Module):
    def __init__(self, input_channel, num_filters, kernel_size, input_dim, hidden_dims, output_dim=activityCount):
        super(SimpleCNNContainer, self).__init__()
        self.conv_layer = nn.Sequential(
            nn.Conv1d(input_channel, num_filters[0], kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(4),
        )
        self.neurons = self.linear_input_neurons()
#         logger.info("neurons {}".format(self.neurons))
        self.fc_layer = nn.Sequential(
            nn.Linear(self.neurons, hidden_dims[0]),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout_rate),
            nn.Linear(hidden_dims[0],output_dim)
        )

    def forward(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
#         logger.info("x.size(0) {}".format(x.size(0)))
        x = self.fc_layer(x)
        return x

    def forward_conv(self, x):
        x = self.conv_layer(x)
        x = x.view(x.size(0), -1)
        return x

    def size_after_relu(self, x):
        x = self.conv_layer(x)
        return x.size()

    def linear_input_neurons(self):
        size = self.size_after_relu(torch.rand(1, 6, 128))  # image size: 64x32
        m = 1
        for i in size:
            m *= i
        return int(m)

    
class SimpleCNNContainerConvBlocks(nn.Module):
    def __init__(self, input_channel, num_filters, kernel_size, output_dim=activityCount):
        super(SimpleCNNContainerConvBlocks, self).__init__()
        
        self.conv_layer = nn.Sequential(
            nn.Conv1d(input_channel, num_filters[0], kernel_size),
            nn.ReLU(inplace=True),
            nn.MaxPool1d(4),
        )
    def forward(self, x):
        x = self.conv_layer(x)
        return x

In [None]:
clientDataTrain = []
clientLabelTrain = []
clientDataTest = []
clientLabelTest = []

centralTrainData = []
centralTrainLabel = []

centralTestData = []
centralTestLabel = []


In [None]:
if(dataSetName == "UCI"):
    def load_file(filepath):
        dataframe = pd.read_csv(filepath, header=None)
        return dataframe.values


    def load_group(filenames, prefix=''):
        loaded = list()
        for name in filenames:
            data = load_file(prefix + name)
            loaded.append(data)
        loaded = np.dstack(loaded)
        return loaded


    def load_dataset(group, prefix=''):
        filepath = mainDir + 'datasetStandardized/'+prefix + '/' + group + '/'
        filenames = list()
        filenames += ['AccX'+prefix+'.csv', 'AccY' +
                      prefix+'.csv', 'AccZ'+prefix+'.csv']
        filenames += ['GyroX'+prefix+'.csv', 'GyroY' +
                      prefix+'.csv', 'GyroZ'+prefix+'.csv']
        X = load_group(filenames, filepath)
        y = load_file(mainDir + 'datasetStandardized/'+prefix +
                      '/' + group + '/Label'+prefix+'.csv')
        return X, y
    trainData, trainLabel = load_dataset('train', dataSetName)
    evalData, evalLabel = load_dataset('eval', dataSetName)
    allData = np.float32(np.vstack((trainData, evalData)))
    allLabel = np.vstack((trainLabel, evalLabel))

    # split data into 80 - 20 
    skf = StratifiedKFold(n_splits=5,shuffle = True)
    skf.get_n_splits(allData, allLabel)
    partitionedData = list()
    partitionedLabel = list()
    for train_index, test_index in skf.split(allData, allLabel):
        partitionedData.append(allData[test_index])
        partitionedLabel.append(allLabel[test_index])

    centralTrainData = np.vstack((partitionedData[:4]))
    centralTrainLabel = np.vstack((partitionedLabel[:4]))
    centralTestData = partitionedData[4]
    centralTestLabel = partitionedLabel[4]

    trainData = list()
    trainLabel = list()
    testData = list()
    testLabel = list()

    if(dataConfig == "BALANCED"):
        skf = StratifiedKFold(n_splits=clientCount)
        skf.get_n_splits(centralTrainData, centralTrainLabel)
        for train_index, test_index in skf.split(centralTrainData, centralTrainLabel):
            trainData.append(np.asarray(centralTrainData[test_index]).reshape(-1,6,128))
            trainLabel.append(centralTrainLabel[test_index].ravel().astype(int))
    else:
    # unbalanced
        kf = KFold(n_splits=clientCount, shuffle=True)
        kf.get_n_splits(centralTrainData)
        for train_index, test_index in kf.split(centralTrainData):
            trainData.append(np.asarray(centralTrainData[test_index]).reshape(-1,6,128))
            trainLabel.append(centralTrainLabel[test_index].ravel().astype(int))

    #slittestSetInto5
    skf.get_n_splits(centralTestData, centralTestLabel)
    for train_index, test_index in skf.split(centralTestData, centralTestLabel):
        testData.append(np.asarray(centralTestData[test_index]).reshape(-1,6,128)[:411])
        testLabel.append(centralTestLabel[test_index].ravel().astype(int)[:411])

    clientDataTrain = np.vstack(trainData) 
    clientLabelTrain = trainLabel
    clientDataTest = np.vstack(testData)
    clientLabelTest = testLabel
    
    centralTrainData = np.float32(np.vstack(clientDataTrain))
    centralTrainLabel = np.vstack(clientLabelTrain).ravel()
    centralTestData = np.float32(np.vstack(clientDataTest))
    centralTestLabel = np.vstack(clientLabelTest).ravel()
    
else:
    clientData = []
    clientLabel = []

    dataSetName = 'REALWORLD_CLIENT'
    for i in range(0,15):
        accX = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/AccX'+dataSetName+'.hkl')
        accY = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/AccY'+dataSetName+'.hkl')
        accZ = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/AccZ'+dataSetName+'.hkl')
        gyroX = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/GyroX'+dataSetName+'.hkl')
        gyroY = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/GyroY'+dataSetName+'.hkl')
        gyroZ = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/GyroZ'+dataSetName+'.hkl')
        label = hkl.load('datasetStandardized/'+dataSetName+'/'+str(i)+'/Label'+dataSetName+'.hkl')
        clientData.append(np.dstack((accX,accY,accZ,gyroX,gyroY,gyroZ)))
        clientLabel.append(label)
    
    if(dataConfig == "BALANCED"):
        for i in range (0,int(args.n_nets)):
            skf = StratifiedKFold(n_splits=5, shuffle=True)
            skf.get_n_splits(clientData[i], clientLabel[i])
            partitionedData = list()
            partitionedLabel = list()    
            for train_index, test_index in skf.split(clientData[i], clientLabel[i]):
                partitionedData.append(clientData[i][test_index])
                partitionedLabel.append(clientLabel[i][test_index])
            clientDataTrain.append(np.float32(np.vstack(partitionedData[:4])).reshape(-1,6,128))
            clientLabelTrain.append((np.hstack((partitionedLabel[:4]))))
            clientDataTest.append(np.float32(partitionedData[4]).reshape(-1,6,128))
            clientLabelTest.append((partitionedLabel[4]))
            
    centralTrainData = np.float32(np.vstack(clientDataTrain))
    centralTrainLabel = (np.hstack(clientLabelTrain))

    centralTestData = np.float32(np.vstack((clientDataTest)))
    centralTestLabel = (np.hstack(clientLabelTest))

In [None]:
traindata_cls_counts = {}
net_dataidx_map = {}

startingIndex = 0 
endingIndex = 0
y_train = []
for i in range(0,args.n_nets):
    traindata_cls_counts[i] = pd.Series(clientLabelTrain[i]).value_counts().sort_index().to_dict()
    startingIndex = endingIndex
    endingIndex = endingIndex + clientDataTrain[i].shape[0]
    y_train.append(clientLabelTrain[i])
    net_dataidx_map[i] = list(range(startingIndex, endingIndex))
    
y_train = np.hstack((y_train))


In [None]:
def get_dataloader(dataset, datadir, train_bs, test_bs, dataidxs=None,clientNum=100):
    if dataset in ('mnist', 'cifar10'):
        dl_obj = CIFAR10_truncated
        if(clientNum == 100):
            train_ds = dl_obj(datadir,centralTrainData,centralTrainLabel, 
                          train=True,download=False)
            test_ds = dl_obj(datadir,centralTestData,centralTestLabel,train=False, download=False)
        else:
            train_ds = dl_obj(datadir,clientDataTrain[clientNum],clientLabelTrain[clientNum], 
                          train=True,download=False)
            test_ds = dl_obj(datadir,clientDataTest[clientNum],clientLabelTest[clientNum],train=False, download=False)
        train_dl = data.DataLoader(
            dataset=train_ds, batch_size=train_bs, shuffle=True)
        test_dl = data.DataLoader(
            dataset=test_ds, batch_size=test_bs, shuffle=False)
    return train_dl, test_dl


In [None]:
def train_net(net_id, net, train_dataloader, test_dataloader, epochs, lr, args, device="cpu"):
    logger.info('Training network %s' % str(net_id))
    logger.info('n_training: %d' % len(train_dataloader))
    logger.info('n_test: %d' % len(test_dataloader))

    train_acc = compute_accuracy(net, train_dataloader, device=device)
    test_acc, conf_matrix = compute_accuracy(
        net, test_dataloader, get_confusion_matrix=True, device=device)

    logger.info('>> Pre-Training Training accuracy: {}'.format(train_acc))
    logger.info('>> Pre-Training Test accuracy: {}'.format(test_acc))

    if args.dataset == "cinic10":
        optimizer = optim.SGD(net.parameters(), lr=lr,
                              momentum=0.9, weight_decay=0.0001)
        scheduler = torch.optim.lr_scheduler.StepLR(
            optimizer, step_size=1, gamma=0.95)
    else:
        optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9)

    criterion = nn.CrossEntropyLoss().to(device)

    cnt = 0
    losses, running_losses = [], []

    for epoch in range(epochs):
        epoch_loss_collector = []
        for batch_idx, (x, target) in enumerate(train_dataloader):
            x, target = x.to(device), target.to(device)

            optimizer.zero_grad()
            x.requires_grad = True
            target.requires_grad = False
            target = target.long()

            out = net(x)
            loss = criterion(out, target)

            loss.backward()
            optimizer.step()

            cnt += 1
            epoch_loss_collector.append(loss.item())

        #logging.debug('Epoch: %d Loss: %f L2 loss: %f' % (epoch, loss.item(), reg*l2_reg))
        epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)
        logger.info('Epoch: %d Loss: %f' % (epoch, epoch_loss))

        if args.dataset == "cinic10":
            scheduler.step()

    train_acc = compute_accuracy(net, train_dataloader, device=device)
    test_acc, conf_matrix = compute_accuracy(
        net, test_dataloader, get_confusion_matrix=True, device=device)

    logger.info('>> Training accuracy: %f' % train_acc)
    logger.info('>> Test accuracy: %f' % test_acc)
    logger.info(' ** Training complete **')
    return train_acc, test_acc



In [None]:
def local_train(nets, args, net_dataidx_map, device="cpu"):
    # save local dataset
    local_datasets = []
    for net_id, net in nets.items():
        if args.retrain:
            dataidxs = net_dataidx_map[net_id]
            # move the model to cuda device:
            net.to(device)

            train_dl_local, test_dl_local = get_dataloader(
                args.dataset, args_datadir, args.batch_size, 32, dataidxs,net_id)

            local_datasets.append((train_dl_local, test_dl_local))

            # switch to global test set here
            trainacc, testacc = train_net(
                net_id, net, train_dl_local, test_dl_local, args.epochs, args.lr, args, device=device)
            # saving the trained models here
    nets_list = list(nets.values())
    return nets_list

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assuming that we are on a CUDA machine, this should print a CUDA device:
logger.info("torch.cuda.is_available: {} device: {}".format(torch.cuda.is_available(),device))

seed = 0

np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
def init_models( n_nets, args):
    '''
    Initialize the local LeNets
    Please note that this part is hard coded right now
    '''

    cnns = {net_i: None for net_i in range(n_nets)}

    # we add this book keeping to store meta data of model weights
    model_meta_data = []
    layer_type = []

    for cnn_i in range(n_nets):
        if args.model == "lenet":
            cnn = LeNet()
        elif args.model == "vgg":
            cnn = vgg11()
        elif args.model == "simple-cnn":
            if args.dataset in ("cifar10", "cinic10"):
                cnn = SimpleCNN(input_dim=(92 * kernelSize),
                                hidden_dims=[120, 84], output_dim=activityCount)
        cnns[cnn_i] = cnn

    for (k, v) in cnns[0].state_dict().items():
        model_meta_data.append(v.shape)
        layer_type.append(k)
    return cnns, model_meta_data, layer_type

In [None]:
for dicts in traindata_cls_counts: 
    for keys in traindata_cls_counts[dicts]: 
        traindata_cls_counts[dicts][keys] = np.int64(traindata_cls_counts[dicts][keys])

In [None]:
n_classes = len(np.unique(y_train))
averaging_weights = np.zeros((args.n_nets, n_classes), dtype=np.float32)

In [None]:
# coefficient per class
for i in range(n_classes):
    total_num_counts = 0
    worker_class_counts = [0] * args.n_nets
    for j in range(args.n_nets):
        if i in traindata_cls_counts[j].keys():
            total_num_counts += traindata_cls_counts[j][i]
            worker_class_counts[j] = traindata_cls_counts[j][i]
        else:
            total_num_counts += 0
            worker_class_counts[j] = 0
    averaging_weights[:, i] = worker_class_counts / total_num_counts
logger.info("averaging_weights: {}".format(averaging_weights))

In [None]:
def roundNumber(toRoundNb):
    return round(np.mean(toRoundNb), 4)

In [None]:
class CIFAR10_truncated(data.Dataset):

    def __init__(self, root,trainData,trainLabel, dataidxs=None, train=True, transform=None, target_transform=None, download=False):

        self.root = root
        self.dataidxs = dataidxs
        self.train = train
        self.transform = transform
        self.target_transform = target_transform
        self.download = download

        if self.dataidxs is not None:
            self.data = trainData[self.dataidxs]
            self.target = trainLabel[self.dataidxs]
        else:
            self.data = trainData
            self.target = trainLabel

    def __getitem__(self, index):
        img, target = self.data[index], self.target[index]


        return img, target

    def __len__(self):
        return len(self.data)


In [None]:
def compute_accuracy(model, dataloader, get_confusion_matrix=False, device="cpu"):

    was_training = False
    if model.training:
        model.eval()
        was_training = True

    true_labels_list, pred_labels_list = np.array([]), np.array([])
    correct, total = 0, 0
    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(dataloader):
#             print(target.shape)
            x, target = x.to(device), target.to(device)
            out = model(x)
            _, pred_label = torch.max(out.data, 1)
            
            total += x.data.size()[0]

            correct += (pred_label == target.data).sum().item()
            if device == "cpu":
                pred_labels_list = np.append(
                    pred_labels_list, pred_label.numpy())
                true_labels_list = np.append(
                    true_labels_list, target.data.numpy())
            else:
                pred_labels_list = np.append(
                    pred_labels_list, pred_label.cpu().numpy())
                true_labels_list = np.append(
                    true_labels_list, target.data.cpu().numpy())

    if get_confusion_matrix:
        conf_matrix = confusion_matrix(true_labels_list, pred_labels_list)

    if was_training:
        model.train()

    if get_confusion_matrix:
        return correct/float(total), conf_matrix

    return correct/float(total)


In [None]:
def normalize_weights(weights):
    Z = np.array([])
    eps = 1e-6
    weights_norm = {}
    for _, weight in weights.items():
        if len(Z) == 0:
            Z = weight.data.numpy()
        else:
            Z = Z + weight.data.numpy()
    for mi, weight in weights.items():
        weights_norm[mi] = weight / torch.from_numpy(Z + eps)
    return weights_norm

In [None]:
def get_weighted_average_pred(models: list, weights: dict, x, device="cpu"):
    out_weighted = None
    # Compute the predictions
    for model_i, model in enumerate(models):
        out = F.softmax(model(x), dim=-1)  # (N, C)

        weight = weights[model_i].to(device)
        if out_weighted is None:
            weight = weight.to(device)
            out_weighted = (out * weight)
        else:
            out_weighted += (out * weight)

    return out_weighted

In [None]:
def compute_ensemble_accuracy(models: list, dataloader, n_classes, train_cls_counts=None, uniform_weights=False, sanity_weights=False, device="cpu"):

    correct, total = 0, 0
    true_labels_list, pred_labels_list = np.array([]), np.array([])

    was_training = [False]*len(models)
    for i, model in enumerate(models):
        if model.training:
            was_training[i] = True
            model.eval()

    if uniform_weights is True:
        weights_list = prepare_uniform_weights(n_classes, len(models))
    elif sanity_weights is True:
        weights_list = prepare_sanity_weights(n_classes, len(models))
    else:
        weights_list = prepare_weight_matrix(n_classes, train_cls_counts)
    weights_norm = normalize_weights(weights_list)
    with torch.no_grad():
        for batch_idx, (x, target) in enumerate(dataloader):
            x, target = x.to(device), target.to(device)
            target = target.long()
            out = get_weighted_average_pred(
                models, weights_norm, x, device=device)

            _, pred_label = torch.max(out, 1)

            total += x.data.size()[0]
            correct += (pred_label == target.data).sum().item()

            if device == "cpu":
                pred_labels_list = np.append(
                    pred_labels_list, pred_label.numpy())
                true_labels_list = np.append(
                    true_labels_list, target.data.numpy())
            else:
                pred_labels_list = np.append(
                    pred_labels_list, pred_label.cpu().numpy())
                true_labels_list = np.append(
                    true_labels_list, target.data.cpu().numpy())

    conf_matrix = confusion_matrix(true_labels_list, pred_labels_list)

    for i, model in enumerate(models):
        if was_training[i]:
            model.train()

    return correct / float(total), conf_matrix

In [None]:
def row_param_cost(global_weights, weights_j_l, global_sigmas, sigma_inv_j):

    match_norms = ((weights_j_l + global_weights) ** 2 / (sigma_inv_j + global_sigmas)).sum(axis=1) - (
        global_weights ** 2 / global_sigmas).sum(axis=1)

    return match_norms



In [None]:
def process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0):
    J = len(batch_weights)
    sigma_bias = sigma
    sigma0_bias = sigma0
    mu0_bias = 0.1
    softmax_bias = [batch_weights[j][-1] for j in range(J)]
    softmax_inv_sigma = [s / sigma_bias for s in last_layer_const]
    softmax_bias = sum([b * s for b, s in zip(softmax_bias, softmax_inv_sigma)]) + mu0_bias / sigma0_bias
    softmax_inv_sigma = 1 / sigma0_bias + sum(softmax_inv_sigma)
    return softmax_bias, softmax_inv_sigma

In [None]:
def row_param_cost_simplified(global_weights, weights_j_l, sij_p_gs, red_term):
    match_norms = ((weights_j_l + global_weights) ** 2 / sij_p_gs).sum(axis=1) - red_term
    return match_norms

In [None]:
def compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma,
                 popularity_counts, gamma, J):

    param_cost_start = time.time()
    Lj = weights_j.shape[0]
    counts = np.minimum(np.array(popularity_counts, dtype=np.float32), 10)

    sij_p_gs = sigma_inv_j + global_sigmas
    red_term = (global_weights ** 2 / global_sigmas).sum(axis=1)
    stupid_line_start = time.time()

    param_cost = np.array([row_param_cost_simplified(global_weights, weights_j[l], sij_p_gs, red_term) for l in range(Lj)], dtype=np.float32)
    stupid_line_dur = time.time() - stupid_line_start

    param_cost += np.log(counts / (J - counts))
    param_cost_dur = time.time() - param_cost_start

    ## Nonparametric cost
    nonparam_start = time.time()
    L = global_weights.shape[0]
    max_added = min(Lj, max(700 - L, 1))
    nonparam_cost = np.outer((((weights_j + prior_mean_norm) ** 2 / (prior_inv_sigma + sigma_inv_j)).sum(axis=1) - (
                prior_mean_norm ** 2 / prior_inv_sigma).sum()), np.ones(max_added, dtype=np.float32))
    cost_pois = 2 * np.log(np.arange(1, max_added + 1))
    nonparam_cost -= cost_pois
    nonparam_cost += 2 * np.log(gamma / J)

    nonparam_dur = time.time() - nonparam_start

    full_cost = np.hstack((param_cost, nonparam_cost)).astype(np.float32)
    return full_cost



In [None]:
def reconstruct_local_net(weights, args, ori_assignments=None, worker_index=0):
    if args.model == "simple-cnn":
        if args.dataset in ("cifar10", "cinic10"):
            input_channel = 6

        num_filters = [weights[0].shape[0]]
        input_dim = weights[2].shape[0]
        hidden_dims = [weights[2].shape[1]]
        matched_cnn = SimpleCNNContainer(input_channel=input_channel,
                                         num_filters=num_filters,
                                         kernel_size=kernelSize,
                                         input_dim=input_dim,
                                         hidden_dims=hidden_dims,
                                         output_dim=activityCount)
        shape_estimator = SimpleCNNContainerConvBlocks(input_channel=6, num_filters=num_filters, kernel_size=kernelSize,output_dim=activityCount)
        dummy_input = torch.rand(1,6,128)
        estimated_output = shape_estimator(dummy_input)
        input_dim = estimated_output.view(-1).size()[0]

    def __reconstruct_weights(weight, assignment, layer_ori_shape, matched_num_filters=None, weight_type="conv_weight", slice_dim="filter"):
        if weight_type == "conv_weight":
            if slice_dim == "filter":
                res_weight = weight[assignment, :]
            elif slice_dim == "channel":
                _ori_matched_shape = list(copy.deepcopy(layer_ori_shape))
                _ori_matched_shape[1] = matched_num_filters
                trans_weight = trans_next_conv_layer_forward(
                    weight, _ori_matched_shape)
                logger.info("trans_weight{} assignment {} layer_ori_shape {}".format(
                    np.asarray(trans_weight).shape, np.asarray(assignment).shape, layer_ori_shape))
                sliced_weight = trans_weight[assignment, :]
                res_weight = trans_next_conv_layer_backward(
                    sliced_weight, layer_ori_shape)
        elif weight_type == "bias":
            res_weight = weight[assignment]
        elif weight_type == "first_fc_weight":
            # NOTE: please note that in this case, we pass the `estimated_shape` to `layer_ori_shape`:
            __ori_shape = weight.shape
            res_weight = weight.reshape(
                matched_num_filters, layer_ori_shape[2]*__ori_shape[1])[assignment, :]
            res_weight = res_weight.reshape(
                (len(assignment)*layer_ori_shape[2], __ori_shape[1]))
        elif weight_type == "fc_weight":
            if slice_dim == "filter":
                res_weight = weight.T[assignment, :]
                #res_weight = res_weight.T
            elif slice_dim == "channel":
                res_weight = weight[assignment, :]
        return res_weight

    reconstructed_weights = []
    # handle the conv layers part which is not changing
    for param_idx, (key_name, param) in enumerate(matched_cnn.state_dict().items()):
        _matched_weight = weights[param_idx]
        if param_idx < 1:  # we need to handle the 1st conv layer specificly since the color channels are aligned
            _assignment = ori_assignments[int(param_idx / 2)][worker_index]
            _res_weight = __reconstruct_weights(weight=_matched_weight, assignment=_assignment,
                                                layer_ori_shape=param.size(), matched_num_filters=None,
                                                weight_type="conv_weight", slice_dim="filter")
            reconstructed_weights.append(_res_weight)
        elif (param_idx >= 1) and (param_idx < len(weights) - 2):
            if "bias" in key_name:  # the last bias layer is already aligned so we won't need to process it
                _assignment = ori_assignments[int(param_idx / 2)][worker_index]
                _res_bias = __reconstruct_weights(weight=_matched_weight, assignment=_assignment,
                                                  layer_ori_shape=param.size(), matched_num_filters=None,
                                                  weight_type="bias", slice_dim=None)
                reconstructed_weights.append(_res_bias)

            elif "conv" in key_name or "features" in key_name:
                # we make a note here that for these weights, we will need to slice in both `filter` and `channel` dimensions
                cur_assignment = ori_assignments[int(
                    param_idx / 2)][worker_index]
                prev_assignment = ori_assignments[int(
                    param_idx / 2)-1][worker_index]
                _matched_num_filters = weights[param_idx - 2].shape[0]
                _layer_ori_shape = list(param.size())
                _layer_ori_shape[0] = _matched_weight.shape[0]

                _temp_res_weight = __reconstruct_weights(weight=_matched_weight, assignment=prev_assignment,
                                                         layer_ori_shape=_layer_ori_shape, matched_num_filters=_matched_num_filters,
                                                         weight_type="conv_weight", slice_dim="channel")

                _res_weight = __reconstruct_weights(weight=_temp_res_weight, assignment=cur_assignment,
                                                    layer_ori_shape=param.size(), matched_num_filters=None,
                                                    weight_type="conv_weight", slice_dim="filter")
                reconstructed_weights.append(_res_weight)

            elif "fc" in key_name or "classifier" in key_name:
                # we make a note here that for these weights, we will need to slice in both `filter` and `channel` dimensions
                cur_assignment = ori_assignments[int(
                    param_idx / 2)][worker_index]
                prev_assignment = ori_assignments[int(
                    param_idx / 2)-1][worker_index]
                _matched_num_filters = weights[param_idx - 2].shape[0]

                if param_idx != 2:  # this is the index of the first fc layer
                    #logger.info("%%%%%%%%%%%%%%% prev assignment length: {}, cur assignmnet length: {}".format(len(prev_assignment), len(cur_assignment)))
                    temp_res_weight = __reconstruct_weights(weight=_matched_weight, assignment=prev_assignment,
                                                            layer_ori_shape=param.size(), matched_num_filters=_matched_num_filters,
                                                            weight_type="fc_weight", slice_dim="channel")

                    _res_weight = __reconstruct_weights(weight=temp_res_weight, assignment=cur_assignment,
                                                        layer_ori_shape=param.size(), matched_num_filters=None,
                                                        weight_type="fc_weight", slice_dim="filter")

                    reconstructed_weights.append(_res_weight.T)
                else:
                    # that's for handling the first fc layer that is connected to the conv blocks
                    temp_res_weight = __reconstruct_weights(weight=_matched_weight, assignment=prev_assignment,
                                                            layer_ori_shape=estimated_output.size(), matched_num_filters=_matched_num_filters,
                                                            weight_type="first_fc_weight", slice_dim=None)

                    _res_weight = __reconstruct_weights(weight=temp_res_weight, assignment=cur_assignment,
                                                        layer_ori_shape=param.size(), matched_num_filters=None,
                                                        weight_type="fc_weight", slice_dim="filter")
                    reconstructed_weights.append(_res_weight.T)
        elif param_idx == len(weights) - 2:
            # this is to handle the weight of the last layer
            prev_assignment = ori_assignments[int(
                param_idx / 2)-1][worker_index]
            _res_weight = _matched_weight[prev_assignment, :]
            reconstructed_weights.append(_res_weight)
        elif param_idx == len(weights) - 1:
            reconstructed_weights.append(_matched_weight)

    return reconstructed_weights

In [None]:
def matching_upd_j(weights_j, global_weights, sigma_inv_j, global_sigmas, prior_mean_norm, prior_inv_sigma,
                   popularity_counts, gamma, J):

    L = global_weights.shape[0]

    compute_cost_start = time.time()
    full_cost = compute_cost(global_weights.astype(np.float32), weights_j.astype(np.float32), global_sigmas.astype(np.float32), sigma_inv_j.astype(np.float32), prior_mean_norm.astype(np.float32), prior_inv_sigma.astype(np.float32),
                             popularity_counts, gamma, J)
    compute_cost_dur = time.time() - compute_cost_start
    #logger.info("###### Compute cost dur: {}".format(compute_cost_dur))

    #row_ind, col_ind = linear_sum_assignment(-full_cost)
    # please note that this can not run on non-Linux systems
    start_time = time.time()
    row_ind, col_ind = solve_dense(-full_cost)
    solve_dur = time.time() - start_time



    assignment_j = []

    new_L = L

    for l, i in zip(row_ind, col_ind):
        if i < L:
            popularity_counts[i] += 1
            assignment_j.append(i)
            global_weights[i] += weights_j[l]
            global_sigmas[i] += sigma_inv_j
        else:  # new neuron
            popularity_counts += [1]
            assignment_j.append(new_L)
            new_L += 1
            global_weights = np.vstack((global_weights, prior_mean_norm + weights_j[l]))
            global_sigmas = np.vstack((global_sigmas, prior_inv_sigma + sigma_inv_j))

    return global_weights, global_sigmas, popularity_counts, assignment_j


In [None]:
def pdm_prepare_full_weights_cnn(nets, device="cpu"):
    """
    we extract all weights of the conv nets out here:
    """
    weights = []
    for net_i, net in enumerate(nets):
        net_weights = []
        statedict = net.state_dict()

        for param_id, (k, v) in enumerate(statedict.items()):
            if device == "cpu":
                if 'fc' in k or 'classifier' in k:
                    if 'weight' in k:
                        net_weights.append(v.numpy().T)
                    else:
                        net_weights.append(v.numpy())
                elif 'conv' in k or 'features' in k:
                    if 'weight' in k:
                        _weight_shape = v.size()
                        if len(_weight_shape) == 3:
                            net_weights.append(v.numpy().reshape(_weight_shape[0], _weight_shape[1]*_weight_shape[2]))
                        else:
                            pass
                    else:
                        net_weights.append(v.numpy())
            else:
                if 'fc' in k or 'classifier' in k:
                    if 'weight' in k:
                        net_weights.append(v.cpu().numpy().T)
                    else:
                        net_weights.append(v.cpu().numpy())
                elif 'conv' in k or 'features' in k:
                    if 'weight' in k:
                        _weight_shape = v.size()
                        if len(_weight_shape) == 3:
                            net_weights.append(v.cpu().numpy().reshape(_weight_shape[0], _weight_shape[1]*_weight_shape[2]))
                        else:
                            pass
                    else:
                        net_weights.append(v.cpu().numpy())
        weights.append(net_weights)
    return weights


In [None]:
def block_patching(w_j, L_next, assignment_j_c, layer_index, model_meta_data, 
                                matching_shapes=None, 
                                layer_type="fc", 
                                dataset="cifar10",
                                network_name="lenet"):
    """
    In CNN, weights patching needs to be handled block-wisely
    We handle all conv layers and the first fc layer connected with the output of conv layers here
    """

    if assignment_j_c is None:
        return w_j

    layer_meta_data = model_meta_data[2 * layer_index - 2]
    prev_layer_meta_data = model_meta_data[2 * layer_index - 2 - 2]

    if layer_type == "conv":    
        new_w_j = np.zeros((w_j.shape[0], L_next*(layer_meta_data[-1])))

        # we generate a sequence of block indices
        block_indices = [np.arange(i*layer_meta_data[-1], (i+1)*layer_meta_data[-1]) for i in range(L_next)]
        ori_block_indices = [np.arange(i*layer_meta_data[-1], (i+1)*layer_meta_data[-1]) for i in range(layer_meta_data[1])]
        for ori_id in range(layer_meta_data[1]):
            new_w_j[:, block_indices[assignment_j_c[ori_id]]] = w_j[:, ori_block_indices[ori_id]]

    elif layer_type == "fc":
        # we need to estimate the output shape here:
        if network_name == "simple-cnn":
            if dataset in ("cifar10", "cinic10"):
                shape_estimator = SimpleCNNContainerConvBlocks(input_channel=6, num_filters=matching_shapes, kernel_size=kernelSize,output_dim=activityCount)
        if dataset in ("cifar10", "cinic10"):
            dummy_input = torch.rand(1, 6, 128)
        estimated_output = shape_estimator(dummy_input)
        new_w_j = np.zeros((w_j.shape[0], estimated_output.view(-1).size()[0]))
#         logger.info("estimated_output shape : {}".format(estimated_output.size()))
#         logger.info("meta data of previous layer: {}".format(prev_layer_meta_data))
        
        block_indices = [np.arange(i*estimated_output.size()[-1], (i+1)*estimated_output.size()[-1]) for i in range(L_next)]

        ori_block_indices = [np.arange(i*estimated_output.size()[-1], (i+1)*estimated_output.size()[-1]) for i in range(prev_layer_meta_data[0])]

        for ori_id in range(prev_layer_meta_data[0]):
            #logger.info("{} ------------ to ------------ {}".format(block_indices[assignment_j_c[ori_id]], ori_block_indices[ori_id]))
            new_w_j[:, block_indices[assignment_j_c[ori_id]]] = w_j[:, ori_block_indices[ori_id]]
    return new_w_j

In [None]:
def match_layer(weights_bias, sigma_inv_layer, mean_prior, sigma_inv_prior, gamma, it):
    J = len(weights_bias)

    group_order = sorted(range(J), key=lambda x: -weights_bias[x].shape[0])

    batch_weights_norm = [w * s for w, s in zip(weights_bias, sigma_inv_layer)]
    prior_mean_norm = mean_prior * sigma_inv_prior

    global_weights = prior_mean_norm + batch_weights_norm[group_order[0]]
    global_sigmas = np.outer(np.ones(global_weights.shape[0]), sigma_inv_prior + sigma_inv_layer[group_order[0]])

    popularity_counts = [1] * global_weights.shape[0]

    assignment = [[] for _ in range(J)]

    assignment[group_order[0]] = list(range(global_weights.shape[0]))

    ## Initialize
    for j in group_order[1:]:
        global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j],
                                                                                        global_weights,
                                                                                        sigma_inv_layer[j],
                                                                                        global_sigmas, prior_mean_norm,
                                                                                        sigma_inv_prior,
                                                                                        popularity_counts, gamma, J)
        assignment[j] = assignment_j

    ## Iterate over groups
    for iteration in range(it):
        random_order = np.random.permutation(J)
        for j in random_order:  # random_order:
            to_delete = []
            ## Remove j
            Lj = len(assignment[j])
            for l, i in sorted(zip(range(Lj), assignment[j]), key=lambda x: -x[1]):
                popularity_counts[i] -= 1
                if popularity_counts[i] == 0:
                    del popularity_counts[i]
                    to_delete.append(i)
                    for j_clean in range(J):
                        for idx, l_ind in enumerate(assignment[j_clean]):
                            if i < l_ind and j_clean != j:
                                assignment[j_clean][idx] -= 1
                            elif i == l_ind and j_clean != j:
                                logger.info('Warning - weird unmatching')
                else:
                    global_weights[i] = global_weights[i] - batch_weights_norm[j][l]
                    global_sigmas[i] -= sigma_inv_layer[j]

            global_weights = np.delete(global_weights, to_delete, axis=0)
            global_sigmas = np.delete(global_sigmas, to_delete, axis=0)

            ## Match j
            global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j],
                                                                                            global_weights,
                                                                                            sigma_inv_layer[j],
                                                                                            global_sigmas,
                                                                                            prior_mean_norm,
                                                                                            sigma_inv_prior,
                                                                                            popularity_counts, gamma, J)
            assignment[j] = assignment_j

    logger.info('Number of global neurons is %d, gamma %f' % (global_weights.shape[0], gamma))
    logger.info("***************Shape of global weights after match: {} ******************".format(global_weights.shape))
    return assignment, global_weights, global_sigmas
    


In [None]:
def layer_wise_group_descent(batch_weights, layer_index, batch_frequencies, sigma_layers, 
                                sigma0_layers, gamma_layers, it, 
                                model_meta_data, 
                                model_layer_type,
                                n_layers,
                                matching_shapes,
                                args):
    """
    We implement a layer-wise matching here:
    """
    if type(sigma_layers) is not list:
        sigma_layers = (n_layers - 1) * [sigma_layers]
    if type(sigma0_layers) is not list:
        sigma0_layers = (n_layers - 1) * [sigma0_layers]
    if type(gamma_layers) is not list:
        gamma_layers = (n_layers - 1) * [gamma_layers]

    last_layer_const = []
    total_freq = sum(batch_frequencies)
    for f in batch_frequencies:
        last_layer_const.append(f / total_freq)

    # J: number of workers
    J = len(batch_weights)
    init_num_kernel = batch_weights[0][0].shape[0]
    init_channel_kernel_dims = []
    for bw in batch_weights[0]:
        if len(bw.shape) > 1:
            init_channel_kernel_dims.append(bw.shape[1])

    
    sigma_bias_layers = sigma_layers
    sigma0_bias_layers = sigma0_layers
    mu0 = 0.
    mu0_bias = 0.1
    assignment_c = [None for j in range(J)]
    L_next = None

    sigma = sigma_layers[layer_index - 1]
    sigma_bias = sigma_bias_layers[layer_index - 1]
    gamma = gamma_layers[layer_index - 1]
    sigma0 = sigma0_layers[layer_index - 1]
    sigma0_bias = sigma0_bias_layers[layer_index - 1]

    if layer_index <= 1:
        weights_bias = [np.hstack((batch_weights[j][0], batch_weights[j][layer_index * 2 - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array(
            init_channel_kernel_dims[layer_index - 1] * [1 / sigma0] + [1 / sigma0_bias])
        mean_prior = np.array(init_channel_kernel_dims[layer_index - 1] * [mu0] + [mu0_bias])

        # handling 2-layer neural network
        if n_layers == 2:
            sigma_inv_layer = [
                np.array(D * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)]
        else:
            sigma_inv_layer = [np.array(init_channel_kernel_dims[layer_index - 1] * [1 / sigma] + [1 / sigma_bias]) for j in range(J)]

    elif layer_index == (n_layers - 1) and n_layers > 2:
        # our assumption is that this branch will consistently handle the last fc layers
        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 2 - 2]
        first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))


        if first_fc_identifier:
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                        batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]
        else:
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                        batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]


        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        
    
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 2 - 2]

        if 'conv' in layer_type or 'features' in layer_type:
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2], batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]

        elif 'fc' in layer_type or 'classifier' in layer_type:
            first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))
            if first_fc_identifier:
                weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]
            else:
                weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]          

        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    assignment_c, global_weights_c, global_sigmas_c = match_layer(weights_bias, sigma_inv_layer, mean_prior,
                                                                  sigma_inv_prior, gamma, it)

    L_next = global_weights_c.shape[0]

    if layer_index <= 1:
        if n_layers == 2:
            softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)
            global_weights_out = [softmax_bias]
            global_inv_sigmas_out = [softmax_inv_sigma]
        
        global_weights_out = [global_weights_c[:, :init_channel_kernel_dims[int(layer_index/2)]], global_weights_c[:, init_channel_kernel_dims[int(layer_index/2)]]]
        global_inv_sigmas_out = [global_sigmas_c[:, :init_channel_kernel_dims[int(layer_index/2)]], global_sigmas_c[:, init_channel_kernel_dims[int(layer_index/2)]]]



    elif layer_index == (n_layers - 1) and n_layers > 2:
        softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)

        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 2 - 2]

        first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))

        # remove fitting the last layer
        if first_fc_identifier:
            
            global_weights_out = [global_weights_c[:, 0:-1].T, 
                                    global_weights_c[:, -softmax_bias.shape[0]-1]]

            global_inv_sigmas_out = [global_sigmas_c[:, 0:-1].T, 
                                        global_sigmas_c[:, -softmax_bias.shape[0]-1]]
        else:
            global_weights_out = [global_weights_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, 
                                    global_weights_c[:, matching_shapes[layer_index - 1 - 1]]]

            global_inv_sigmas_out = [global_sigmas_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, 
                                        global_sigmas_c[:, matching_shapes[layer_index - 1 - 1]]]


    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        layer_type = model_layer_type[2 * layer_index - 2]
        gwc_shape = global_weights_c.shape

        if "conv" in layer_type or 'features' in layer_type:
            global_weights_out = [global_weights_c[:, 0:gwc_shape[1]-1], global_weights_c[:, gwc_shape[1]-1]]
            global_inv_sigmas_out = [global_sigmas_c[:, 0:gwc_shape[1]-1], global_sigmas_c[:, gwc_shape[1]-1]]
        elif "fc" in layer_type or 'classifier' in layer_type:
            global_weights_out = [global_weights_c[:, 0:gwc_shape[1]-1].T, global_weights_c[:, gwc_shape[1]-1]]
            global_inv_sigmas_out = [global_sigmas_c[:, 0:gwc_shape[1]-1].T, global_sigmas_c[:, gwc_shape[1]-1]]

    map_out = [g_w / g_s for g_w, g_s in zip(global_weights_out, global_inv_sigmas_out)]
    return map_out, assignment_c, L_next



In [None]:
def rebuild_net(weights):
    num_filters = [weights[0].shape[0]]
    input_dim = weights[2].shape[0]
    hidden_dims = [weights[2].shape[1]]
    matched_cnn = SimpleCNNContainer(6,num_filters,kernel_size=kernelSize,input_dim=input_dim,hidden_dims=hidden_dims,output_dim=activityCount)
    return matched_cnn

In [None]:
def compute_model_averaging_accuracy(models, weights, train_dl, test_dl, n_classes, args):
    """An variant of fedaveraging"""
    if args.model == "lenet":
        avg_cnn = LeNet()
    elif args.model == "vgg":
        avg_cnn = vgg11()
    elif args.model == "simple-cnn":
        if args.dataset in ("cifar10", "cinic10"):
            avg_cnn = SimpleCNN(input_dim=(16 * 5 * 5),
                                hidden_dims=[120, 84], output_dim=activityCount)
    elif args.model == "moderate-cnn":
        if args.dataset in ("cifar10", "cinic10"):
            avg_cnn = ModerateCNN()
        elif args.dataset == "mnist":
            avg_cnn = ModerateCNNMNIST()

    new_state_dict = {}
    model_counter = 0

    # handle the conv layers part which is not changing
    for param_idx, (key_name, param) in enumerate(avg_cnn.state_dict().items()):
        if "conv" in key_name or "features" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(
                    weights[param_idx].reshape(param.size()))}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx])}
        elif "fc" in key_name or "classifier" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx].T)}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx])}

        new_state_dict.update(temp_dict)

    avg_cnn.load_state_dict(new_state_dict)

    # switch to eval mode:
    avg_cnn.eval()
    ##

    correct, total = 0, 0
    for batch_idx, (x, target) in enumerate(test_dl):
        out_k = avg_cnn(x)
        _, pred_label = torch.max(out_k, 1)
        total += x.data.size()[0]
        correct += (pred_label == target.data).sum().item()

    logger.info(
        "Accuracy for Fed Averaging correct: {}, total: {}".format(correct, total))

In [None]:
def BBP_MAP(nets_list, model_meta_data, layer_type, net_dataidx_map,
            averaging_weights, args,
            device="cpu"):
    # starting the neural matching
    models = nets_list
    cls_freqs = traindata_cls_counts
    n_classes = activityCount
    it = iteration
    sigma = args_pdm_sig
    sigma0 = args_pdm_sig0
    gamma = args_pdm_gamma
    assignments_list = []

    batch_weights = pdm_prepare_full_weights_cnn(models, device=device)
    raw_batch_weights = copy.deepcopy(batch_weights)

    logging.info("=="*15)
#     logging.info("Weights shapes: {}".format(
#         [bw.shape for bw in batch_weights[0]]))

    batch_freqs = pdm_prepare_freq(cls_freqs, n_classes)
    res = {}
    best_test_acc, best_train_acc, best_weights, best_sigma, best_gamma, best_sigma0 = - \
        1, -1, None, -1, -1, -1

    gamma = gammaValue
    sigma = 1.0
    sigma0 = 1.0

    n_layers = int(len(batch_weights[0]) / 2)
    num_workers = len(nets_list)
    matching_shapes = []

    first_fc_index = None

    for layer_index in range(1, n_layers):
        layer_hungarian_weights, assignment, L_next = layer_wise_group_descent(
            batch_weights=batch_weights,
            layer_index=layer_index,
            sigma0_layers=sigma0,
            sigma_layers=sigma,
            batch_frequencies=batch_freqs,
            it=it,
            gamma_layers=gamma,
            model_meta_data=model_meta_data,
            model_layer_type=layer_type,
            n_layers=n_layers,
            matching_shapes=matching_shapes,
            args=args
        )
        
        assignments_list.append(assignment)

        # iii) load weights to the model and train the whole thing
        type_of_patched_layer = layer_type[2 * (layer_index + 1) - 2]
        if 'conv' in type_of_patched_layer or 'features' in type_of_patched_layer:
            l_type = "conv"
        elif 'fc' in type_of_patched_layer or 'classifier' in type_of_patched_layer:
            l_type = "fc"

        type_of_this_layer = layer_type[2 * layer_index - 2]
        type_of_prev_layer = layer_type[2 * layer_index - 2 - 2]
        first_fc_identifier = (('fc' in type_of_this_layer or 'classifier' in type_of_this_layer) and (
            'conv' in type_of_prev_layer or 'features' in type_of_this_layer))

        if first_fc_identifier:
            first_fc_index = layer_index

        matching_shapes.append(L_next)
        tempt_weights = [([batch_weights[w][i] for i in range(2 * layer_index - 2)] +
                          copy.deepcopy(layer_hungarian_weights)) for w in range(num_workers)]

        # i) permutate the next layer wrt matching result
        for worker_index in range(num_workers):
            if first_fc_index is None:
                if l_type == "conv":
                    patched_weight = block_patching(batch_weights[worker_index][2 * (layer_index + 1) - 2],
                                                    L_next, assignment[worker_index],
                                                    layer_index+1, model_meta_data,
                                                    matching_shapes=matching_shapes, layer_type=l_type,
                                                    dataset=args.dataset, network_name=args.model)
                elif l_type == "fc":
                    patched_weight = block_patching(batch_weights[worker_index][2 * (layer_index + 1) - 2].T,
                                                    L_next, assignment[worker_index],
                                                    layer_index+1, model_meta_data,
                                                    matching_shapes=matching_shapes, layer_type=l_type,
                                                    dataset=args.dataset, network_name=args.model).T

            elif layer_index >= first_fc_index:
                patched_weight = patch_weights(
                    batch_weights[worker_index][2 * (layer_index + 1) - 2].T, L_next, assignment[worker_index]).T

            tempt_weights[worker_index].append(patched_weight)
        # ii) prepare the whole network weights
        for worker_index in range(num_workers):
            for lid in range(2 * (layer_index + 1) - 1, len(batch_weights[0])):
                tempt_weights[worker_index].append(
                    batch_weights[worker_index][lid])
                
        myModel = np.asarray(tempt_weights[0])

                
        for worker_index in range(num_workers):
            dataidxs = net_dataidx_map[worker_index]
            train_dl_local, test_dl_local = get_dataloader(args.dataset, args_datadir, args.batch_size, 32, dataidxs,worker_index)
        retrained_nets = []
        for worker_index in range(num_workers):
            dataidxs = net_dataidx_map[worker_index]
            train_dl_local, test_dl_local = get_dataloader(
                args.dataset, args_datadir, args.batch_size, 32, dataidxs,worker_index)

            retrained_cnn = local_retrain((train_dl_local, test_dl_local), tempt_weights[worker_index], args,
                                          freezing_index=(2 * (layer_index + 1) - 2), device=device)
            retrained_nets.append(retrained_cnn)
        batch_weights = pdm_prepare_full_weights_cnn(
            retrained_nets, device=device)


    matched_weights = []
    num_layers = len(batch_weights[0])

    last_layer_weights_collector = []

    for i in range(num_workers):
        # firstly we combine last layer's weight and bias
        bias_shape = batch_weights[i][-1].shape
        last_layer_bias = batch_weights[i][-1].reshape((1, bias_shape[0]))
        last_layer_weights = np.concatenate(
            (batch_weights[i][-2], last_layer_bias), axis=0)

        last_layer_weights_collector.append(last_layer_weights)

    last_layer_weights_collector = np.array(last_layer_weights_collector)

    avg_last_layer_weight = np.zeros(
        last_layer_weights_collector[0].shape, dtype=np.float32)

    for i in range(n_classes):
        avg_weight_collector = np.zeros(
            last_layer_weights_collector[0][:, 0].shape, dtype=np.float32)
        for j in range(num_workers):
            avg_weight_collector += averaging_weights[j][i] * \
                last_layer_weights_collector[j][:, i]
        avg_last_layer_weight[:, i] = avg_weight_collector

    #avg_last_layer_weight = np.mean(last_layer_weights_collector, axis=0)
    for i in range(num_layers):
        if i < (num_layers - 2):
            matched_weights.append(batch_weights[0][i])

    matched_weights.append(avg_last_layer_weight[0:-1, :])
    matched_weights.append(avg_last_layer_weight[-1, :])
    return matched_weights, assignments_list



In [None]:
def computeWeights(modelWeight):
    modelWeight = np.asarray(modelWeight)
    modelWeightsPrep = []
    for i in range(int(modelWeight.shape[0]/ 2)):
        if(layerType[i] == 0 ):
            weightReshaped = modelWeight[i*2]
        else:
            weightReshaped = modelWeight[i*2].T
        biasReshaped = modelWeight[i*2+1].reshape(-1,1)
        modelWeightsPrep.append(np.hstack((weightReshaped,biasReshaped)))

    modelWeightsPrep = np.asarray(modelWeightsPrep)
    return modelWeightsPrep

In [None]:
def fedma_comm(batch_weights, model_meta_data, layer_type, net_dataidx_map,
               averaging_weights, args,
               train_dl_global,
               test_dl_global,
               assignments_list,
               comm_round=2,
               device="cpu"):
    
    
    timePerRound = 0
    n_layers = int(len(batch_weights[0]) / 2)
    num_workers = len(batch_weights)
    matching_shapes = []
    first_fc_index = None
    gamma = gammaValue
    sigma = 1.0
    sigma0 = 1.0

    cls_freqs = traindata_cls_counts
    n_classes = activityCount
    batch_freqs = pdm_prepare_freq(cls_freqs, n_classes)
    it = iteration
    serverAccuracyTest = []
    serverLossTest = []
    serverAccuracyTrain = []
    serverLossTrain = []
    
    serverMacroVal_f1Train = []
    serverMacroVal_f1Test = []
    
    clientTestSingleAccuracy = []
    clientTestSingleAccuracyStd = []
    clientTestSingleLoss = []
    clientTestSingleLossStd = []    
    
    clientTrainSingleAccuracy = []
    clientTrainSingleAccuracyStd = []
    clientTrainSingleLoss = []
    clientTrainSingleLossStd = []   
    
    clientTestAllAccuracy = []
    clientTestAllAccuracyStd = []
    clientTestAllLoss = []
    clientTestAllLossStd = []    
    
    clientTrainAllAccuracy = []
    clientTrainAllAccuracyStd = []
    clientTrainAllLoss = []
    clientTrainAllLossStd = []   
    
    cleintMacroVal_f1SingleTrain = []
    clientMacroVal_f1SingleTest = []
    
    cleintMacroVal_f1AllTrain = []
    clientMacroVal_f1AllTest = []
    
    
    
    
    for cr in range(comm_round):
        startTime = time.time()
        logger.info("Entering communication round: {} ...".format(cr))
        retrained_nets = []
        for worker_index in range(args.n_nets):
            dataidxs = net_dataidx_map[worker_index]
            train_dl_local, test_dl_local = get_dataloader(
                args.dataset, args_datadir, args.batch_size, 32, dataidxs,worker_index)

            recons_local_net = reconstruct_local_net(
                batch_weights[worker_index], args, ori_assignments=assignments_list, worker_index=worker_index) 
                
            retrained_cnn = local_retrain((train_dl_local, test_dl_local), recons_local_net, args,
                                          mode="bottom-up", freezing_index=0, ori_assignments=None, device=device)
            retrained_nets.append(retrained_cnn)

        server_weights = computeWeights(batch_weights[0])

        batch_weights = pdm_prepare_full_weights_cnn(retrained_nets, device=device)
        
        tempSingleAccuracyTest = []
        tempSingleLossTest = []
        tempSingleAccuracyTrain = []
        tempSingleLossTrain = []
        
        tempAllAccuracyTest = []
        tempAllLossTest = []
        tempAllAccuracyTrain = []
        tempAllLossTrain = []
        
        tempSingleMacroVal_f1Train = []
        tempSingleMacroVal_f1Test = []
        
        tempAllMacroVal_f1Train = []
        tempAllMacroVal_f1Test = []
        

        for worker_index in range(args.n_nets): 
            dataidxs = net_dataidx_map[worker_index]
            train_dl_local, test_dl_local = get_dataloader(args.dataset, args_datadir, args.batch_size, 32, dataidxs,worker_index)
        
            tempAccSingleTest, tempLsSingleTest,tempAccSingleTrain,tempLsSingleTrain,tempMacroVal_f1SingleTrain,tempMacroVal_f1SingleTest = compute_full_cnn_accuracy(models,batch_weights[worker_index],train_dl_local,test_dl_local,n_classes,args,clientServer="Client")
            
            tempSingleAccuracyTest.append(tempAccSingleTest)
            tempSingleLossTest.append(tempLsSingleTest)
            
            tempSingleAccuracyTrain.append(tempAccSingleTrain)
            tempSingleLossTrain.append(tempLsSingleTrain)
            
            tempSingleMacroVal_f1Train.append(tempMacroVal_f1SingleTrain)
            tempSingleMacroVal_f1Test.append(tempMacroVal_f1SingleTest)
            
            
            if(ClientAllTest  == True):
                
                tempAccAllTest, tempLsAllTest,tempAccAllTrain,tempLsAllTrain, tpAllMacroVal_f1Train, tpAllMacroVal_f1Test= compute_full_cnn_accuracy(models,batch_weights[worker_index],train_dl_global,test_dl_global,n_classes,args,clientServer="Client")
                
                tempAllAccuracyTest.append(tempAccAllTest)
                tempAllLossTest.append(tempLsAllTest)

                tempAllAccuracyTrain.append(tempAccAllTrain)
                tempAllLossTrain.append(tempLsAllTrain)
                
                tempAllMacroVal_f1Train.append(tpAllMacroVal_f1Train)
                tempAllMacroVal_f1Test.append(tpAllMacroVal_f1Test)
                 

        if(euclid):
            meanServerClient = []
            stdServerClient = []
            serverShape = np.asarray(computeWeights(recons_local_net))
            localMeanClientLayer = []
            localStdClientLayer = []
            for clientIndex in range(clientCount):
                localMeanServerClient = []
                localStdServerClient = []
                localShape = np.asarray(computeWeights(batch_weights[clientIndex]))
                for i in range(serverShape.shape[0]):
                    newLayerDist = np.sqrt((serverShape[i] - localShape[i])**2)
                    localMeanServerClient.append(np.mean(newLayerDist))
                    localStdServerClient.append(np.std(newLayerDist))
                localMeanClientLayer.append(localMeanServerClient)
                localStdClientLayer.append(localStdServerClient)

                meanServerClient.append(np.mean(localMeanServerClient))
                stdServerClient.append(np.mean(localStdServerClient))


    #         15 clients 
            meanHistoryDist.append(np.asarray(meanServerClient))
            stdHistoryDist.append(np.asarray(stdServerClient))

    #         per layer distance
            meanRoundLayerHistory.append(np.mean(localMeanClientLayer,axis = 0))
            stdRoundLayerHistory.append(np.mean(localStdClientLayer,axis=0))

    #         all layer distance
            meanRoundGeneralLayerHistory.append(np.mean(localMeanClientLayer))
            stdRoundGeneralLayerHistory.append(np.mean(localStdClientLayer))

        
        

        clientTestSingleAccuracy.append(np.mean(tempSingleAccuracyTest))
        clientTestSingleAccuracyStd.append(np.std(tempSingleAccuracyTest))
        clientTestSingleLoss.append(np.mean(tempSingleLossTest))
        clientTestSingleLossStd.append(np.mean(tempSingleLossTest))
        
        clientTrainSingleAccuracy.append(np.mean(tempSingleAccuracyTrain))
        clientTrainSingleAccuracyStd.append(np.std(tempSingleAccuracyTrain))
        clientTrainSingleLoss.append(np.mean(tempSingleLossTrain))
        clientTrainSingleLossStd.append(np.mean(tempSingleLossTrain))
        
        cleintMacroVal_f1SingleTrain.append(np.mean(tempSingleMacroVal_f1Train))
        clientMacroVal_f1SingleTest.append(np.mean(tempSingleMacroVal_f1Test))
        
        
        logger.info("Client single train accuracy: {}".format(np.mean(clientTrainSingleAccuracy)))
#         logger.info("Client single train loss: {}".format(np.mean(clientTrainSingleLoss)))
        logger.info("Client single test accuracy: {}".format(np.mean(clientTestSingleAccuracy)))
#         logger.info("Client single test loss: {}".format(np.mean(clientTestSingleLoss)))
        
        if(ClientAllTest  == True):
            clientTestAllAccuracy.append(np.mean(tempAllAccuracyTest))
            clientTestAllAccuracyStd.append(np.std(tempAllAccuracyTest))
            clientTestAllLoss.append(np.mean(tempAllLossTest))
            clientTestAllLossStd.append(np.mean(tempAllLossTest))

            clientTrainAllAccuracy.append(np.mean(tempAllAccuracyTrain))
            clientTrainAllAccuracyStd.append(np.std(tempAllAccuracyTrain))
            clientTrainAllLoss.append(np.mean(tempAllLossTrain))
            clientTrainAllLossStd.append(np.mean(tempAllLossTrain))

            logger.info("Client all train accuracy: {}".format(np.mean(clientTrainAllAccuracy)))
#             logger.info("Client all train loss: {}".format(np.mean(clientTrainAllLoss)))
            logger.info("Client all test accuracy: {}".format(np.mean(clientTestAllAccuracy)))
#             logger.info("Client all test loss: {}".format(np.mean(clientTestAllLoss)))

            cleintMacroVal_f1AllTrain.append(np.mean(tempAllMacroVal_f1Train))
            clientMacroVal_f1AllTest.append(np.mean(tempAllMacroVal_f1Test))
        
        # BBP_MAP step
        hungarian_weights, assignments_list = BBP_MAP(
            retrained_nets, model_meta_data, layer_type, net_dataidx_map, averaging_weights, args, device=device)
        
        batch_weights = [copy.deepcopy(hungarian_weights)
                         for _ in range(args.n_nets)]
        
        
        
        serverAccTest, serverLsTest , serverAccTrain , serverLsTrain,sMacroVal_f1Train,sMacroVal_f1Test = compute_full_cnn_accuracy(None,
                                      hungarian_weights,
                                      train_dl_global,
                                      test_dl_global,
                                      n_classes,
                                      device=device,
                                      args=args)
        
        serverAccuracyTest.append(serverAccTest)
        serverLossTest.append(serverLsTest)
        serverAccuracyTrain.append(serverAccTrain)
        serverLossTrain.append(serverLsTrain)
        
        serverMacroVal_f1Train.append(sMacroVal_f1Train)
        serverMacroVal_f1Test.append(sMacroVal_f1Test)
        
        
        
        logger.info("Server train accuracy: {}".format(serverAccTrain))
#         logger.info("Server train loss: {}".format(serverLsTrain))
        logger.info("Server test accuracy: {}".format(serverAccTest))
#         logger.info("Server test loss: {}".format(serverLsTest))
                
        del hungarian_weights
        del retrained_nets
        timePerRound = time.time() - startTime 
        print(timePerRound)
        
    serverStats = [serverAccuracyTrain,serverLossTrain,serverAccuracyTest,serverLossTest,serverMacroVal_f1Train,serverMacroVal_f1Test]
    clientSingleStats = [clientTrainSingleAccuracy,clientTrainSingleAccuracyStd,clientTrainSingleLoss,clientTrainSingleLossStd,clientTestSingleAccuracy,clientTestSingleAccuracyStd,clientTestSingleLoss,clientTestSingleLossStd,cleintMacroVal_f1SingleTrain,clientMacroVal_f1SingleTest]
    clientAllStats = [clientTestAllAccuracy,clientTestAllAccuracyStd,clientTestAllLoss,clientTestAllLossStd,clientTrainAllAccuracy,clientTrainAllAccuracyStd,clientTrainAllLoss,clientTrainAllLossStd,cleintMacroVal_f1AllTrain,clientMacroVal_f1AllTest]
    return serverStats,clientSingleStats,clientAllStats,batch_weights,timePerRound

In [None]:
def local_retrain(local_datasets, weights, args, mode="bottom-up", freezing_index=0, ori_assignments=None, device="cpu"):
    """
    freezing_index :: starting from which layer we update the model weights,
                      i.e. freezing_index = 0 means we train the whole network normally
                           freezing_index = len(model) means we freez the entire network
    """
    input_channel = 6
    num_filters = [weights[0].shape[0]]
    input_dim = weights[2].shape[0]
    hidden_dims = [weights[2].shape[1]]
    matched_cnn = SimpleCNNContainer(input_channel = input_channel,num_filters = num_filters,
                                       kernel_size=kernelSize,
                                       input_dim=input_dim,
                                       hidden_dims=hidden_dims,
                                       output_dim=activityCount)

    new_state_dict = {}
    model_counter = 0
    n_layers = int(len(weights) / 2)
    
#     for i in range(np.asarray(weights).shape[0]):
#         logger.info("weights {}".format(np.asarray(weights)[i].shape))

    # we hardcoded this for now: will probably make changes later
    # if mode != "block-wise":
    if mode not in ("block-wise", "squeezing"):
        __non_loading_indices = []

    def __reconstruct_weights(weight, assignment, layer_ori_shape, matched_num_filters=None, weight_type="conv_weight", slice_dim="filter"):
        # what contains in the param `assignment` is the assignment for a certain layer, a certain worker
        """
        para:: slice_dim: for reconstructing the conv layers, for each of the three consecutive layers, we need to slice the 
               filter/kernel to reconstruct the first conv layer; for the third layer in the consecutive block, we need to 
               slice the
               color channel 
        """
        if weight_type == "conv_weight":
            if slice_dim == "filter":
                res_weight = weight[assignment, :]
            elif slice_dim == "channel":
                _ori_matched_shape = list(copy.deepcopy(layer_ori_shape))
                _ori_matched_shape[1] = matched_num_filters
                trans_weight = trans_next_conv_layer_forward(
                    weight, _ori_matched_shape)
                sliced_weight = trans_weight[assignment, :]
                res_weight = trans_next_conv_layer_backward(
                    sliced_weight, layer_ori_shape)
        elif weight_type == "bias":
            res_weight = weight[assignment]
        elif weight_type == "first_fc_weight":
            # NOTE: please note that in this case, we pass the `estimated_shape` to `layer_ori_shape`:
            __ori_shape = weight.shape

            res_weight = weight.reshape(
                matched_num_filters, layer_ori_shape[2]*__ori_shape[1])[assignment, :]
            res_weight = res_weight.reshape(
                (len(assignment)*layer_ori_shape[2], __ori_shape[1]))
        elif weight_type == "fc_weight":
            if slice_dim == "filter":
                res_weight = weight.T[assignment, :]
                #res_weight = res_weight.T
            elif slice_dim == "channel":
                res_weight = weight[assignment, :].T
        return res_weight  

    for param_idx, (key_name, param) in enumerate(matched_cnn.state_dict().items()):
        if (param_idx in __non_loading_indices) and (freezing_index[0] != n_layers):
            # we need to reconstruct the weights here s.t.
            # i) shapes of the weights are euqal to the shapes of the weight in original model (before matching)
            # ii) each neuron comes from the corresponding global neuron
            _matched_weight = weights[param_idx]
            _matched_num_filters = weights[__non_loading_indices[0]].shape[0]
            #
            # we now use this `_slice_dim` for both conv layers and fc layers
            if __non_loading_indices.index(param_idx) != 2:
                # please note that for biases, it doesn't really matter if we're going to use filter or channel
                _slice_dim = "filter"
            else:
                _slice_dim = "channel"
            logger.info("_slice_dim {}".format(_slice_dim))
#             _slice_dim = "channel"

            if "conv" in key_name or "features" in key_name:
                if "weight" in key_name:
                    _res_weight = __reconstruct_weights(weight=_matched_weight, assignment=ori_assignments,
                                                        layer_ori_shape=param.size(), matched_num_filters=_matched_num_filters,
                                                        weight_type="conv_weight", slice_dim=_slice_dim)
                    temp_dict = {key_name: torch.from_numpy(
                        _res_weight.reshape(param.size()))}
                elif "bias" in key_name:
                    _res_bias = __reconstruct_weights(weight=_matched_weight, assignment=ori_assignments,
                                                      layer_ori_shape=param.size(), matched_num_filters=_matched_num_filters,
                                                      weight_type="bias", slice_dim=_slice_dim)
                    temp_dict = {key_name: torch.from_numpy(_res_bias)}
            elif "fc" in key_name or "classifier" in key_name:
                if "weight" in key_name:
                    if freezing_index[0] != 4:
                        _res_weight = __reconstruct_weights(weight=_matched_weight, assignment=ori_assignments,
                                                            layer_ori_shape=param.size(), matched_num_filters=_matched_num_filters,
                                                            weight_type="fc_weight", slice_dim=_slice_dim)
                        temp_dict = {key_name: torch.from_numpy(_res_weight)}
                    else:
                        # that's for handling the first fc layer that is connected to the conv blocks
                        _res_weight = __reconstruct_weights(weight=_matched_weight, assignment=ori_assignments,
                                                            layer_ori_shape=estimated_output.size(), matched_num_filters=_matched_num_filters,
                                                            weight_type="first_fc_weight", slice_dim=_slice_dim)
                        temp_dict = {key_name: torch.from_numpy(_res_weight.T)}
                elif "bias" in key_name:
                    _res_bias = __reconstruct_weights(weight=_matched_weight, assignment=ori_assignments,
                                                      layer_ori_shape=param.size(), matched_num_filters=_matched_num_filters,
                                                      weight_type="bias", slice_dim=_slice_dim)
                    temp_dict = {key_name: torch.from_numpy(_res_bias)}
        else:
            if "conv" in key_name or "features" in key_name:
                if "weight" in key_name:
                    temp_dict = {key_name: torch.from_numpy(
                        weights[param_idx].reshape(param.size()))}
                elif "bias" in key_name:
                    temp_dict = {key_name: torch.from_numpy(
                        weights[param_idx])}
            elif "fc" in key_name or "classifier" in key_name:
                if "weight" in key_name:
#                     logger.info("im in here nara key_name {} weights[param_idx] {}".format(key_name,np.array(weights[param_idx]).shape))
                    temp_dict = {key_name: torch.from_numpy(
                        weights[param_idx].T)}
                elif "bias" in key_name:
                    temp_dict = {key_name: torch.from_numpy(
                        weights[param_idx])}
        new_state_dict.update(temp_dict)
        

    matched_cnn.load_state_dict(new_state_dict)

                    
                
    for param_idx, param in enumerate(matched_cnn.parameters()):
        if mode == "bottom-up":
            # for this freezing mode, we freeze the layer before freezing index
            if param_idx < freezing_index:
                param.requires_grad = False
        elif mode == "per-layer":
            # for this freezing mode, we only unfreeze the freezing index
            if param_idx not in (2*freezing_index-2, 2*freezing_index-1):
                param.requires_grad = False
        elif mode == "block-wise":
            # for block-wise retraining the `freezing_index` becomes a range of indices
            if param_idx not in __non_loading_indices:
                param.requires_grad = False
        elif mode == "squeezing":
            pass

    matched_cnn.to(device).train()
    # start training last fc layers:
    train_dl_local = local_datasets[0]
    test_dl_local = local_datasets[1]

    if mode != "block-wise":
        if freezing_index < (len(weights) - 2):
            optimizer_fine_tune = optim.SGD(filter(
                lambda p: p.requires_grad, matched_cnn.parameters()), lr=args.retrain_lr, momentum=0.9)
        else:
            optimizer_fine_tune = optim.SGD(filter(lambda p: p.requires_grad, matched_cnn.parameters(
            )), lr=(args.retrain_lr/10), momentum=0.9, weight_decay=0.0001)
    else:
        #optimizer_fine_tune = optim.SGD(filter(lambda p: p.requires_grad, matched_cnn.parameters()), lr=args.retrain_lr, momentum=0.9)
        optimizer_fine_tune = optim.Adam(filter(lambda p: p.requires_grad, matched_cnn.parameters(
        )), lr=0.001, weight_decay=0.0001, amsgrad=True)

    criterion_fine_tune = nn.CrossEntropyLoss().to(device)


    if mode != "block-wise":
        if freezing_index < (len(weights) - 2):
            retrain_epochs = args.retrain_epochs
        else:
            retrain_epochs = int(args.retrain_epochs*3)
    else:
        retrain_epochs = args.retrain_epochs

    for epoch in range(retrain_epochs):
        epoch_loss_collector = []
        for batch_idx, (x, target) in enumerate(train_dl_local):
            x, target = x.to(device), target.to(device)

            optimizer_fine_tune.zero_grad()
            x.requires_grad = True
            target.requires_grad = False
            target = target.long()

            out = matched_cnn(x)
            loss = criterion_fine_tune(out, target)
            epoch_loss_collector.append(loss.item())

            loss.backward()
            optimizer_fine_tune.step()

        epoch_loss = sum(epoch_loss_collector) / len(epoch_loss_collector)


    return matched_cnn

In [None]:
def save_model(model, model_index):
    logger.info("saving local model-{}".format(model_index))
    with open("trained_local_model"+str(model_index), "wb") as f_:
        torch.save(model.state_dict(), f_)
    return

In [None]:
def prepare_uniform_weights(n_classes, net_cnt, fill_val=1):
    weights_list = {}

    for net_i in range(net_cnt):
        temp = np.array([fill_val] * n_classes, dtype=np.float32)
        weights_list[net_i] = torch.from_numpy(temp).view(1, -1)

    return weights_list

In [None]:
def pdm_prepare_freq(cls_freqs, n_classes):
    freqs = []

    for net_i in sorted(cls_freqs.keys()):
        net_freqs = [0] * n_classes

        for cls_i in cls_freqs[net_i]:
            net_freqs[cls_i] = cls_freqs[net_i][cls_i]

        freqs.append(np.array(net_freqs))

    return freqs

In [None]:
def trans_next_conv_layer_forward(layer_weight, next_layer_shape):
#     logger.info("trans_next_conv_layer_forward next_layer_shape {}".format(next_layer_shape))
    reshaped = layer_weight.reshape(next_layer_shape).transpose(
        (1, 0, 2)).reshape((next_layer_shape[1], -1))
    return reshaped


def trans_next_conv_layer_backward(layer_weight, next_layer_shape):
#     logger.info("trans_next_conv_layer_backward next_layer_shape {}".format(next_layer_shape))
    reconstructed_next_layer_shape = (
        next_layer_shape[1], next_layer_shape[0], next_layer_shape[2])
    reshaped = layer_weight.reshape(reconstructed_next_layer_shape).transpose(
        1, 0, 2).reshape(next_layer_shape[0], -1)
    return reshaped

In [None]:
def objective(global_weights, global_sigmas):
    obj = ((global_weights)/ global_sigmas).sum()
    return obj


def patch_weights(w_j, L_next, assignment_j_c):
    if assignment_j_c is None:
        return w_j
    new_w_j = np.zeros((w_j.shape[0], L_next))
    new_w_j[:, assignment_j_c] = w_j
    return new_w_j

In [None]:
args.model = "simple-cnn"
logger.info("Initializing nets")
nets, model_meta_data, layer_type = init_models(args.n_nets, args)
logger.info("Retrain? : {}".format(args.retrain))


In [None]:
# local training stage
nets_list = local_train(nets, args, net_dataidx_map, device=device)


In [None]:
train_dl_global, test_dl_global = get_dataloader(
    args.dataset, args_datadir, args.batch_size, 32)

# ensemble part of experiments
logger.info("Computing Uniform ensemble accuracy")
uens_train_acc, _ = compute_ensemble_accuracy(
    nets_list, train_dl_global, n_classes,  uniform_weights=True, device=device)
uens_test_acc, _ = compute_ensemble_accuracy(
    nets_list, test_dl_global, n_classes, uniform_weights=True, device=device)

logger.info("Uniform ensemble (Train acc): {}".format(uens_train_acc))
logger.info("Uniform ensemble (Test acc): {}".format(uens_test_acc))

In [None]:
batch_weights = pdm_prepare_full_weights_cnn(nets_list, device=device)

In [None]:
hungarian_weights, assignments_list = BBP_MAP(
    nets_list, model_meta_data, layer_type, net_dataidx_map, averaging_weights, args, device=device)

In [None]:
logging.info("Weights shapes: {}".format(
    [bw.shape for bw in hungarian_weights]))

In [None]:
batch_weights = pdm_prepare_full_weights_cnn(nets_list, device=device)
total_data_points = sum([len(net_dataidx_map[r])
                         for r in range(args.n_nets)])
fed_avg_freqs = [len(net_dataidx_map[r]) /
                 total_data_points for r in range(args.n_nets)]
logger.info("Total data points: {}".format(total_data_points))
logger.info("Freq of FedAvg: {}".format(fed_avg_freqs))

averaged_weights = []
num_layers = len(batch_weights[0])
for i in range(num_layers):
    avegerated_weight = sum([b[i] * fed_avg_freqs[j]
                             for j, b in enumerate(batch_weights)])
    averaged_weights.append(avegerated_weight)

for aw in averaged_weights:
    logger.info(aw.shape)

models = nets_list

In [None]:
compute_model_averaging_accuracy(models,
                                     averaged_weights,
                                     train_dl_global,
                                     test_dl_global,
                                     n_classes,
                                     args)

In [None]:
def compute_full_cnn_accuracy(models, weights, train_dl, test_dl, n_classes, args, device="cpu",clientServer="Client"):
    """Note that we only handle the FC weights for now"""
    # we need to figure out the FC dims first

    input_channel = 6

    num_filters = [weights[0].shape[0]]
    input_dim = weights[2].shape[0]
    hidden_dims = [weights[2].shape[1]]
    matched_cnn = SimpleCNNContainer(input_channel=input_channel,
                                     num_filters=num_filters,
                                     kernel_size=kernelSize,
                                     input_dim=input_dim,
                                     hidden_dims=hidden_dims,
                                     output_dim=activityCount)
    
#     for index, wei in enumerate(weights):
#         logger.info("weight shapes {}".format(np.asarray(wei).shape))
    if(clientServer == "test"):
        matched_cnn = models
    #logger.info("Keys of layers of convblock ...")
    new_state_dict = {}
    model_counter = 0
    # handle the conv layers part which is not changing
    for param_idx, (key_name, param) in enumerate(matched_cnn.state_dict().items()):
        # print("&"*30)
        #print("Key: {}, Weight Shape: {}, Matched weight shape: {}".format(key_name, param.size(), weights[param_idx].shape))
        # print("&"*30)
        if "conv" in key_name or "features" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(
                    weights[param_idx].reshape(param.size()))}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx])}
        elif "fc" in key_name or "classifier" in key_name:
            if "weight" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx].T)}
            elif "bias" in key_name:
                temp_dict = {key_name: torch.from_numpy(weights[param_idx])}

        new_state_dict.update(temp_dict)
    matched_cnn.load_state_dict(new_state_dict)
    matched_cnn.to(device)
    matched_cnn.eval()

    ##
    criterion = nn.CrossEntropyLoss().to(device)
    correct, total = 0, 0
    total_loss = []
    allPred = []
    for batch_idx, (x, target) in enumerate(test_dl):
        x, target = x.to(device), target.to(device)
        out_k = matched_cnn(x)
        _, pred_label = torch.max(out_k, 1)
        total += x.data.size()[0]
        correct += (pred_label == target.data).sum().item()
#         pred_label = pred_label.squeeze_()
        allPred.append(pred_label)
        loss = criterion(out_k, target.long())
        total_loss.append(loss.item())
        
    allPred = torch.cat(allPred)

    macroVal_f1Test = f1_score(test_dl.dataset.target,allPred.cpu(),average='macro')

        
    clientAccuracyTest = correct/total
    clientLossTest = sum(total_loss) / len(total_loss)
    
    allPred = []
    correct, total = 0, 0
    total_loss = []
    for batch_idx, (x, target) in enumerate(train_dl):
        x, target = x.to(device), target.to(device)
        out_k = matched_cnn(x)
        _, pred_label = torch.max(out_k, 1)
        total += x.data.size()[0]
        correct += (pred_label == target.data).sum().item()
#         pred_label = pred_label.squeeze_()
        allPred.append(pred_label)
        loss = criterion(out_k, target.long())
        total_loss.append(loss.item())
                
    allPred = torch.cat(allPred)
    macroVal_f1Train = f1_score(train_dl.dataset.target,allPred.cpu(),average='macro')


    clientAccuracyTrain = correct/total


    clientLossTrain = sum(total_loss) / len(total_loss)    
    
    return clientAccuracyTest,clientLossTest,clientAccuracyTrain,clientLossTrain,macroVal_f1Train,macroVal_f1Test


In [None]:
comm_init_batch_weights = [copy.deepcopy(
    hungarian_weights) for _ in range(args.n_nets)]

In [None]:
meanHistoryDist = []
stdHistoryDist = []

meanRoundLayerHistory = []
stdRoundLayerHistory = []

meanRoundGeneralLayerHistory = []
stdRoundGeneralLayerHistory = []

roundTime = 0

serverStats,clientSingleStats,clientAllStats,finalWeights,roundTime = fedma_comm(comm_init_batch_weights,
           model_meta_data, layer_type, net_dataidx_map,
           averaging_weights, args,
           train_dl_global,
           test_dl_global,
           assignments_list,
           comm_round=args.comm_round,
           device=device)


serverAccuracy = serverStats[2]
serverLoss = serverStats[3]
serverAccuracyTrain = serverStats[0]
serverLossTrain = serverStats[1]

serverMacroTrain = serverStats[4]
serverMacroTest = serverStats[5]

clientAccuracy = clientSingleStats[4]
clientAccuracyStd = clientSingleStats[5]
clientLoss = clientSingleStats[6]
clientLossStd = clientSingleStats[7]
clientAccuracyTrain = clientSingleStats[0]
clientAccuracyStdTrain = clientSingleStats[1]
clientLossTrain = clientSingleStats[2]
clientLossStdTrain = clientSingleStats[3]

clientMacroTrain = clientSingleStats[8]
clientMacroTest = clientSingleStats[9]


clientAllAccuracy = clientAllStats[4]
clientAllAccuracyStd = clientAllStats[5]
clientAllLoss = clientAllStats[6]
clientAllLossStd = clientAllStats[7]
clientAllAccuracyTrain = clientAllStats[0]
clientAllAccuracyStdTrain = clientAllStats[1]
clientAllLossTrain = clientAllStats[2]
clientAllLossStdTrain = clientAllStats[3]

clientAllMacroTrain = clientAllStats[8]
clientAllMacroTest = clientAllStats[9]


In [None]:
serverAccuracyTrain = np.asarray(serverAccuracyTrain)
serverLossTrain = np.asarray(serverLossTrain)
serverAccuracy = np.asarray(serverAccuracy)
serverLoss = np.asarray(serverLoss)

clientAccuracy = np.asarray(clientAccuracy)
clientAccuracyStd = np.asarray(clientAccuracyStd)
clientLoss = np.asarray(clientLoss)
clientLossStd = np.asarray(clientLossStd)
clientAccuracyTrain = np.asarray(clientAccuracyTrain)
clientAccuracyStdTrain = np.asarray(clientAccuracyStdTrain)
clientLossTrain = np.asarray(clientLossTrain)
clientLossStdTrain = np.asarray(clientLossStdTrain)

clientAllAccuracy = np.asarray(clientAllAccuracy)
clientAllAccuracyStd = np.asarray(clientAllAccuracyStd)
clientAllLoss = np.asarray(clientAllLoss)
clientAllLossStd = np.asarray(clientAllLossStd)
clientAllAccuracyTrain = np.asarray(clientAllAccuracyTrain)
clientAllAccuracyStdTrain = np.asarray(clientAllAccuracyStdTrain)
clientAllLossTrain = np.asarray(clientAllLossTrain)
clientAllLossStdTrain = np.asarray(clientAllLossStdTrain)



clientAllMacroTrain = np.asarray(clientAllMacroTrain)
clientAllMacroTest = np.asarray(clientAllMacroTest)

clientMacroTrain = np.asarray(clientMacroTrain)
clientMacroTest = np.asarray(clientMacroTest)

serverMacroTrain = np.asarray(serverMacroTrain)
serverMacroTest = np.asarray(serverMacroTest)

if(euclid):
    meanHistoryDist = np.asarray(meanHistoryDist).T
    stdHistoryDist = np.asarray(stdHistoryDist).T
    meanRoundLayerHistory = np.asarray(meanRoundLayerHistory).T
    stdRoundLayerHistory = np.asarray(stdRoundLayerHistory).T
    meanRoundGeneralLayerHistory = np.asarray(meanRoundGeneralLayerHistory)
    stdRoundGeneralLayerHistory = np.asarray(stdRoundGeneralLayerHistory)





In [None]:
os.makedirs(filepath+"trainingStats", exist_ok=True)
hkl.dump(serverAccuracy,filepath + "trainingStats/serverAccuracy.hkl" )
hkl.dump(serverLoss,filepath + "trainingStats/serverLoss.hkl" )

hkl.dump(serverAccuracyTrain,filepath + "trainingStats/serverAccuracyTrain.hkl" )
hkl.dump(serverLossTrain,filepath + "trainingStats/serverLossTrain.hkl" )

hkl.dump(clientAccuracy,filepath + "trainingStats/clientAccuracy.hkl" )
hkl.dump(clientAccuracyStd,filepath + "trainingStats/clientAccuracyStd.hkl" )

hkl.dump(clientLoss,filepath + "trainingStats/clientLoss.hkl" )
hkl.dump(clientLossStd,filepath + "trainingStats/clientLossStd.hkl" )

hkl.dump(clientAccuracyTrain,filepath + "trainingStats/clientAccuracyTrain.hkl" )
hkl.dump(clientAccuracyStdTrain,filepath + "trainingStats/clientAccuracyStdTrain.hkl" )

hkl.dump(clientLossTrain,filepath + "trainingStats/clientLossTrain.hkl" )
hkl.dump(clientLossStdTrain,filepath + "trainingStats/clientLossStdTrain.hkl" )

hkl.dump(clientAllAccuracy,filepath + "trainingStats/clientAllAccuracy.hkl" )
hkl.dump(clientAllAccuracyStd,filepath + "trainingStats/clientAllAccuracyStd.hkl" )

hkl.dump(clientAllLoss,filepath + "trainingStats/clientAllLoss.hkl" )
hkl.dump(clientAllLossStd,filepath + "trainingStats/clientAllLossStd.hkl" )

hkl.dump(clientAllAccuracyTrain,filepath + "trainingStats/clientAllAccuracyTrain.hkl" )
hkl.dump(clientAllAccuracyStdTrain,filepath + "trainingStats/clientAllAccuracyStdTrain.hkl" )

hkl.dump(clientAllLossTrain,filepath + "trainingStats/clientAllLossTrain.hkl" )
hkl.dump(clientAllLossStdTrain,filepath + "trainingStats/clientAllLossStdTrain.hkl" )

hkl.dump(clientAllMacroTrain,filepath + "trainingStats/clientAllMacroTrain.hkl" )
hkl.dump(clientAllMacroTest,filepath + "trainingStats/clientAllMacroTest.hkl" )

hkl.dump(clientMacroTrain,filepath + "trainingStats/clientMacroTrain.hkl" )
hkl.dump(clientMacroTest,filepath + "trainingStats/clientMacroTest.hkl" )

hkl.dump(serverMacroTrain,filepath + "trainingStats/serverMacroTrain.hkl" )
hkl.dump(serverMacroTest,filepath + "trainingStats/serverMacroTest.hkl" )

if(euclid):
    hkl.dump(meanHistoryDist,filepath + "trainingStats/meanHistoryDist.hkl" )
    hkl.dump(stdHistoryDist,filepath + "trainingStats/stdHistoryDist.hkl" )
    hkl.dump(meanRoundLayerHistory,filepath + "trainingStats/meanRoundLayerHistory.hkl" )
    hkl.dump(stdRoundLayerHistory,filepath + "trainingStats/stdRoundLayerHistory.hkl" )
    hkl.dump(meanRoundGeneralLayerHistory,filepath + "trainingStats/meanRoundGeneralLayerHistory.hkl" )
    hkl.dump(stdRoundGeneralLayerHistory,filepath + "trainingStats/stdRoundGeneralLayerHistory.hkl" )



In [None]:
epoch_range = range(1, args.comm_round+1)

if(algorithm != "FEDPER"):
    plt.plot(epoch_range, serverAccuracyTrain, label = 'Server Train')
    plt.plot(epoch_range, serverAccuracy, label= 'Server Test')
   
plt.errorbar(epoch_range, clientAccuracyTrain, yerr=clientAccuracyStdTrain, label='Client Own Train', alpha = 0.6)
plt.errorbar(epoch_range, clientAccuracy, yerr=clientAccuracyStd, label='Client Own Test',alpha = 0.6)

if(ClientAllTest  == True):
    plt.errorbar(epoch_range, clientAllAccuracyTrain, yerr=clientAllAccuracyStdTrain, label='Client All Train', alpha = 0.6)
    plt.errorbar(epoch_range, clientAllAccuracy, yerr=clientAllAccuracyStd, label='Client All Test', alpha = 0.6)
    
    plt.plot(epoch_range, clientAllAccuracyTrain,markevery=[np.argmax(clientAllAccuracyTrain)], ls="", marker="o",color="purple")
    plt.plot(epoch_range, clientAllAccuracy,markevery=[np.argmax(clientAllAccuracy)], ls="", marker="o",color="brown")  

plt.plot(epoch_range, serverAccuracyTrain,markevery=[np.argmax(serverAccuracyTrain)], ls="", marker="o",color="blue")
plt.plot(epoch_range, serverAccuracy,markevery=[np.argmax(serverAccuracy)], ls="", marker="o",color="orange") 

plt.plot(epoch_range, clientAccuracyTrain,markevery=[np.argmax(clientAccuracyTrain)], ls="", marker="o",color="green")
plt.plot(epoch_range, clientAccuracy,markevery=[np.argmax(clientAccuracy)], ls="", marker="o",color="red")  
    
plt.title('Model accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Communication Round')
plt.legend(loc='lower right')
plt.savefig(filepath+'LearningAccuracy.png', dpi=100)
plt.clf()

if(algorithm != "FEDPER"):
    plt.plot(epoch_range, serverLossTrain, label = 'Server Train')
    plt.plot(epoch_range, serverLoss, label= 'Server Test')

plt.errorbar(epoch_range, clientLossTrain, yerr=clientLossStdTrain, label='Client Own Train',alpha = 0.6)
plt.errorbar(epoch_range, clientLoss, yerr=clientLossStd, label='Client Own Test',alpha = 0.6)
    
if(ClientAllTest  == True):
    plt.errorbar(epoch_range, clientAllLossTrain, yerr=clientAllLossStdTrain, label='Client All Train',alpha = 0.6)
    plt.errorbar(epoch_range, clientAllLoss, yerr=clientAllLossStd, label='Client All Test',alpha = 0.6)
    plt.plot(epoch_range, clientAllLossTrain,markevery=[np.argmin(clientAllLossTrain)], ls="", marker="o",color="purple")
    plt.plot(epoch_range, clientAllLoss,markevery=[np.argmin(clientAllLoss)], ls="", marker="o",color="brown")  
    
plt.plot(epoch_range, serverLossTrain,markevery=[np.argmin(serverLossTrain)], ls="", marker="o",color="blue")
plt.plot(epoch_range, serverLoss,markevery=[np.argmin(serverLoss)], ls="", marker="o",color="orange") 

plt.plot(epoch_range, clientLossTrain,markevery=[np.argmin(clientLossTrain)], ls="", marker="o",color="green")
plt.plot(epoch_range, clientLoss,markevery=[np.argmin(clientLoss)], ls="", marker="o",color="red")  

plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Communication Round')
plt.legend(loc='upper right')

plt.savefig(filepath+'LearningLoss.png', dpi=100)
plt.clf()

In [None]:
if(euclid):
    for i in range(clientCount):
        plt.errorbar(epoch_range, meanHistoryDist[i], yerr=stdHistoryDist[i], label='Client '+str(i+1))
    plt.title('Distance between client & server model')
    plt.ylabel('Euclidiance Distance')
    plt.xlabel('Communication Round')
    plt.savefig(filepath+'allClientEuclid.png', dpi=100)
#     plt.legend(loc='upper right')
    plt.clf()
    
    for i in range(len(layerType)):
        plt.errorbar(epoch_range, meanRoundLayerHistory[i], yerr=stdRoundLayerHistory[i], label='Layer '+str(i+1)) 
    plt.errorbar(epoch_range, meanRoundGeneralLayerHistory, yerr=stdRoundGeneralLayerHistory, label='General')
    plt.title('Layer distance between client & server model')
    plt.ylabel('Euclidiance Distance')
    plt.xlabel('Communication Round')
    plt.savefig(filepath+'LayerClientEuclid.png', dpi=100)
    plt.legend(loc='upper right')
    plt.clf()

In [None]:
hkl.dump(finalWeights,filepath + "trainingStats/finalWeights.hkl")

In [None]:
logger.info("Best Server accuracy at {} at CR {}".format(serverAccuracy.max(),serverAccuracy.argmax()+1))
logger.info("Best client client accuracy at {} at CR {}".format(clientAccuracy.max(),clientAccuracy.argmax()+1))
logger.info("Best all client accuracy at {} at CR {}".format(clientAllAccuracy.max(),clientAllAccuracy.argmax()+1))

In [None]:
modelStatistics = {
    "Server accuracy" : roundNumber(serverAccuracy.max()),
    "Best server round:": serverAccuracy.argmax()+1,
    "Single client accuracy:" : roundNumber(clientAccuracy.max()),
    "Best single client Round:" : clientAccuracy.argmax()+1,
    "All client accuracy:": roundNumber(clientAllAccuracy.max()),
    "Best all client Round": clientAllAccuracy.argmax()+1,
    "All client F-Measure:":roundNumber(clientAllMacroTest),
    "Single client F-Measure:":roundNumber(clientMacroTest),
    "Server F-Measure:":roundNumber(serverMacroTest),
    "Time per round:":roundTime
}    


with open(filepath +'bestAccuracyStats.csv','w') as f:
    w = csv.writer(f)
    w.writerows(modelStatistics.items())

In [None]:
modelFinalShape = {}
for i in range (np.int(np.floor(np.asarray(finalWeights[0]).shape[0]/2))):
    modelFinalShape["Layer "+str(i+1)] = np.asarray(finalWeights[0][i*2+1]).shape

In [None]:
with open(filepath +'modelFinalShape.csv','w') as f:
    w = csv.writer(f)
    w.writerows(modelFinalShape.items())