# Hierarchical Loss
The most basic use case for MNIST is to train a model that classifies digits. This tutorial shows how to train a model that can classify digits as well as distinguish between even and odd numbers.

## Prepare data
First we download the dataset and obtain data iterators for MNIST training and validation sets.

In [61]:
import sys
import os
import mxnet as mx
import mxnet.metric
import numpy as np
mxnet_path = os.path.dirname(os.path.abspath(os.path.expanduser(mxnet.__file__)))
sys.path.append(os.path.join(mxnet_path, "../../tests/python/common"))
sys.path.append(os.path.join(mxnet_path, "../../example/python-howto"))
from data import mnist_iterator
train, val = mnist_iterator(batch_size=100, input_shape = (784,))

## Data iterator

Now we will create a new iterator that will allow us to multitask, classifying digits *and* parity. We can build off of the MNIST iterator, which returns batches of input data and labels for digits 0, 1,...,9. We need this iterator to return batches of data with both labels for digits 0, 1,...,9 as well as labels for even and odd. To create this second set of labels, we simply calculate the parity based off of the ground truth we already have - the digit labels.

* For **provide_data()** we are going to return MNIST iterator's data description since we want to keep the input data format the same.
* For **provide_label()** we are returning a list of 2 data descriptions: the first item is the name and shape of the digit labels, the second item is the name and shape of the parity labels.
* For **next()** we are returning a DataBatch object whose data field is set to the batch input, and the label field is set to a list of 2 lists: one with labels for the digit and the other with labels for the parity.

[Read more about the DataIter API.](http://mxnet.io/api/python/io.html?highlight=dataiter#mxnet.io.DataIter)

In [62]:
class EvenOdd_iterator(mx.io.DataIter):
    '''multi label ilab iterator'''

    def __init__(self, data_iter):
        super(EvenOdd_iterator, self).__init__()
        self.data_iter = data_iter
        self.batch_size = self.data_iter.batch_size

    @property
    def provide_data(self):
        return self.data_iter.provide_data

    @property
    def provide_label(self):
        return [mx.io.DataDesc('softmaxdigit_label', (self.batch_size,), np.float32),
                mx.io.DataDesc('softmaxeo_label', (self.batch_size,), np.float32)]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next()
        labels = []
        labels.append(batch.label[0]) 
        eolabels = []
        for i in batch.label[0].asnumpy():
            eolabels.append(i%2)
        eolabels = mx.nd.array(np.array(eolabels))
        labels.append(eolabels)
        return mx.io.DataBatch(data=batch.data, label=labels, pad=batch.pad, index=batch.index)


Now we just pass the MNIST iterators to iterator we just created.

In [63]:
_train = EvenOdd_iterator(train)
_val = EvenOdd_iterator(val)

## Evaluation

We create a class to help evaluate how accurately the model is classifying digits *and* parity. We know that for each example, the model will make 2 predictions (one for each of the above). So for a batch of n examples, the model makes 2n predictions. We want to keep track of how many of the 2n predictions were correct. 

Our class can extend the mx.metric.EvalMetric class. We will set **num_inst** to 2n and increment **sum_metric** every time the model predicted correctly. Then, EvalMetric can use those to numbers to return an evaluation number.

[Read more about EvalMetric API](http://mxnet.io/api/python/model.html?highlight=evalmetric#mxnet.metric.EvalMetric)

To show the status of training, we will also write a function to be called after every epoch. What would be helpful in those status reports is a simple visualization of how many correct and incorrect predictions there are for each label. Take parity for example:

[[90  10]
 [ 20 80]]
 
The first row tells us that the model correctly labeled even examples 90 times, and incorrectly labeled even examples as odd 10 times. The second row tells us that the model correctly labeled odd examples 80 times, and incorrectly labeled even examples as odd 20 times. What we want to see are big numbers on the main diagonal.

So, we will also build this state into the evaluation class and have the status function print this state out after every epoch.

In [64]:
class MNISTPerDigitAccuracy(mx.metric.EvalMetric):
    """Calculate accuracy"""

    def reset(self):
        if hasattr(self, 'cms'):
            for cm in self.cms:
                cm.fill(0)
        

    def __init__(self,sizes):
        super(MNISTPerDigitAccuracy, self).__init__('mnistperdigiaccuracy')
        self.cms = map( lambda x: np.zeros((x,x), dtype=int), sizes )
        self.reset()
        self.thing = 0
        return super(MNISTPerDigitAccuracy, self).reset()

    def update(self, labels, preds):
        mx.metric.check_label_shapes(labels, preds)
        for i in range(len(labels)):
            for label, pred_label in zip(labels[i].asnumpy(), preds[i].asnumpy()):
                pred_label = int(np.argmax(pred_label))
                label = int(label)
                self.cms[i][label,pred_label] += 1
                self.sum_metric += (pred_label == label)
                self.num_inst += 1
                
def perDigitMetric(params):
    for i, cm in enumerate(params.eval_metric.cms):
        print "Label=", i
        print cm

# Model

Now we build the model. It will be similar to the one we build for digit classification, except we add another layer, fc3eo, to predict the parity.

In [65]:
data = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data = data, name='fc1', num_hidden=128)
act1 = mx.symbol.Activation(data = fc1, name='relu1', act_type="relu")
fc2 = mx.symbol.FullyConnected(data = act1, name = 'fc2', num_hidden = 64)
act2 = mx.symbol.Activation(data = fc2, name='relu2', act_type="relu")
fc3digit = mx.symbol.FullyConnected(data = act2, name='fc3', num_hidden=10)
fc3eo = mx.symbol.FullyConnected(data = act2, name='fc3eo', num_hidden=2)

Now we create the output symbols. We will group those two symbols together so we can multitask our learning.

In [66]:
mlp1 = mx.symbol.SoftmaxOutput(data = fc3digit, name = 'softmaxdigit')
mlp2 = mx.symbol.SoftmaxOutput(data = fc3eo, name = 'softmaxeo')
mlp = mx.symbol.Group([mlp1,mlp2])

optimizer_params = (('learning_rate', 0.1),('momentum', 0.9), ('wd', 0.00001))

# Training
Now let's train the model.

In [67]:
mod = mx.mod.Module(mlp,label_names=['softmaxdigit_label','softmaxeo_label'])
mod.fit(_train, 
        eval_data=_val,
        optimizer_params=optimizer_params,
        eval_metric=MNISTPerDigitAccuracy([10,2]),
        eval_end_callback=perDigitMetric,
        num_epoch=3, batch_end_callback=mx.callback.Speedometer(100,100))

INFO:root:Epoch[0] Batch [100]	Speed: 31839.65 samples/sec	Train-mnistperdigiaccuracy=0.403762
INFO:root:Epoch[0] Batch [200]	Speed: 35066.44 samples/sec	Train-mnistperdigiaccuracy=0.605572
INFO:root:Epoch[0] Batch [300]	Speed: 27333.91 samples/sec	Train-mnistperdigiaccuracy=0.712475
INFO:root:Epoch[0] Batch [400]	Speed: 27062.64 samples/sec	Train-mnistperdigiaccuracy=0.769825
INFO:root:Epoch[0] Batch [500]	Speed: 33326.43 samples/sec	Train-mnistperdigiaccuracy=0.807036
INFO:root:Epoch[0] Train-mnistperdigiaccuracy=0.831658
INFO:root:Epoch[0] Time cost=1.945
INFO:root:Epoch[0] Validation-mnistperdigiaccuracy=0.850236


Label= 0
[[ 959    0    2    2    0    3    8    5    1    0]
 [   0 1118    5    2    0    1    4    1    4    0]
 [   9    0  996    7    4    1    2    8    5    0]
 [   0    0   10  974    0    2    0   11    4    9]
 [   1    0    3    0  918    0   14    2    3   41]
 [   3    1    1   34    1  819   10    9    5    9]
 [  10    3    3    1    5    8  925    0    3    0]
 [   1    8   16    4    1    0    0  971    0   27]
 [   9    2   11   19    6   15   24    8  869   11]
 [   4    5    0   10   17    8    0    6    0  959]]
Label= 1
[[4728  198]
 [  76 4998]]


INFO:root:Epoch[1] Batch [100]	Speed: 29954.20 samples/sec	Train-mnistperdigiaccuracy=0.864182
INFO:root:Epoch[1] Batch [200]	Speed: 35452.56 samples/sec	Train-mnistperdigiaccuracy=0.875816
INFO:root:Epoch[1] Batch [300]	Speed: 35527.01 samples/sec	Train-mnistperdigiaccuracy=0.885080
INFO:root:Epoch[1] Batch [400]	Speed: 27720.73 samples/sec	Train-mnistperdigiaccuracy=0.892893
INFO:root:Epoch[1] Batch [500]	Speed: 31716.02 samples/sec	Train-mnistperdigiaccuracy=0.899563
INFO:root:Epoch[1] Train-mnistperdigiaccuracy=0.905000
INFO:root:Epoch[1] Time cost=1.858
INFO:root:Epoch[1] Validation-mnistperdigiaccuracy=0.909675


Label= 0
[[ 966    0    0    0    0    2    9    1    2    0]
 [   0 1121    3    2    0    1    2    3    3    0]
 [  12    0  995    4    2    2    1    7    7    2]
 [   0    0    5  934    0   54    0    4    7    6]
 [   2    0    1    0  952    0   14    2    1   10]
 [   3    1    1    1    3  866    8    2    3    4]
 [   6    2    1    0    4    5  938    0    2    0]
 [   1    5    9    5    4    0    0  969    3   32]
 [  15    0    9    7   10   16   25    3  879   10]
 [   4    2    0    7   32   10    0    3    2  949]]
Label= 1
[[4852   74]
 [  86 4988]]


INFO:root:Epoch[2] Batch [100]	Speed: 34112.68 samples/sec	Train-mnistperdigiaccuracy=0.913931
INFO:root:Epoch[2] Batch [200]	Speed: 32506.58 samples/sec	Train-mnistperdigiaccuracy=0.917936
INFO:root:Epoch[2] Batch [300]	Speed: 33970.97 samples/sec	Train-mnistperdigiaccuracy=0.921379
INFO:root:Epoch[2] Batch [400]	Speed: 34072.36 samples/sec	Train-mnistperdigiaccuracy=0.924339
INFO:root:Epoch[2] Batch [500]	Speed: 33754.04 samples/sec	Train-mnistperdigiaccuracy=0.927194
INFO:root:Epoch[2] Train-mnistperdigiaccuracy=0.929848
INFO:root:Epoch[2] Time cost=1.775
INFO:root:Epoch[2] Validation-mnistperdigiaccuracy=0.931976


Label= 0
[[ 974    1    2    1    0    0    0    0    1    1]
 [   0 1125    3    0    0    0    3    0    4    0]
 [   6    1 1009    2    1    0    0    4    9    0]
 [   1    2   14  954    0   17    0    4   11    7]
 [   2    0    2    0  947    0    5    2    0   24]
 [   8    0    1    3    0  847   13    7    4    9]
 [  12    2   12    0    3    4  921    1    3    0]
 [   2    4   11    3    1    0    0  966    4   37]
 [   9    1   10    5    4    4    7    2  928    4]
 [   2    2    1    4    7    6    1    1    3  982]]
Label= 1
[[4870   56]
 [ 106 4968]]
