In [1]:
import mxnet as mx
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet import ndarray as nd
from Algorithm.CNN import CNN_Model
from Algorithm.MLP import MLP
from Tools import utils
import numpy as np
import random

In [2]:
ctx = utils.try_all_gpus()
mlp = MLP()
lenet = CNN_Model('LeNet')


def get_data(data_idx, batch_size):
    data = "E:\\PythonProjects\\Mxnet_FederatedLearning\\Fed_Client\\FedAvg_data_Non-IID\\train_data" + str(data_idx) + ".npy"
    label = "E:\\PythonProjects\\Mxnet_FederatedLearning\\Fed_Client\\FedAvg_data_Non-IID\\train_label" + str(data_idx) + ".npy"
    data = np.load(data)
    label = np.load(label)
    train_data = mx.io.NDArrayIter(data,label,batch_size=batch_size,shuffle=True)
    return train_data

def validate(net):
    mnist = mx.test_utils.get_mnist()
    val_data = mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size=100)
    metric = mx.metric.Accuracy()
    val_data.reset()
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        for x in data:
            outputs.append(net(x))
        metric.update(label, outputs)
    name,acc =metric.get()
    return acc

In [3]:
def train_till_acc(net, stop_acc, learning_rate, batch_size, epoch):
    round = 0
    clist = [x for x in range(100)]
    while True:
        random.shuffle(clist)
        for client in clist:
            client = client%100
            train_data = get_data(client, batch_size)
            softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
            trainer = trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': learning_rate})
            metric = mx.metric.Accuracy()
            stop_flag = False
            print("Round %d starting"%(round))
            for i in range(epoch):
                train_data.reset()
                for batch in train_data:
                    data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
                    label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
                    outputs = []
                    with ag.record():
                        for x, y in zip(data, label):
                            z = net(x)
                            loss = softmax_cross_entropy_loss(z, y)
                            loss.backward()
                            outputs.append(z)
                    metric.update(label, outputs)
                    trainer.step(batch.data[0].shape[0])
                name, acc = metric.get()
                metric.reset()
                print('training acc at epoch %d: %s=%f'%(i, name, acc))
            acc = validate(net)
            print("accuracy in test set %f"%(acc))
            if acc >= stop_acc:
                break
            round += 1
        print("round: ",round)

In [4]:
lenet.load_parameters("LeNet.params",ctx=ctx)
train_till_acc(lenet, 0.97, 0.01, 600, 1)

d 2776 starting
training acc at epoch 0: accuracy=0.983333
accuracy in test set 0.963800
Round 2777 starting
training acc at epoch 0: accuracy=0.983333
accuracy in test set 0.962600
Round 2778 starting
training acc at epoch 0: accuracy=0.936667
accuracy in test set 0.965500
Round 2779 starting
training acc at epoch 0: accuracy=0.963333
accuracy in test set 0.967100
Round 2780 starting
training acc at epoch 0: accuracy=0.941667
accuracy in test set 0.968600
Round 2781 starting
training acc at epoch 0: accuracy=0.953333
accuracy in test set 0.966100
Round 2782 starting
training acc at epoch 0: accuracy=0.975000
accuracy in test set 0.966600
Round 2783 starting
training acc at epoch 0: accuracy=0.983333
accuracy in test set 0.964700
Round 2784 starting
training acc at epoch 0: accuracy=0.940000
accuracy in test set 0.966100
Round 2785 starting
training acc at epoch 0: accuracy=0.978333
accuracy in test set 0.968200
Round 2786 starting
training acc at epoch 0: accuracy=0.970000
accuracy in

KeyboardInterrupt: 