In [4]:
import mxnet as mx
import logging
logging.getLogger().setLevel(logging.DEBUG)  

mnist = mx.test_utils.get_mnist()
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()

batch_size = 128
train_iter = mx.io.NDArrayIter(mnist['train_data'], mnist['train_label'], batch_size, shuffle=True)
val_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)

data = mx.sym.var('data')
data = mx.sym.flatten(data=data)

fc1  = mx.sym.FullyConnected(data=data, num_hidden = 512)
act1 = mx.sym.Activation(data=fc1, act_type="relu")

fc2  = mx.sym.FullyConnected(data=act1, num_hidden = 512)
act2 = mx.sym.Activation(data=fc2, act_type="relu")

fc3  = mx.sym.FullyConnected(data=act2, num_hidden=10)
mlp  = mx.sym.SoftmaxOutput(data=fc3, name='softmax')

mlp_model = mx.mod.Module(symbol=mlp, context=ctx)


import matplotlib.pyplot as plt

mx.viz.plot_network(mlp).view()  


mlp_model.fit(train_iter, 
              eval_data=val_iter,  
              optimizer='sgd',  
              optimizer_params={'learning_rate':0.1}, 
              eval_metric='acc', 
              batch_end_callback = mx.callback.Speedometer(batch_size, 100),
              num_epoch=20) 

test_iter = mx.io.NDArrayIter(mnist['test_data'], mnist['test_label'], batch_size)
acc = mx.metric.Accuracy()
mlp_model.score(test_iter, acc)
print(acc)

INFO:root:train-labels-idx1-ubyte.gz exists, skipping download
INFO:root:train-images-idx3-ubyte.gz exists, skipping download
INFO:root:t10k-labels-idx1-ubyte.gz exists, skipping download
INFO:root:t10k-images-idx3-ubyte.gz exists, skipping download
INFO:root:Epoch[0] Batch [100]	Speed: 10419.20 samples/sec	accuracy=0.114093
INFO:root:Epoch[0] Batch [200]	Speed: 9022.10 samples/sec	accuracy=0.151953
INFO:root:Epoch[0] Batch [300]	Speed: 11532.59 samples/sec	accuracy=0.399531
INFO:root:Epoch[0] Batch [400]	Speed: 12049.41 samples/sec	accuracy=0.721797
INFO:root:Epoch[0] Train-accuracy=0.792969
INFO:root:Epoch[0] Time cost=5.601
INFO:root:Epoch[0] Validation-accuracy=0.810819
INFO:root:Epoch[1] Batch [100]	Speed: 13381.66 samples/sec	accuracy=0.840347
INFO:root:Epoch[1] Batch [200]	Speed: 12863.71 samples/sec	accuracy=0.865078
INFO:root:Epoch[1] Batch [300]	Speed: 12712.07 samples/sec	accuracy=0.880000
INFO:root:Epoch[1] Batch [400]	Speed: 14564.47 samples/sec	accuracy=0.895078
INFO:root

INFO:root:Epoch[18] Batch [200]	Speed: 14403.11 samples/sec	accuracy=0.995234
INFO:root:Epoch[18] Batch [300]	Speed: 14243.63 samples/sec	accuracy=0.995625
INFO:root:Epoch[18] Batch [400]	Speed: 15612.54 samples/sec	accuracy=0.996250
INFO:root:Epoch[18] Train-accuracy=0.996783
INFO:root:Epoch[18] Time cost=4.087
INFO:root:Epoch[18] Validation-accuracy=0.978145
INFO:root:Epoch[19] Batch [100]	Speed: 13764.55 samples/sec	accuracy=0.995204
INFO:root:Epoch[19] Batch [200]	Speed: 14664.64 samples/sec	accuracy=0.996094
INFO:root:Epoch[19] Batch [300]	Speed: 13529.28 samples/sec	accuracy=0.996328
INFO:root:Epoch[19] Batch [400]	Speed: 13847.31 samples/sec	accuracy=0.996719
INFO:root:Epoch[19] Train-accuracy=0.997013
INFO:root:Epoch[19] Time cost=4.309
INFO:root:Epoch[19] Validation-accuracy=0.978738


EvalMetric: {'accuracy': 0.9787381329113924}


In [4]:
mlp_model.save_checkpoint('models/mxnet.parms',20) 

INFO:root:Saved checkpoint to "models/mxnet.parms-0020.params"
