## Federated Averaging 模拟实验

In [1]:
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet import ndarray as nd
from Algorithm.CNN import CNN_Model
from Algorithm.MLP import MLP
from Tools import utils
import numpy as np
import copy
import random

In [2]:
def get_data(data_idx, batch_size):
    data = "E:\\PythonProjects\\Mxnet_FederatedLearning\\Fed_Client\\FedAvg_data\\train_data" + str(data_idx) + ".npy"
    label = "E:\\PythonProjects\\Mxnet_FederatedLearning\\Fed_Client\\FedAvg_data\\train_label" + str(data_idx) + ".npy"
    data = np.load(data)
    label = np.load(label)
    train_data = mx.io.NDArrayIter(data,label,batch_size=batch_size,shuffle=True)
    return train_data

def validate(net):
    mnist = mx.test_utils.get_mnist()
    ctx = utils.try_all_gpus()
    val_data = mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size=100)
    metric = mx.metric.Accuracy()
    val_data.reset()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    name,acc =metric.get()
    return acc

In [3]:
def train(net, batch_size, learning_rate, client_id, epoch, ctx):
    train_data = get_data(client_id, batch_size)
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    trainer = trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learning_rate})
    metric = mx.metric.Accuracy()
    for i in range(epoch):
        train_data.reset()
        for batch in train_data:
            data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
            label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    loss = softmax_cross_entropy_loss(z, y)
                    print(loss)
                    loss.backward()
                    outputs.append(z)
            metric.update(label, outputs)
            trainer.step(batch.data[0].shape[0])
        name, acc = metric.get()
        metric.reset()
        print('training acc at epoch %d: %s=%f'%(i, name, acc))

In [4]:
def zero_net(net,ctx):
    net.initialize(mx.init.Xavier(),ctx=ctx)
    x = nd.random.uniform(0,255,shape=(1,28,28),ctx=ctx[0])
    net(x)
    for layer in net:
        try:
            layer.weight.data()[:] = nd.zeros(shape=layer.weight.data().shape,ctx=ctx[0])
            layer.bias.data()[:] = nd.zeros(shape=layer.bias.data().shape,ctx=ctx[0])
        except:
            pass
        
def merge_net(client_list):
    print("合并模型 Client_list: ",client_list)
    # 将client_list中指向的模型参数合并
    path = "E:\\PythonProjects\\Mxnet_FederatedLearning\\FedAvgExp\\"
    num = len(client_list)
    ctx = utils.try_all_gpus()
    #mnet = CNN_Model('LeNet')
    mnet = MLP()
    zero_net(mnet,ctx)
    #tmp = CNN_Model('LeNet')
    tmp = MLP()
    for id in client_list:
        tmp.load_parameters(path + str(id) + ".params",ctx=ctx)  
        lay = 0
        for layer in tmp:
            try:
                mnet[lay].weight.data()[:] += layer.weight.data()[:]
                mnet[lay].bias.data()[:] += layer.bias.data()[:]
            except:
                pass
    for layer in mnet:
        try:
            layer.weight.data()[:] = layer.weight.data()[:]/num
            layer.bias.data()[:] = layer.bias.data()[:]/num
        except:
            pass
    return copy.deepcopy(mnet)

def Fed_Avg(net, Class, BatchSize, Epoch, LearningRate):
    total_client = 100
    merge_client_num = int(total_client*Class)
    round_num = int(total_client/merge_client_num)
    merge_list = [x for x in range(100)]
    random.shuffle(merge_list)  # 打乱Client编号列表
    ctx = utils.try_all_gpus()
    for round in range(round_num): #分割Client
        print("round %d"%round)
        round_list = merge_list[round*merge_client_num: (round+1)*merge_client_num]
        net.load_parameters("E:\\PythonProjects\\Mxnet_FederatedLearning\\FedAvgExp\\update.params",ctx=ctx)
        for client_id in round_list:    #一个合并轮数 merge_client_num
            print("running on Client %d"%client_id)
            train(net, BatchSize, LearningRate, client_id, Epoch,ctx) # 训练模型
            net.save_parameters("E:\\PythonProjects\\Mxnet_FederatedLearning\\FedAvgExp\\" + str(client_id) +".params")
        net = merge_net(round_list)
        acc = validate(net)
        print("验证集准确率：%f"%acc)
        net.save_parameters("E:\\PythonProjects\\Mxnet_FederatedLearning\\FedAvgExp\\update.params")