测试Mxnet在神经网络训练过程中导出梯度的方法
在step调用前,通过layer.weight.data().grad可获得当前梯度
因更新权值为 weight = weight_ - grad_t*lr/batch_size

In [13]:
import mxnet as mx

mx.random.seed(42)
mnist = mx.test_utils.get_mnist()

In [14]:
batch_size = 100
train_data = mx.io.NDArrayIter(mnist['train_data'],mnist['train_label'],batch_size, shuffle=True)
val_data = mx.io.NDArrayIter(mnist['test_data'],mnist['test_label'],batch_size)

In [15]:
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
from mxnet import ndarray as nd
net = nn.Sequential()
with net.name_scope():
    net.add(nn.Dense(128,activation='relu'))
    net.add(nn.Dense(64,activation='relu'))
    net.add(nn.Dense(10))

gpus = mx.test_utils.list_gpus()
ctx = [mx.gpu()] if gpus else [mx.cpu(0), mx.cpu(1)]
net.initialize(mx.init.Xavier(magnitude=2.24),ctx=ctx)
trainer = gluon.Trainer(net.collect_params(),'sgd',{'learning_rate':0.02})

net(nd.random.uniform(shape=(1,28,28),ctx=ctx[0]))


[[-0.2462503   0.00970606 -0.17562176  0.15085515  0.08591089  0.34705222
   0.31139    -0.2723901   0.5233319   0.1465436 ]]
<NDArray 1x10 @gpu(0)>

In [16]:
epoch = 10
metric = mx.metric.Accuracy()
smc_loss = gluon.loss.SoftmaxCrossEntropyLoss()
import copy

train_data.reset()
for batch in train_data:
    data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0)
    label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0)
    outputs = []
    oldnet = copy.deepcopy(net)
    with ag.record():
        for x,y in zip(data,label):
            z = net(x)
            loss = smc_loss(z, y)
            loss.backward()
            outputs.append(z)
    metric.update(label,outputs)
    trainer.step(batch.data[0].shape[0])
    break

name,acc = metric.get()
metric.reset()
print('training acc at epoch  %s=%f'%(name,acc))

training acc at epoch  accuracy=0.110000


In [17]:
batch.data[0].shape[0]

100

In [18]:
net[2].weight.data()


[[ 1.06219165e-01 -6.81807399e-02  1.66785240e-01 -8.53006691e-02
  -1.79632396e-01  1.00756489e-01  2.25179702e-01  2.09298804e-01
   5.91011718e-02 -2.03805089e-01 -1.27911448e-01 -1.01885289e-01
   1.61775470e-01 -8.95325467e-02 -1.48227587e-01  1.65068448e-01
  -1.09122686e-01  6.48109466e-02 -1.26142323e-01 -1.11803599e-01
  -1.74754485e-01  1.32522881e-01  2.21957713e-01 -8.09206888e-02
  -7.04270825e-02 -1.98768675e-01 -5.51394783e-02 -2.75089424e-02
  -1.16302967e-01 -1.01506067e-02  1.84768200e-01 -5.99520504e-02
   9.56663117e-02 -1.36495829e-01 -2.19197646e-02  3.01616453e-02
  -1.95187569e-01 -1.73100039e-01 -1.52075989e-02 -1.61726385e-01
   7.34412894e-02 -4.16175015e-02 -1.38483286e-01 -1.18397489e-01
  -1.81381062e-01  2.16985449e-01 -1.49193436e-01 -1.39406353e-01
   2.07765326e-01  1.64345443e-01  1.01500511e-01  4.61183861e-02
   2.04222649e-01  1.50770783e-01  9.36110914e-02  3.38063017e-02
  -1.04917996e-01  2.28027448e-01  1.14396214e-01 -1.71315655e-01
  -1.8519

In [19]:
new = net[2].weight.data()
old = oldnet[2].weight.data()
grad_test = old-new
grad_test


[[-1.43870711e-05 -7.04228878e-05  3.80873680e-05 -9.75579023e-05
   1.86562538e-05 -1.07161701e-04  4.76539135e-05  1.67474151e-04
   6.48945570e-06 -6.81728125e-05  4.06801701e-05 -1.35630369e-04
  -1.15990639e-04 -6.34565949e-05 -1.28284097e-04 -9.89437103e-06
   1.41270459e-04 -1.86249614e-04 -1.71184540e-04 -4.44874167e-05
  -1.44690275e-05 -2.14636326e-04 -1.22785568e-05  1.41784549e-05
  -1.10454857e-04 -5.71906567e-05 -8.21426511e-06  7.39675015e-05
   4.51058149e-05  7.05225393e-05 -3.09675932e-04  1.51991844e-06
   3.63536179e-04 -6.58333302e-05  9.22605395e-05 -1.98926777e-04
   1.69277191e-05  3.57776880e-05  5.41610643e-05 -5.49852848e-06
  -3.77744436e-06  3.53343785e-05  2.99215317e-05 -1.35600567e-06
   1.69053674e-04 -6.09904528e-05  2.50011683e-04  1.94907188e-05
   1.14813447e-04 -1.12652779e-05 -6.62803650e-05  1.62646174e-05
  -2.18749046e-05  8.55326653e-05 -7.51018524e-05  6.44251704e-05
   9.58517194e-05  7.83950090e-05  3.25143337e-05  8.62777233e-06
   0.0000

In [20]:
net[2].weight.data().grad*0.02/100


[[-1.43849966e-05 -7.04251433e-05  3.80922902e-05 -9.75586445e-05
   1.86622492e-05 -1.07162428e-04  4.76535497e-05  1.67473423e-04
   6.48813011e-06 -6.81742895e-05  4.06838517e-05 -1.35629452e-04
  -1.15983676e-04 -6.34569151e-05 -1.28288695e-04 -9.89085311e-06
   1.41273573e-04 -1.86247431e-04 -1.71177089e-04 -4.44868492e-05
  -1.44656478e-05 -2.14643514e-04 -1.22851961e-05  1.41790797e-05
  -1.10455177e-04 -5.71845740e-05 -8.21439062e-06  7.39668903e-05
   4.51038468e-05  7.05223283e-05 -3.09681171e-04  1.52023870e-06
   3.63536208e-04 -6.58347708e-05  9.22613763e-05 -1.98926427e-04
   1.69253108e-05  3.57726494e-05  5.41611189e-05 -5.49622928e-06
  -3.77385618e-06  3.53351243e-05  2.99187159e-05 -1.35909318e-06
   1.69049934e-04 -6.09958879e-05  2.50013924e-04  1.94901768e-05
   1.14815251e-04 -1.12708985e-05 -6.62825405e-05  1.62663455e-05
  -2.18713121e-05  8.55277467e-05 -7.51052212e-05  6.44254542e-05
   9.58480159e-05  7.83939176e-05  3.25124856e-05  8.63251353e-06
   0.0000

In [21]:
net[0]

Dense(784 -> 128, Activation(relu))

In [22]:
net[0].weight.data().grad


[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
<NDArray 128x784 @gpu(0)>

In [23]:
net[1].weight.data().grad


[[ 1.00654036e-01  5.50387979e-01  1.31264895e-01 ...  3.24882507e-01
   4.18776751e-01  1.90174177e-01]
 [ 6.96533918e-01  3.26293588e-01 -4.71295476e-01 ... -4.15953845e-01
   7.81149328e-01  4.42714900e-01]
 [-9.49838981e-02 -1.18255347e-01  1.17416136e-01 ...  2.28661239e-01
  -2.25636899e-01 -4.37596887e-02]
 ...
 [ 1.13742739e-01  8.75184089e-02 -2.92807817e-06 ... -1.64443731e-01
   4.40927148e-01  5.63036025e-01]
 [-1.84871554e-01 -2.11241156e-01 -9.79393870e-02 ... -5.06389141e-01
  -7.99530983e-01 -1.12881780e-01]
 [ 1.73685074e-01 -2.31371611e-01  5.14518246e-02 ... -2.14173257e-01
   4.38343346e-01  3.79437983e-01]]
<NDArray 64x128 @gpu(0)>

In [24]:
net[2].weight.data().grad


[[-7.19249845e-02 -3.52125704e-01  1.90461457e-01 -4.87793237e-01
   9.33112502e-02 -5.35812140e-01  2.38267764e-01  8.37367117e-01
   3.24406512e-02 -3.40871453e-01  2.03419268e-01 -6.78147316e-01
  -5.79918385e-01 -3.17284554e-01 -6.41443491e-01 -4.94542681e-02
   7.06367910e-01 -9.31237161e-01 -8.55885506e-01 -2.22434238e-01
  -7.23282397e-02 -1.07321763e+00 -6.14259839e-02  7.08954036e-02
  -5.52275896e-01 -2.85922885e-01 -4.10719514e-02  3.69834453e-01
   2.25519240e-01  3.52611661e-01 -1.54840589e+00  7.60119362e-03
   1.81768107e+00 -3.29173863e-01  4.61306900e-01 -9.94632125e-01
   8.46265554e-02  1.78863257e-01  2.70805597e-01 -2.74811462e-02
  -1.88692808e-02  1.76675633e-01  1.49593592e-01 -6.79546595e-03
   8.45249653e-01 -3.04979444e-01  1.25006974e+00  9.74508896e-02
   5.74076295e-01 -5.63544929e-02 -3.31412703e-01  8.13317299e-02
  -1.09356560e-01  4.27638769e-01 -3.75526130e-01  3.22127283e-01
   4.79240060e-01  3.91969621e-01  1.62562430e-01  4.31625694e-02
   0.0000