In [1]:
from __future__ import print_function
import numpy as np
import mxnet as mx
from mxnet import nd, autograd, gluon
mx.random.seed(1)

In [2]:
ctx = mx.gpu()

In [3]:
batch_size = 64
num_inputs = 784
num_outputs = 10
def transform(data, label):
    return nd.transpose(data.astype(np.float32), (2,0,1))/255, label.astype(np.float32)

train_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=True, transform=transform),
                                      batch_size, shuffle=True)
test_data = gluon.data.DataLoader(gluon.data.vision.MNIST(train=False, transform=transform),
                                     batch_size, shuffle=False)


  label = np.fromstring(fin.read(), dtype=np.uint8).astype(np.int32)
  data = np.fromstring(fin.read(), dtype=np.uint8)


In [4]:
num_fc = 512
net = gluon.nn.Sequential()
with net.name_scope():
    net.add(gluon.nn.Conv2D(channels=20, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
    net.add(gluon.nn.Conv2D(channels=50, kernel_size=5, activation='relu'))
    net.add(gluon.nn.MaxPool2D(pool_size=2, strides=2))
    # The Flatten layer collapses all axis, except the first one, into one axis.
    net.add(gluon.nn.Flatten())
    net.add(gluon.nn.Dense(num_fc, activation="relu"))
    net.add(gluon.nn.Dense(num_outputs))

In [5]:
net.collect_params().initialize(mx.init.Xavier(magnitude=2.24), ctx=ctx)

In [6]:
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()

In [7]:
trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': .1})

In [14]:
def evaluate_accuracy(data_iterator, net):
    acc = mx.metric.Accuracy()
    for i, (data, label) in enumerate(data_iterator):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        output = net(data)
        predictions = nd.argmax(output, axis=1)
        acc.update(preds=predictions, labels=label)
        print(i, type(data))
    return acc.get()[1]

In [15]:
epochs = 1
smoothing_constant = .01

for e in range(epochs):
    for i, (data, label) in enumerate(train_data):
        data = data.as_in_context(ctx)
        label = label.as_in_context(ctx)
        with autograd.record():
            output = net(data)
            loss = softmax_cross_entropy(output, label)
        loss.backward()
        trainer.step(data.shape[0])

        ##########################
        #  Keep a moving average of the losses
        ##########################
        curr_loss = nd.mean(loss).asscalar()
        moving_loss = (curr_loss if ((i == 0) and (e == 0))
                       else (1 - smoothing_constant) * moving_loss + smoothing_constant * curr_loss)

    test_accuracy = evaluate_accuracy(test_data, net)
    train_accuracy = evaluate_accuracy(train_data, net)
    print("Epoch %s. Loss: %s, Train_acc %s, Test_acc %s" % (e, moving_loss, train_accuracy, test_accuracy))

0 <class 'mxnet.ndarray.ndarray.NDArray'>
1 <class 'mxnet.ndarray.ndarray.NDArray'>
2 <class 'mxnet.ndarray.ndarray.NDArray'>
3 <class 'mxnet.ndarray.ndarray.NDArray'>
4 <class 'mxnet.ndarray.ndarray.NDArray'>
5 <class 'mxnet.ndarray.ndarray.NDArray'>
6 <class 'mxnet.ndarray.ndarray.NDArray'>
7 <class 'mxnet.ndarray.ndarray.NDArray'>
8 <class 'mxnet.ndarray.ndarray.NDArray'>
9 <class 'mxnet.ndarray.ndarray.NDArray'>
10 <class 'mxnet.ndarray.ndarray.NDArray'>
11 <class 'mxnet.ndarray.ndarray.NDArray'>
12 <class 'mxnet.ndarray.ndarray.NDArray'>
13 <class 'mxnet.ndarray.ndarray.NDArray'>
14 <class 'mxnet.ndarray.ndarray.NDArray'>
15 <class 'mxnet.ndarray.ndarray.NDArray'>
16 <class 'mxnet.ndarray.ndarray.NDArray'>
17 <class 'mxnet.ndarray.ndarray.NDArray'>
18 <class 'mxnet.ndarray.ndarray.NDArray'>
19 <class 'mxnet.ndarray.ndarray.NDArray'>
20 <class 'mxnet.ndarray.ndarray.NDArray'>
21 <class 'mxnet.ndarray.ndarray.NDArray'>
22 <class 'mxnet.ndarray.ndarray.NDArray'>
23 <class 'mxnet.ndar

34 <class 'mxnet.ndarray.ndarray.NDArray'>
35 <class 'mxnet.ndarray.ndarray.NDArray'>
36 <class 'mxnet.ndarray.ndarray.NDArray'>
37 <class 'mxnet.ndarray.ndarray.NDArray'>
38 <class 'mxnet.ndarray.ndarray.NDArray'>
39 <class 'mxnet.ndarray.ndarray.NDArray'>
40 <class 'mxnet.ndarray.ndarray.NDArray'>
41 <class 'mxnet.ndarray.ndarray.NDArray'>
42 <class 'mxnet.ndarray.ndarray.NDArray'>
43 <class 'mxnet.ndarray.ndarray.NDArray'>
44 <class 'mxnet.ndarray.ndarray.NDArray'>
45 <class 'mxnet.ndarray.ndarray.NDArray'>
46 <class 'mxnet.ndarray.ndarray.NDArray'>
47 <class 'mxnet.ndarray.ndarray.NDArray'>
48 <class 'mxnet.ndarray.ndarray.NDArray'>
49 <class 'mxnet.ndarray.ndarray.NDArray'>
50 <class 'mxnet.ndarray.ndarray.NDArray'>
51 <class 'mxnet.ndarray.ndarray.NDArray'>
52 <class 'mxnet.ndarray.ndarray.NDArray'>
53 <class 'mxnet.ndarray.ndarray.NDArray'>
54 <class 'mxnet.ndarray.ndarray.NDArray'>
55 <class 'mxnet.ndarray.ndarray.NDArray'>
56 <class 'mxnet.ndarray.ndarray.NDArray'>
57 <class '

225 <class 'mxnet.ndarray.ndarray.NDArray'>
226 <class 'mxnet.ndarray.ndarray.NDArray'>
227 <class 'mxnet.ndarray.ndarray.NDArray'>
228 <class 'mxnet.ndarray.ndarray.NDArray'>
229 <class 'mxnet.ndarray.ndarray.NDArray'>
230 <class 'mxnet.ndarray.ndarray.NDArray'>
231 <class 'mxnet.ndarray.ndarray.NDArray'>
232 <class 'mxnet.ndarray.ndarray.NDArray'>
233 <class 'mxnet.ndarray.ndarray.NDArray'>
234 <class 'mxnet.ndarray.ndarray.NDArray'>
235 <class 'mxnet.ndarray.ndarray.NDArray'>
236 <class 'mxnet.ndarray.ndarray.NDArray'>
237 <class 'mxnet.ndarray.ndarray.NDArray'>
238 <class 'mxnet.ndarray.ndarray.NDArray'>
239 <class 'mxnet.ndarray.ndarray.NDArray'>
240 <class 'mxnet.ndarray.ndarray.NDArray'>
241 <class 'mxnet.ndarray.ndarray.NDArray'>
242 <class 'mxnet.ndarray.ndarray.NDArray'>
243 <class 'mxnet.ndarray.ndarray.NDArray'>
244 <class 'mxnet.ndarray.ndarray.NDArray'>
245 <class 'mxnet.ndarray.ndarray.NDArray'>
246 <class 'mxnet.ndarray.ndarray.NDArray'>
247 <class 'mxnet.ndarray.ndarra

417 <class 'mxnet.ndarray.ndarray.NDArray'>
418 <class 'mxnet.ndarray.ndarray.NDArray'>
419 <class 'mxnet.ndarray.ndarray.NDArray'>
420 <class 'mxnet.ndarray.ndarray.NDArray'>
421 <class 'mxnet.ndarray.ndarray.NDArray'>
422 <class 'mxnet.ndarray.ndarray.NDArray'>
423 <class 'mxnet.ndarray.ndarray.NDArray'>
424 <class 'mxnet.ndarray.ndarray.NDArray'>
425 <class 'mxnet.ndarray.ndarray.NDArray'>
426 <class 'mxnet.ndarray.ndarray.NDArray'>
427 <class 'mxnet.ndarray.ndarray.NDArray'>
428 <class 'mxnet.ndarray.ndarray.NDArray'>
429 <class 'mxnet.ndarray.ndarray.NDArray'>
430 <class 'mxnet.ndarray.ndarray.NDArray'>
431 <class 'mxnet.ndarray.ndarray.NDArray'>
432 <class 'mxnet.ndarray.ndarray.NDArray'>
433 <class 'mxnet.ndarray.ndarray.NDArray'>
434 <class 'mxnet.ndarray.ndarray.NDArray'>
435 <class 'mxnet.ndarray.ndarray.NDArray'>
436 <class 'mxnet.ndarray.ndarray.NDArray'>
437 <class 'mxnet.ndarray.ndarray.NDArray'>
438 <class 'mxnet.ndarray.ndarray.NDArray'>
439 <class 'mxnet.ndarray.ndarra

607 <class 'mxnet.ndarray.ndarray.NDArray'>
608 <class 'mxnet.ndarray.ndarray.NDArray'>
609 <class 'mxnet.ndarray.ndarray.NDArray'>
610 <class 'mxnet.ndarray.ndarray.NDArray'>
611 <class 'mxnet.ndarray.ndarray.NDArray'>
612 <class 'mxnet.ndarray.ndarray.NDArray'>
613 <class 'mxnet.ndarray.ndarray.NDArray'>
614 <class 'mxnet.ndarray.ndarray.NDArray'>
615 <class 'mxnet.ndarray.ndarray.NDArray'>
616 <class 'mxnet.ndarray.ndarray.NDArray'>
617 <class 'mxnet.ndarray.ndarray.NDArray'>
618 <class 'mxnet.ndarray.ndarray.NDArray'>
619 <class 'mxnet.ndarray.ndarray.NDArray'>
620 <class 'mxnet.ndarray.ndarray.NDArray'>
621 <class 'mxnet.ndarray.ndarray.NDArray'>
622 <class 'mxnet.ndarray.ndarray.NDArray'>
623 <class 'mxnet.ndarray.ndarray.NDArray'>
624 <class 'mxnet.ndarray.ndarray.NDArray'>
625 <class 'mxnet.ndarray.ndarray.NDArray'>
626 <class 'mxnet.ndarray.ndarray.NDArray'>
627 <class 'mxnet.ndarray.ndarray.NDArray'>
628 <class 'mxnet.ndarray.ndarray.NDArray'>
629 <class 'mxnet.ndarray.ndarra

794 <class 'mxnet.ndarray.ndarray.NDArray'>
795 <class 'mxnet.ndarray.ndarray.NDArray'>
796 <class 'mxnet.ndarray.ndarray.NDArray'>
797 <class 'mxnet.ndarray.ndarray.NDArray'>
798 <class 'mxnet.ndarray.ndarray.NDArray'>
799 <class 'mxnet.ndarray.ndarray.NDArray'>
800 <class 'mxnet.ndarray.ndarray.NDArray'>
801 <class 'mxnet.ndarray.ndarray.NDArray'>
802 <class 'mxnet.ndarray.ndarray.NDArray'>
803 <class 'mxnet.ndarray.ndarray.NDArray'>
804 <class 'mxnet.ndarray.ndarray.NDArray'>
805 <class 'mxnet.ndarray.ndarray.NDArray'>
806 <class 'mxnet.ndarray.ndarray.NDArray'>
807 <class 'mxnet.ndarray.ndarray.NDArray'>
808 <class 'mxnet.ndarray.ndarray.NDArray'>
809 <class 'mxnet.ndarray.ndarray.NDArray'>
810 <class 'mxnet.ndarray.ndarray.NDArray'>
811 <class 'mxnet.ndarray.ndarray.NDArray'>
812 <class 'mxnet.ndarray.ndarray.NDArray'>
813 <class 'mxnet.ndarray.ndarray.NDArray'>
814 <class 'mxnet.ndarray.ndarray.NDArray'>
815 <class 'mxnet.ndarray.ndarray.NDArray'>
816 <class 'mxnet.ndarray.ndarra

In [12]:
net.save_params("models/cnn_1gpu_mnist.par")

In [13]:
data.shape[0]

32