测试梯度更新算法可行性
以及for循环遍历神经网络层可行性

In [1]:
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet import ndarray as nd
from mxnet.gluon import loss 

def LeNet_(activation='relu'):
    # 获取一个结构定义完整的LeNet卷积神经网络
    # 激活函数可自选 默认为sigmoid
    net = nn.Sequential()
    net.add(nn.Conv2D(channels=6, kernel_size=(5,5), activation=activation),
            nn.MaxPool2D(pool_size=(2,2), strides=(2,2)),
            nn.Conv2D(channels=16, kernel_size=(5,5), activation=activation),
            nn.MaxPool2D(pool_size=(2,2), strides=(2,2)),
            # Dense会默认将(批量大小， 通道， 高， 宽)形状的输入转换成
            # (批量大小， 通道 * 高 * 宽)形状的输入
            nn.Dense(120, activation=activation),
            nn.Dense(84, activation=activation),
            nn.Dense(10))
    return net

In [2]:
import mxnet as mx
net = LeNet_()
input_shape = (1,1,28,28)
ctx = [mx.gpu()]
mx.random.seed(42)
net.initialize(mx.init.Xavier(magnitude=2.24),ctx=ctx)
_ = net(nd.random.uniform(shape=input_shape,ctx=ctx[0]))

In [3]:
#验证
#val_x,val_y = val_data_set[0],val_data_set[1]
#val_data = mx.io.NDArrayIter(val_x,val_y,batch_size=100)
mnist = mx.test_utils.get_mnist()
val_data = mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size=100)    
for batch in val_data:
    data = gluon.utils.split_and_load(batch.data[0],ctx_list=ctx,batch_axis=0)
    label = gluon.utils.split_and_load(batch.label[0],ctx_list=ctx,batch_axis=0)
    outputs = []
    metric = mx.metric.Accuracy()
    for x in data:
        outputs.append(net(x))
    metric.update(label,outputs)
print('验证集准确率 validation acc:%s=%f'%metric.get())

验证集准确率 validation acc:accuracy=0.100000


In [4]:

def init_gradient(net,local_gradient=[]):
    local_gradient['weight'].clear()
    local_gradient['bias'].clear()
    for layer in net:
        try:
            shape_w = layer.weight.data().shape
            shape_b = layer.bias.data().shape
        except:
            continue
        local_gradient['weight'].append(nd.zeros(shape=shape_w,ctx=ctx[0]))
        local_gradient['bias'].append(nd.zeros(shape=shape_b,ctx=ctx[0]))
    
def collect_gradient(net, local_gradient,batch_size):
    idx = 0
    for layer in net:
        try:
            grad_w = layer.weight.data().grad
            grad_b = layer.bias.data().grad
        except:
            continue
        local_gradient['weight'][idx] = local_gradient['weight'][idx] + grad_w.as_in_context(local_gradient['weight'][idx].context)/batch_size
        local_gradient['bias'][idx] = local_gradient['bias'][idx] + grad_b.as_in_context(local_gradient['bias'][idx].context)/batch_size
        idx+=1

def updata_gradient(net,gradient_info,learning_rate):
    # 由Client回传的梯度信息 更新Server模型
    idx = 0
    grad_w = gradient_info['weight']
    grad_b = gradient_info['bias']
    update_flag = False
    for layer in net:
        try:
            layer.weight.data()[:] = layer.weight.data()[:] - learning_rate*grad_w[idx]
            #layer.weight.set_data(layer.weight.data()[:] - learning_rate*gradient_info[idx])
            layer.bias.data()[:] = layer.bias.data()[:] - learning_rate*grad_b[idx]
        except:
            continue
        idx += 1
    """
    if update_flag:
        print("-gradient successfully updated-")
    else:
        print("-gradient failure-")
    """


In [5]:
#训练
train_data = mx.io.NDArrayIter(mnist['train_data'],mnist['train_label'],batch_size=100) 
epoch = 10
metric = mx.metric.Accuracy()
smc_loss = gluon.loss.SoftmaxCrossEntropyLoss()
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.02})
#初始化梯度
gradient_info = {'weight':[],'bias':[]}


In [6]:
# 正常训练
for i in range(epoch):
    train_data.reset()
    for batch in train_data:
        init_gradient(net,gradient_info)
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
        label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
        outputs = []
        with ag.record():
            for x,y in zip(data,label):
                z = net(x)
                loss = smc_loss(z, y)
                loss.backward()
                outputs.append(z)
        #收集梯度
        collect_gradient(net,gradient_info,batch_size=batch.data[0].shape[0])
        updata_gradient(net,gradient_info,learning_rate=0.02)
        metric.update(label,outputs)
        #trainer.step(batch.data[0].shape[0])
    name,acc = metric.get()
    metric.reset()
    print('training acc at epoch %d, %s=%f'%(i,name,acc))
    # YA DA ZE

training acc at epoch 0, accuracy=0.702617
training acc at epoch 1, accuracy=0.939650
training acc at epoch 2, accuracy=0.959250
training acc at epoch 3, accuracy=0.967967
training acc at epoch 4, accuracy=0.973167
training acc at epoch 5, accuracy=0.976783
training acc at epoch 6, accuracy=0.979517
training acc at epoch 7, accuracy=0.981517
training acc at epoch 8, accuracy=0.983017
training acc at epoch 9, accuracy=0.984200


In [6]:
# 梯度采集 模型更新测试
train_data.reset()
for batch in train_data:
    init_gradient(net,gradient_info)
    data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
    label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
    outputs = []
    with ag.record():
        for x,y in zip(data,label):
            z = net(x)
            loss = smc_loss(z, y)
            loss.backward()
            outputs.append(z)
    #收集梯度
    collect_gradient(net,gradient_info,batch_size=batch.data[0].shape[0])
    #updata_gradient(gradient_info,learning_rate=0.02)
    metric.update(label,outputs)
    trainer.step(batch.data[0].shape[0])

name,acc = metric.get()
metric.reset()
print('training acc at epoch %s=%f'%(name,acc))

training acc at epoch accuracy=0.702000


In [7]:
val_data.reset()
for batch in val_data:
    data = gluon.utils.split_and_load(batch.data[0],ctx_list=ctx,batch_axis=0)
    label = gluon.utils.split_and_load(batch.label[0],ctx_list=ctx,batch_axis=0)
    outputs = []
    metric = mx.metric.Accuracy()
    for x in data:
        outputs.append(net(x))
    metric.update(label,outputs)
print('验证集准确率 validation acc:%s=%f'%metric.get())

验证集准确率 validation acc:accuracy=0.990000
