# Handwritten Digit & Parity Recognition
MNIST is a database containing a total of 70,000 images of handwritten digits collected from a combination of Census Bureau employees and high school students. The database also includes labels for each of the digits. For example, a handwritten "7" has a label 7 to serve as the "ground truth". In other words, we can train the model to classify digits given the images because we have the correct answer. Because MNIST is sanitized and has these labels, it makes for a good dataset to run image recognition algorithms on - something deep learning performs well at.

The most basic use case for MNIST is to train a model that classifies digits. For example, we can build a model that classifies each image as one of 0, 1,..., or 9. To view a tutorial on this basic use case, see [Handwritten Digit Recognition](https://github.com/dmlc/mxnet-notebooks/blob/master/python/tutorials/mnist.ipynb). This tutorial will take it up a notch by showing how to train a model that can classify digits *as well* as distinguish between even and odd numbers. For example, given an image that looks like a "7", the model should classify the image as digit 7 and parity odd.

## Prepare data
First we download the dataset and obtain data iterators for MNIST training and validation sets. Read more about the [MNISTIter API](http://mxnet.io/api/python/io.html#mxnet.io.MNISTIter).

In [199]:
import sys
import os
import mxnet as mx
import mxnet.metric
import numpy as np

# Finds MXNet source directory so we can import some utility functions
mxnet_path = os.path.dirname(os.path.abspath(os.path.expanduser(mxnet.__file__)))
sys.path.append(os.path.join(mxnet_path, "../../tests/python/common"))
sys.path.append(os.path.join(mxnet_path, "../../example/python-howto"))

# Calls a utility function to provide two MNISTIter objects
from data import mnist_iterator
train, val = mnist_iterator(batch_size=100, input_shape = (784,))

## Data iterator

Now we will create a new iterator that will allow us to multitask, classifying digits *and* parity. We can build off of the MNIST iterator, which returns batches of input data and labels for digits 0, 1,...,9. We need this iterator to return batches of data with both labels for digits 0, 1,...,9 as well as labels for even and odd. To create this second set of labels, we simply calculate the parity based off of the ground truth we already have - the digit labels.

* For **provide_data()** we are going to return MNIST iterator's data description since we want to keep the input data format the same.
* For **provide_label()** we are returning a list of 2 data descriptions: the first item is the name and shape of the digit labels, the second item is the name and shape of the parity labels.
* For **next()** we are returning a DataBatch object whose data field is set to the batch input, and the label field is set to a list of 2 lists: one with labels for the digit and the other with labels for the parity. This way the model can evaluate its accuracy in predicting both labels.

[Read more about the DataIter API](http://mxnet.io/api/python/io.html?highlight=dataiter#mxnet.io.DataIter) and see a tutorial on [developing new iterators](https://github.com/dmlc/mxnet-notebooks/blob/master/python/basic/data.ipynb).

In [200]:
class DigitParity_iterator(mx.io.DataIter): # Extend DataIter
    '''Multi label iterator'''

    def __init__(self, data_iter):
        super(DigitParity_iterator, self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        # Return the name and shape of the input data as it is
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        # Return the digit and parity labels name/shape/data type as a list
        # You will see later that we name our Softmax symbols "softmaxdigit" and "softmaxparity"
        # MXNet will create labels called "softmaxdigit_label" and "softmaxparity_label"
        # So those are the names we need to provide here
        return [mx.io.DataDesc('softmaxdigit_label', (self.batch_size,), np.float32),
                mx.io.DataDesc('softmaxparity_label', (self.batch_size,), np.float32)]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        # grab the next batch
        batch = self.data_iter.next()
        
        # initialize "labels" with the original digit ground truth we have
        labels = []
        labels.append(batch.label[0]) 
        eolabels = []
        
        # calculate parity labels from the digit labels
        for i in batch.label[0].asnumpy():
            eolabels.append(i%2)
        eolabels = mx.nd.array(np.array(eolabels))
        
        # append new parity labels to "labels"
        labels.append(eolabels)
        
        return mx.io.DataBatch(data=batch.data, label=labels, pad=batch.pad, index=batch.index)


Now we just pass the MNIST iterators to iterator we just created.

In [201]:
_train = DigitParity_iterator(train)
_val = DigitParity_iterator(val)

## Evaluation

We create a class to help evaluate how accurately the model is classifying digits *and* parity. We know that for each example, the model will make 2 predictions (one for each of the above). So for a batch of n examples, the model makes 2n predictions. We want to keep track of how many of the 2n predictions were correct. 

Our class can extend the mx.metric.EvalMetric class. We will set **num_inst** to 2n and increment **sum_metric** every time the model predicted correctly. Then, EvalMetric can use those to numbers to return an evaluation.

[Read more about EvalMetric API](http://mxnet.io/api/python/model.html?highlight=evalmetric#mxnet.metric.EvalMetric).

To show the status of training, we will also write a function to be called after every epoch. What would be helpful in those status reports is a simple visualization of how many correct and incorrect predictions there are for each label. Take parity for example:

```
[[90, 10]
 [20, 80]]
```
 
The first row tells us that the model correctly labeled even examples 90 times, and incorrectly labeled even examples as odd 10 times. The second row tells us that the model correctly labeled odd examples 80 times, and incorrectly labeled even examples as odd 20 times. What we want to see are big numbers on the main diagonal, which tells us that each label is guessed correctly often.

So, we will also build this state into the evaluation class and have the status function print this state out after every epoch.

In [202]:
class MNISTPerDigitAccuracy(mx.metric.EvalMetric): # Extend EvalMetric
    """Calculate accuracy"""

    def reset(self):
        if hasattr(self, 'cms'):
            for cm in self.cms:
                cm.fill(0)
        

    def __init__(self,sizes):
        super(MNISTPerDigitAccuracy, self).__init__('perdigitaccuracy')
        
        # Create matrices for each label type
        # Each matrix is x*x, where x is the number of classes in the label type
        # For digit label it will be 10*10
        # For parity label it will be 2*2
        self.cms = map( lambda x: np.zeros((x,x), dtype=int), sizes )
        self.reset()
        return super(MNISTPerDigitAccuracy, self).reset()

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        
        # Go through each label type: digit and parity
        for i in range(len(labels)):
            # For the current label type, zip up the predictions versus the ground truths
            for label, pred_label in zip(labels[i].asnumpy(), preds[i].asnumpy()):
                # Predictions come in the form of probability distributions
                # For digit labels, it will be a vector of 10 numbers. Ex:
                # [0, .01, .09, 0, 0, .06, .04, .8, 0, 0] shows "7" has the highest probability
                # So we take the maximum from the prediction
                pred_label = int(np.argmax(pred_label))
                
                # Keep track of what predictions the models makes for what ground truths
                # It is a matrix where big numbers on the main diagonal is our goal
                label = int(label)
                self.cms[i][label,pred_label] += 1
                
                # Keep track of how many correct predictions the model makes
                self.sum_metric += (pred_label == label)
                
                # Keep track of how many predictions are made over all
                self.num_inst += 1
                
def perDigitMetric(params):
    # Print the state of what predictions the models makes for what ground truths
    for i, cm in enumerate(params.eval_metric.cms):
        print "Label=", i
        print cm

# Model

Now we build the model. It will be similar to the one we build for digit classification, except we add another layer, fc3eo, to predict the parity.

In [203]:
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3digit = mx.symbol.FullyConnected(data = act2, name='fc3digit', num_hidden=10)

# A fully connected layer for the parity. The hidden size is 2 to account for even and odd.
fc3parity = mx.symbol.FullyConnected(data = act2, name='fc3parity', num_hidden=2)

Now we create the output symbols. We will group those two symbols together so we can multitask our learning.

In [204]:
mlp1 = mx.symbol.SoftmaxOutput(data = fc3digit, name = 'softmaxdigit')
mlp2 = mx.symbol.SoftmaxOutput(data = fc3parity, name = 'softmaxparity')

# Group the digit and parity symbols to multitask
mlp = mx.symbol.Group([mlp1,mlp2])

optimizer_params = (('learning_rate', 0.1),('momentum', 0.9), ('wd', 0.00001))

# Training
Now let's train the model.

In [205]:
# Because we have two sets of labels, one for each of digit and parity SoftmaxOutput,
# we have to explicitly call them out here
mod = mx.mod.Module(mlp, label_names=['softmaxdigit_label','softmaxparity_label'])
mod.fit(_train, 
        eval_data=_val,
        optimizer_params=optimizer_params,
        eval_metric=MNISTPerDigitAccuracy([10,2]),
        eval_end_callback=perDigitMetric,
        num_epoch=3, batch_end_callback=mx.callback.Speedometer(100,100))

INFO:root:Epoch[0] Batch [100]	Speed: 33022.90 samples/sec	Train-perdigitaccuracy=0.390198
INFO:root:Epoch[0] Batch [200]	Speed: 34679.12 samples/sec	Train-perdigitaccuracy=0.603881
INFO:root:Epoch[0] Batch [300]	Speed: 33373.04 samples/sec	Train-perdigitaccuracy=0.712957
INFO:root:Epoch[0] Batch [400]	Speed: 30580.56 samples/sec	Train-perdigitaccuracy=0.770224
INFO:root:Epoch[0] Batch [500]	Speed: 32021.01 samples/sec	Train-perdigitaccuracy=0.807515
INFO:root:Epoch[0] Train-perdigitaccuracy=0.832458
INFO:root:Epoch[0] Time cost=1.826
INFO:root:Epoch[0] Validation-perdigitaccuracy=0.851321


Label= 0
[[ 965    0    0    3    0    2    4    3    3    0]
 [   0 1104    1   10    0    1    1    0   18    0]
 [  12    1  981   16    3    1    9    4    5    0]
 [   0    0    5  990    0    5    0    4    6    0]
 [   2    0    2    0  940    0   14    2    2   20]
 [   5    0    0   56    4  796    7    2   17    5]
 [  10    3    0    1    4    9  926    0    5    0]
 [   1    8   20    9    3    0    0  968    2   17]
 [   7    1    6   22    3    1    8    2  920    4]
 [   3    5    0   10   27    6    2    8    2  946]]
Label= 1
[[4791  135]
 [ 111 4963]]


INFO:root:Epoch[1] Batch [100]	Speed: 32412.70 samples/sec	Train-perdigitaccuracy=0.865418
INFO:root:Epoch[1] Batch [200]	Speed: 36032.52 samples/sec	Train-perdigitaccuracy=0.877009
INFO:root:Epoch[1] Batch [300]	Speed: 34819.55 samples/sec	Train-perdigitaccuracy=0.886309
INFO:root:Epoch[1] Batch [400]	Speed: 33539.16 samples/sec	Train-perdigitaccuracy=0.894074
INFO:root:Epoch[1] Batch [500]	Speed: 34900.47 samples/sec	Train-perdigitaccuracy=0.900733
INFO:root:Epoch[1] Train-perdigitaccuracy=0.906185
INFO:root:Epoch[1] Time cost=1.754
INFO:root:Epoch[1] Validation-perdigitaccuracy=0.911046


Label= 0
[[ 974    0    0    1    0    0    2    1    2    0]
 [   0 1122    4    3    1    0    1    0    4    0]
 [   7    1 1011    3    4    0    3    2    1    0]
 [   5    2   10  972    0    8    0    4    6    3]
 [   1    0    3    0  953    0    6    1    0   18]
 [   8    1    1   13    4  843    7    1    9    5]
 [  11    3    3    0    4    6  925    0    6    0]
 [   2    7   20    7    2    0    0  968    2   20]
 [  23    0   21    3    3    1   10    2  905    6]
 [   4    2    0    5   21    6    1    2    1  967]]
Label= 1
[[4885   41]
 [ 114 4960]]


INFO:root:Epoch[2] Batch [100]	Speed: 32870.39 samples/sec	Train-perdigitaccuracy=0.915363
INFO:root:Epoch[2] Batch [200]	Speed: 36177.10 samples/sec	Train-perdigitaccuracy=0.919403
INFO:root:Epoch[2] Batch [300]	Speed: 34680.55 samples/sec	Train-perdigitaccuracy=0.922878
INFO:root:Epoch[2] Batch [400]	Speed: 33414.76 samples/sec	Train-perdigitaccuracy=0.925980
INFO:root:Epoch[2] Batch [500]	Speed: 33036.99 samples/sec	Train-perdigitaccuracy=0.928848
INFO:root:Epoch[2] Train-perdigitaccuracy=0.931305
INFO:root:Epoch[2] Time cost=1.770
INFO:root:Epoch[2] Validation-perdigitaccuracy=0.933336


Label= 0
[[ 969    0    3    0    0    0    7    1    0    0]
 [   0 1111    8    3    0    0    5    2    5    1]
 [   7    0 1014    0    2    0    3    2    4    0]
 [   1    0   15  959    0   14    0    6    9    6]
 [   1    0    5    0  961    0    8    0    0    7]
 [   6    0    0    7    1  856   11    1    6    4]
 [   6    3    1    0    5    2  934    0    7    0]
 [   1    4   18    2    3    0    0  986    3   11]
 [  15    0   28    0    6    3   15    2  899    6]
 [   5    1    0    2   29    6    1    4    1  960]]
Label= 1
[[4902   24]
 [ 146 4928]]
