In [1]:
# https://mxnet.incubator.apache.org/faq/finetune.html
# https://mxnet.incubator.apache.org/tutorials/python/predict_image.html

import os
import mxnet as mx
import urllib.request
import logging
import numpy as np
import matplotlib.pyplot as plt

num_classes = 200 # Number of birds categories.
num_epoch = 20 # Number of epochs.
batch_size = 50 # Batch size.
data_shape = (3, 224, 224) # Data shape when fed into the model.
ctx = mx.gpu() # p2 has 1 gpu.
prefix = 'resnet-18'
epoch_load = 30 # The epoch number at which to load the model.
logging.getLogger().setLevel(logging.DEBUG)


In [2]:
# Download model.
def download(url):
    filename = url.split("/")[-1]
    if not os.path.exists(filename):
        urllib.request.urlretrieve(url, filename)
    else:
        logging.info('Model file %s exist.', filename)
def get_model(prefix, epoch):
    download(prefix + '-symbol.json')
    download(prefix + '-%04d.params' % (epoch,))    

get_model('http://data.mxnet.io/models/imagenet/resnet/18-layers/'+prefix, 0)
symbol, arg_params, aux_params = mx.model.load_checkpoint(prefix, 0)
# symbol (Symbol) – The symbol configuration of computation network.
# arg_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s weights.
# aux_params (dict of str to NDArray) – Model parameter, dict of name to NDArray of net’s auxiliary states.


In [3]:
# Replace the last fully-connected layer by a new fc layer with output dimension same as the category number.
def get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0'):
    """
    symbol: the pre-trained network symbol
    arg_params: the argument parameters of the pre-trained model
    num_classes: the number of classes for the fine-tune datasets
    layer_name: the layer name before the last fully-connected layer
    """
    all_layers = symbol.get_internals() # Get a new grouped symbol.
    net = all_layers[layer_name + '_output']
    net = mx.symbol.FullyConnected(data=net, num_hidden=num_classes, name='fc1')
    net = mx.symbol.SoftmaxOutput(data=net, name='softmax')
    new_args = {k:arg_params[k] for k in arg_params if 'fc1' not in k}
    return net, new_args
new_sym, new_args = get_fine_tune_model(symbol, arg_params, num_classes, layer_name='flatten0')
net = mx.mod.Module(symbol=new_sym, context=ctx) # Create module by new symbol.
# For comparison, the new and old network structure can be visualized by:
# mx.viz.plot_network(new_sym)
# and
# mx.viz.plot_network(symbol)

In [4]:
# https://mxnet.incubator.apache.org/api/python/io/io.html#mxnet.io.ImageRecordIter
def get_iterators():
    train_itr = mx.io.ImageRecordIter(
        path_imgrec         = './dataset/io_data/bird_data_train.rec',
        batch_size          = batch_size,
        data_shape          = data_shape,
        shuffle             = True,
        rand_crop           = False,
        rand_mirror         = False)
    test_itr = mx.io.ImageRecordIter(
        path_imgrec         = './dataset/io_data/bird_data_val.rec',
        batch_size          = batch_size,
        data_shape          = data_shape,
        rand_crop           = False,
        rand_mirror         = False)
    return (train_itr, test_itr)
train_itr, test_itr = get_iterators()
net.bind(data_shapes=train_itr.provide_data, label_shapes=train_itr.provide_label)

In [None]:
# Model training.
net.fit(train_itr, test_itr,
        num_epoch = num_epoch,
        arg_params = new_args,
        aux_params = aux_params,
        allow_missing = True,
        batch_end_callback = mx.callback.Speedometer(batch_size, 20),
        epoch_end_callback = mx.callback.do_checkpoint(prefix='bird_classification_'+prefix, period=1),
        optimizer = 'sgd',
        optimizer_params = {'learning_rate':0.01, 'wd':0.01},
        eval_metric = 'acc')

INFO:root:Epoch[0] Batch [0-20]	Speed: 159.68 samples/sec	accuracy=0.010476
INFO:root:Epoch[0] Batch [20-40]	Speed: 158.60 samples/sec	accuracy=0.019000
INFO:root:Epoch[0] Batch [40-60]	Speed: 160.38 samples/sec	accuracy=0.030000
INFO:root:Epoch[0] Batch [60-80]	Speed: 158.63 samples/sec	accuracy=0.043000
INFO:root:Epoch[0] Batch [80-100]	Speed: 158.84 samples/sec	accuracy=0.055000
INFO:root:Epoch[0] Batch [100-120]	Speed: 158.80 samples/sec	accuracy=0.074000
INFO:root:Epoch[0] Batch [120-140]	Speed: 159.58 samples/sec	accuracy=0.102000
INFO:root:Epoch[0] Batch [140-160]	Speed: 158.43 samples/sec	accuracy=0.118000
INFO:root:Epoch[0] Batch [160-180]	Speed: 158.68 samples/sec	accuracy=0.102000
INFO:root:Epoch[0] Batch [180-200]	Speed: 158.42 samples/sec	accuracy=0.146000
INFO:root:Epoch[0] Train-accuracy=0.074623
INFO:root:Epoch[0] Time cost=69.428
INFO:root:Saved checkpoint to "bird_classification_resnet-18-0001.params"
INFO:root:Epoch[0] Validation-accuracy=0.175833
INFO:root:Epoch[1] 

In [None]:
# Load the model.
symbol, arg_params, aux_params = mx.model.load_checkpoint('bird_classification_'+prefix, num_epoch)
net = mx.mod.Module(symbol=symbol, context=ctx)
net.bind(train_itr.provide_data, train_itr.provide_label)
net.set_params(arg_params, aux_params)

In [None]:
# Construct dictionary for category names.
dict = {}
with open('labels.txt', 'r') as f:
    lines = f.readlines()
for i in range(len(lines)):
    dict[i] = lines[i]

# Inference demonstration.
def img_transform(data):
    # Transform from MXNet format into plottable format.
    img = data.asnumpy().astype(np.uint8)
    img = img.transpose((1, 2, 0))
    return img

# Fetch image and do inference.
_, test_itr = get_iterators()
batch = test_itr.first_batch
pred_prob = net.predict(test_itr, num_batch=1, reset=True).asnumpy()
pred_index = pred_prob[0].argsort()[-5:][::-1]

# Display results.
img = img_transform(batch.data[0][0])
plt.figure()
plt.imshow(img)
logging.info('True label is: ' + dict[batch.label[0].asnumpy()[0].astype(np.int)])
logging.info('Prediction results: ')
for i in range(5):
    logging.info('Top ' + str(i+1) + ' is: ' + dict[pred_index[i]] + 'with probability: ' + str(pred_prob[0][pred_index[i]]))
