### MNIST Convolutional Neural Network Model

Import libraries

In [10]:
import mnist_cnn
import numpy as np
import torch.nn as nn
from torchsummary import summary
import copy

Load MNIST Data

In [11]:
train_x, train_y, test_x, test_y = mnist_cnn.load_MNIST_data()
print(train_x.shape)
print(train_y.shape)
print(test_x.shape)
print(test_y.shape)

(60000, 784)
(60000,)
(10000, 784)
(10000,)


We need to rehape the data back into a 1x28x28 image

In [12]:
train_x = np.reshape(train_x, (train_x.shape[0], 1, 28, 28))
test_x = np.reshape(test_x, (test_x.shape[0], 1, 28, 28))

Split into train and dev

In [13]:
dev_split_index = int(9 * len(train_x) / 10)
dev_x = train_x[dev_split_index:]
dev_y = train_y[dev_split_index:]
train_x = train_x[:dev_split_index]
train_y = train_y[:dev_split_index]

permutation = np.array([i for i in range(len(train_x))])
np.random.shuffle(permutation)
train_x = [train_x[i] for i in permutation]
train_y = [train_y[i] for i in permutation]

Split dataset into batches

In [14]:
batch_size = 32
train_batches = mnist_cnn.batchify_data(train_x, train_y, batch_size)
dev_batches = mnist_cnn.batchify_data(dev_x, dev_y, batch_size)
test_batches = mnist_cnn.batchify_data(test_x, test_y, batch_size)

Model specification

In [15]:
model = nn.Sequential(
      nn.Conv2d(1, 32, (3, 3)),
      nn.ReLU(),
      nn.MaxPool2d((2, 2)),
      nn.Conv2d(32, 64, (3, 3)),
      nn.ReLU(),
      nn.MaxPool2d((2, 2)),
      mnist_cnn.Flatten(),
      nn.Linear(1600, 128),
      nn.Dropout(0.5),
      nn.Linear(128, 10),
    )

In [16]:
model_summary = copy.deepcopy(model)
summary(model_summary, (1, 28,28))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 26, 26]             320
              ReLU-2           [-1, 32, 26, 26]               0
         MaxPool2d-3           [-1, 32, 13, 13]               0
            Conv2d-4           [-1, 64, 11, 11]          18,496
              ReLU-5           [-1, 64, 11, 11]               0
         MaxPool2d-6             [-1, 64, 5, 5]               0
           Flatten-7                 [-1, 1600]               0
            Linear-8                  [-1, 128]         204,928
           Dropout-9                  [-1, 128]               0
           Linear-10                   [-1, 10]           1,290
Total params: 225,034
Trainable params: 225,034
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.52
Params size (MB): 0.86
Estimated T

Train model

In [17]:
mnist_cnn.train_model(train_batches, dev_batches, model, nesterov=True)

  0%|          | 7/1687 [00:00<00:25, 65.89it/s]

-------------
Epoch 1:



100%|██████████| 1687/1687 [00:25<00:00, 50.77it/s]
 13%|█▎        | 25/187 [00:00<00:00, 246.96it/s]

Train loss: 0.233922 | Train accuracy: 0.926849


100%|██████████| 187/187 [00:00<00:00, 252.16it/s]
  0%|          | 7/1687 [00:00<00:26, 62.63it/s]

Val loss:   0.058919 | Val accuracy:   0.984124
-------------
Epoch 2:



100%|██████████| 1687/1687 [00:28<00:00, 60.21it/s]
 14%|█▍        | 26/187 [00:00<00:00, 252.92it/s]

Train loss: 0.079176 | Train accuracy: 0.976678


100%|██████████| 187/187 [00:00<00:00, 244.10it/s]
  0%|          | 7/1687 [00:00<00:25, 65.50it/s]

Val loss:   0.053302 | Val accuracy:   0.985461
-------------
Epoch 3:



100%|██████████| 1687/1687 [00:30<00:00, 54.79it/s]
 13%|█▎        | 24/187 [00:00<00:00, 231.36it/s]

Train loss: 0.059469 | Train accuracy: 0.981939


100%|██████████| 187/187 [00:00<00:00, 222.10it/s]
  0%|          | 3/1687 [00:00<00:59, 28.19it/s]

Val loss:   0.046066 | Val accuracy:   0.987299
-------------
Epoch 4:



100%|██████████| 1687/1687 [00:32<00:00, 52.41it/s]
 14%|█▍        | 26/187 [00:00<00:00, 259.91it/s]

Train loss: 0.047528 | Train accuracy: 0.985329


100%|██████████| 187/187 [00:00<00:00, 256.80it/s]
  0%|          | 8/1687 [00:00<00:21, 77.71it/s]

Val loss:   0.040804 | Val accuracy:   0.989472
-------------
Epoch 5:



100%|██████████| 1687/1687 [00:23<00:00, 72.22it/s]
 15%|█▍        | 28/187 [00:00<00:00, 278.61it/s]

Train loss: 0.040010 | Train accuracy: 0.987200


100%|██████████| 187/187 [00:00<00:00, 272.32it/s]
  0%|          | 8/1687 [00:00<00:22, 74.91it/s]

Val loss:   0.041399 | Val accuracy:   0.989138
-------------
Epoch 6:



100%|██████████| 1687/1687 [00:23<00:00, 72.25it/s]
 13%|█▎        | 24/187 [00:00<00:00, 230.50it/s]

Train loss: 0.034654 | Train accuracy: 0.989293


100%|██████████| 187/187 [00:00<00:00, 256.11it/s]
  0%|          | 7/1687 [00:00<00:26, 62.98it/s]

Val loss:   0.039093 | Val accuracy:   0.989806
-------------
Epoch 7:



100%|██████████| 1687/1687 [00:23<00:00, 71.61it/s]
 14%|█▍        | 27/187 [00:00<00:00, 261.09it/s]

Train loss: 0.030590 | Train accuracy: 0.990664


100%|██████████| 187/187 [00:00<00:00, 269.39it/s]
  0%|          | 8/1687 [00:00<00:22, 74.66it/s]

Val loss:   0.034824 | Val accuracy:   0.990475
-------------
Epoch 8:



100%|██████████| 1687/1687 [00:23<00:00, 71.48it/s]
 14%|█▍        | 27/187 [00:00<00:00, 261.37it/s]

Train loss: 0.026449 | Train accuracy: 0.992146


100%|██████████| 187/187 [00:00<00:00, 279.97it/s]
  0%|          | 7/1687 [00:00<00:24, 69.17it/s]

Val loss:   0.037062 | Val accuracy:   0.990642
-------------
Epoch 9:



100%|██████████| 1687/1687 [00:23<00:00, 72.07it/s]
 14%|█▍        | 26/187 [00:00<00:00, 256.79it/s]

Train loss: 0.024167 | Train accuracy: 0.992257


100%|██████████| 187/187 [00:00<00:00, 269.59it/s]
  0%|          | 8/1687 [00:00<00:21, 78.48it/s]

Val loss:   0.041018 | Val accuracy:   0.990140
-------------
Epoch 10:



100%|██████████| 1687/1687 [00:23<00:00, 72.62it/s]
 14%|█▍        | 27/187 [00:00<00:00, 264.94it/s]

Train loss: 0.021148 | Train accuracy: 0.993165


100%|██████████| 187/187 [00:00<00:00, 264.57it/s]


Val loss:   0.037539 | Val accuracy:   0.990642


0.9906417112299465

Evaluate the model on test data

In [18]:
loss, accuracy = mnist_cnn.run_epoch(test_batches, model.eval(), None)

print ("Loss on test set:"  + str(loss) + " Accuracy on test set: " + str(accuracy))

100%|██████████| 312/312 [00:01<00:00, 286.00it/s]

Loss on test set:0.026171966654860858 Accuracy on test set: 0.9917868589743589



