In [1]:
# 预先将图片处理，生成rec和lst文件，linux shell中执行下述指令
# sudo python3 /home/carsmart/mxnet/tools/im2rec.py --list TRUE --recursive TRUE --train-ratio 0.8 --test-ratio 0 /home/carsmart/users/xiaoming/Distracted_Driver_Detection/train_lst/train /home/carsmart/users/xiaoming/Distracted_Driver_Detection/train/
# sudo python3 /home/carsmart/mxnet/tools/im2rec.py --pass-through TRUE /home/carsmart/users/xiaoming/Distracted_Driver_Detection/train_lst/train /home/carsmart/users/xiaoming/Distracted_Driver_Detection/train/
# labels can read from train_train.lst or train_val.lst
# sudo python3 /home/carsmart/mxnet/tools/im2rec.py --list TRUE --recursive TRUE --train-ratio 0 --test-ratio 1 /home/carsmart/users/xiaoming/Distracted_Driver_Detection/test_lst/test /home/carsmart/users/xiaoming/Distracted_Driver_Detection/test/
# sudo python3 /home/carsmart/mxnet/tools/im2rec.py --pass-through TRUE /home/carsmart/users/xiaoming/Distracted_Driver_Detection/test_lst/test /home/carsmart/users/xiaoming/Distracted_Driver_Detection/test/

In [2]:
import mxnet as mx
from mxnet import autograd as ag
from mxnet import gluon
from mxnet import init
from mxnet import nd
from time import time
from mxnet.gluon import nn
from mxnet.gluon.model_zoo import vision as models
import os
import numpy as np

import warnings
warnings.filterwarnings('ignore')

## parameter

In [3]:
project_path = '/home/carsmart/users/xiaoming/Distracted_Driver_Detection/'
path_list = project_path + 'train_lst/train_train.lst'
path_rec = project_path + 'train_lst/train_train.rec'
val_path_list = project_path + 'train_lst/train_val.lst'
val_path_rec = project_path + 'train_lst/train_val.rec'
test_path_list = project_path + 'test_lst/test_test.lst'
test_path_rec = project_path + 'test_lst/test_test.rec'

save_path = project_path + 'features/'

(mean_r, mean_g, mean_b) = (0.485 * 255, 0.456 * 255, 0.406 * 255)
(std_r, std_g, std_b) = (0.229 * 225, 0.224 * 255, 0.225 * 255)

batch_size = 8
ctx = [mx.gpu(0)]

## iter

In [4]:
def get_iter(kv, resize, data_shape):
    train_iter = mx.io.ImageRecordIter(
        path_imglist = path_list, 
        path_imgrec = path_rec, 
        resize = resize, 
        data_shape = data_shape, 
        batch_size = batch_size, 
        rand_mirror = False, 
        rand_crop = False, 
        mean_r = mean_r, 
        mean_g = mean_g, 
        mean_b = mean_b, 
        std_r = std_r, 
        std_g = std_g, 
        std_b = std_b, 
        num_parts = kv.num_workers, 
        part_index = kv.rank, 
        shuffle = False
    )
    train_iter = mx.io.PrefetchingIter(train_iter)
    
    val_iter = mx.io.ImageRecordIter(
        path_imglist = val_path_list, 
        path_imgrec = val_path_rec, 
        resize = resize, 
        data_shape = data_shape, 
        batch_size = batch_size, 
        rand_mirror = False, 
        rand_crop = False, 
        mean_r = mean_r, 
        mean_g = mean_g, 
        mean_b = mean_b, 
        std_r = std_r, 
        std_g = std_g, 
        std_b = std_b, 
        num_parts = kv.num_workers, 
        part_index = kv.rank
    )
    
    test_iter = mx.io.ImageRecordIter(
        path_imglist = test_path_list, 
        path_imgrec = test_path_rec, 
        resize = resize, 
        data_shape = data_shape, 
        batch_size = batch_size, 
        rand_mirror = False, 
        rand_crop = False, 
        mean_r = mean_r, 
        mean_g = mean_g, 
        mean_b = mean_b, 
        std_r = std_r, 
        std_g = std_g, 
        std_b = std_b, 
        num_parts = kv.num_workers, 
        part_index = kv.rank
    )
    
    return(train_iter, val_iter, test_iter)

## export features

In [5]:
def get_batch(batch, ctx):
    """return data and label on ctx"""
    if isinstance(batch, mx.io.DataBatch):
        data = batch.data[0]
        label = batch.label[0]
    else:
        data, lable = batch
    return (gluon.utils.split_and_load(data, ctx), 
           gluon.utils.split_and_load(label, ctx), 
           data.shape[0])

def save_features(model_name, data_iter, save_path, ignore = True, prefix = 'train'):
    # file exist
    if os.path.exists(save_path + 'features_%s_%s.nd' % (prefix, model_name)) and ignore:
        return

    net = models.get_model(name = model_name, pretrained = True, ctx = ctx)
    features = []
    labels = []
    for i, batch in enumerate(data_iter):
        real_size = batch_size - batch.pad
        data, _, _ = get_batch(batch, ctx)
        feature = [net.features(x) for x in data]
        if 'squeezenet' in model_name:
            feature = [gluon.nn.GlobalAvgPool2D()(x) for x in feature]
        feature = [gluon.nn.Flatten()(x) for x in feature]
        features.append(nd.concatenate([x.as_in_context(mx.cpu()) for x in feature])[0:real_size, :])
        nd.waitall()
        
    features = nd.concatenate(features)
    nd.save(save_path + 'features_%s_%s.nd' % (prefix, model_name), features)

In [6]:
from mxnet.gluon.model_zoo.model_store import _model_sha1

kv = mx.kvstore.create("local")
for model_name in sorted(_model_sha1.keys()):
    print(model_name)
    if model_name == 'inceptionv3':
        (train_iter_299, val_iter_299, test_iter_299) = get_iter(kv, resize = 299, data_shape = (3, 299, 299))
        save_features(model_name, train_iter_299, save_path = save_path, prefix = 'train')
        save_features(model_name, val_iter_299, save_path = save_path, prefix = 'val')
        save_features(model_name, test_iter_299, save_path = save_path, prefix = 'test')
    else:
        (train_iter_224, val_iter_224, test_iter_224) = get_iter(kv, resize = 224, data_shape = (3, 224, 224))
        save_features(model_name, train_iter_224, save_path = save_path, prefix = 'train')
        save_features(model_name, val_iter_224, save_path = save_path, prefix = 'val')
        save_features(model_name, test_iter_224, save_path = save_path, prefix = 'test')

alexnet
densenet121
densenet161
densenet169
densenet201
inceptionv3
resnet101_v1
resnet152_v1
resnet18_v1
resnet18_v2
resnet34_v1
resnet34_v2
resnet50_v1
resnet50_v2
squeezenet1.0
squeezenet1.1
vgg11
vgg11_bn
vgg13
vgg13_bn
vgg16
vgg16_bn
vgg19
vgg19_bn
