### 开始

In [1]:
# 检查Python版本
from sys import version_info
if version_info.major != 3:
    
    raise Exception('请使用Python3来完成此项目')

### 数据预处理

#### 获取数据
[Dogs vs. Cats Redux: Kernels Edition
](https://www.kaggle.com/c/dogs-vs-cats-redux-kernels-edition)

In [2]:
# download data and unzip it
# from urllib.request import urlretrieve
import subprocess
import os
from tqdm import tqdm
from zipfile import ZipFile

train_url = ['kaggle','competitions','download','-c','dogs-vs-cats-redux-kernels-edition',
            '-f','train.zip','-p','./']
test_url = ['kaggle','competitions','download','-c','dogs-vs-cats-redux-kernels-edition',
            '-f','test.zip','-p','./']
sample_csv_url = ['kaggle','competitions','download','-c','dogs-vs-cats-redux-kernels-edition',
                  '-f','sample_submission.csv','-p','./']


def download_unzip_dataset(url, zip_file_path, folder_path, unzip=True):
    if os.path.exists(zip_file_path):
        print("file is exist, no need download")
    else:
        print("download now")
#         urlretrieve(url, zip_file_path)
        subp = subprocess.run(url)
#         subp.wait()
    
    if unzip:
        if os.path.exists(folder_path):
            print("files found")
        else:
            print("unzip now")
            zipf = ZipFile(zip_file_path)
            zipf.extractall()
            print("unzip end")

download_unzip_dataset(train_url, 'train.zip', 'train/')
download_unzip_dataset(test_url, 'test.zip', 'test/')

file is exist, no need download
files found
file is exist, no need download
files found


#### 分离数据集，dog和cat图片分别放入train2/dogs, train2/cats

In [3]:
import os
import shutil

def split_train_set(old_dir, new_dir):
    file_list = os.listdir(old_dir)
    file_cats = filter(lambda x:x[:3] == 'cat', file_list)
    file_dogs = filter(lambda x:x[:3] == 'dog', file_list)

    if os.path.exists(new_dir):
        shutil.rmtree(new_dir)
    os.mkdir(new_dir)
    
    dogs_path = os.path.join(new_dir, 'dogs')
    cats_path = os.path.join(new_dir, 'cats')
    os.mkdir(dogs_path)
    os.mkdir(cats_path)
    
    # 此处要注意： os.symlink(src, dst)
    # dst是从它所在的目录去选择src,所以src必须是相对于dst的relative path
    for filename in file_cats:
        os.symlink('../../'+old_dir+filename, cats_path+'/'+filename)
    
    for filename in file_dogs:
        os.symlink(old_dir+filename, dogs_path+'/'+filename)
        
    print("split over")

split_train_set('train/', 'pre-train/')

# preprocess test image folder
if os.path.exists('pre-test'):
    shutil.rmtree('pre-test')
os.mkdir('pre-test')
os.symlink('../test', 'pre-test/test')


split over


### 提取特征

利用pre-trained model提取特征

#### 准备训练集和测试集
对于样本数非常多的数据集，可以利用generator函数来减少计算的次数

In [16]:
from keras.preprocessing.image import *

image_size = (224,224)

image_gen = ImageDataGenerator()
train_generator = image_gen.flow_from_directory('pre-train', 
                                                target_size=image_size, 
                                                shuffle=False, # our data will be in order
                                                batch_size=16)
test_generator = image_gen.flow_from_directory('pre-test', 
                                               target_size=image_size, 
                                               shuffle=False, # our data will be in order
                                               batch_size=16, 
                                               class_mode=None, # this means our generator will only yield batches of data, no labels
                                              )


Found 25000 images belonging to 2 classes.
Found 0 images belonging to 0 classes.


#### 提取特征
- 利用 pre-trained 模型从train/test dataset中提取出特征，然后使用自定义的fully-connected层在这些提取的特征集上训练

In [None]:
from keras.applications import resnet50
import h5py

x = Input((image_size[0], image_size[1], 3)) # shape: width, height, channel
base_model = resnet50.ResNet50(input_tensor=resnet50.preprocess_input(x) weights='imagenet')

# the predict_generator method returns the output of a model, given
# a generator that yields batches of numpy data
bottleneck_features_train = base_model.predict_generator(train_generator, train_generator.nb_sample)
# save the output as a Numpy array
# np.save(open('bottleneck_features_train.npy', 'w'), bottleneck_features_train)

bottleneck_features_test = model.predict_generator(test_generator, test_generator.nb_sample)
# np.save(open('bottleneck_features_test.npy', 'w'), bottleneck_features_test)

with h5py.File("pre_out") as h:
        h.create_dataset("train", data=bottleneck_features_train)
        h.create_dataset("label", data=train_generator.classes)
        h.create_dataset("test", data=bottleneck_features_test)

#### 为了增加代码复用，方便调试其它的pre-trained model，将上面的2个步骤封装为一个函数

In [4]:
from keras.preprocessing.image import *
from keras.applications import resnet50
from keras.applications import xception
from keras.applications import inception_v3
from keras.layers import Input, GlobalAveragePooling2D
from keras.models import Model
import h5py

def get_pre_features_from_images(MODEL, image_size, preprocess_input, model_name):
    image_gen = ImageDataGenerator()
    train_generator = image_gen.flow_from_directory('pre-train', 
                                                target_size=image_size, 
                                                shuffle=False, # our data will be in order
                                                batch_size=16)
    
    test_generator = image_gen.flow_from_directory('pre-test', 
                                               target_size=image_size, 
                                               shuffle=False, # our data will be in order
                                               batch_size=16, 
                                               class_mode=None, # this means our generator will only yield batches of data, no labels
                                              )
    
    ## use pre-trained model to get features from image generator
    x = Input((image_size[0], image_size[1], 3)) # shape: width, height, channel
    x = preprocess_input(x)
    base_model = MODEL(input_tensor=x, weights='imagenet', include_top=False)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))
    
    # the predict_generator method returns the output of a model, given
    # a generator that yields batches of numpy data
    pre_features_train = model.predict_generator(train_generator)
    pre_features_test = model.predict_generator(test_generator)
    
    # save the output to h5 file
    out_filename = model_name + "_pre_out.h5"
    with h5py.File(out_filename) as h:
        h.create_dataset("train", data=pre_features_train)
        h.create_dataset("label", data=train_generator.classes)
        h.create_dataset("test", data=pre_features_test)

In [5]:
get_pre_features_from_images(resnet50.ResNet50, (224,224), resnet50.preprocess_input, "ResNet50")

Found 25000 images belonging to 2 classes.
Found 12500 images belonging to 1 classes.


ValueError: ('Error when checking model input: expected no data, but got:', array([[[[203., 164.,  87.],
         [206., 167.,  90.],
         [209., 170.,  93.],
         ...,
         [245., 203., 119.],
         [241., 202., 123.],
         [239., 200., 121.]],

        [[203., 164.,  87.],
         [206., 167.,  90.],
         [209., 170.,  93.],
         ...,
         [245., 205., 120.],
         [242., 203., 124.],
         [240., 201., 122.]],

        [[203., 164.,  87.],
         [206., 167.,  90.],
         [209., 170.,  93.],
         ...,
         [245., 204., 122.],
         [243., 204., 125.],
         [241., 202., 123.]],

        ...,

        [[154., 123.,  56.],
         [155., 124.,  57.],
         [156., 125.,  58.],
         ...,
         [  3.,   3.,   1.],
         [  3.,   3.,   1.],
         [  3.,   3.,   1.]],

        [[153., 122.,  55.],
         [153., 122.,  55.],
         [154., 123.,  56.],
         ...,
         [  2.,   2.,   0.],
         [  2.,   2.,   0.],
         [  2.,   2.,   0.]],

        [[151., 120.,  53.],
         [152., 121.,  54.],
         [153., 122.,  55.],
         ...,
         [  1.,   1.,   0.],
         [  1.,   1.,   0.],
         [  1.,   1.,   0.]]],


       [[[ 39.,  44.,  40.],
         [ 40.,  44.,  43.],
         [ 41.,  45.,  46.],
         ...,
         [210., 209., 181.],
         [207., 204., 171.],
         [201., 199., 161.]],

        [[ 40.,  45.,  41.],
         [ 40.,  44.,  43.],
         [ 41.,  45.,  46.],
         ...,
         [207., 203., 176.],
         [203., 200., 169.],
         [197., 195., 157.]],

        [[ 39.,  44.,  40.],
         [ 38.,  42.,  41.],
         [ 37.,  41.,  42.],
         ...,
         [195., 191., 166.],
         [198., 193., 164.],
         [205., 200., 168.]],

        ...,

        [[ 29.,  27.,  28.],
         [ 25.,  23.,  24.],
         [ 22.,  20.,  21.],
         ...,
         [ 50.,  37.,  31.],
         [ 41.,  28.,  22.],
         [ 49.,  38.,  32.]],

        [[ 32.,  30.,  31.],
         [ 26.,  24.,  25.],
         [ 22.,  20.,  21.],
         ...,
         [ 44.,  31.,  23.],
         [ 42.,  29.,  21.],
         [ 55.,  45.,  36.]],

        [[ 32.,  30.,  31.],
         [ 25.,  23.,  24.],
         [ 21.,  19.,  20.],
         ...,
         [ 59.,  46.,  38.],
         [ 51.,  38.,  30.],
         [ 40.,  30.,  21.]]],


       [[[ 29.,  33.,  42.],
         [ 19.,  23.,  32.],
         [  8.,  12.,  23.],
         ...,
         [130., 162., 159.],
         [128., 160., 157.],
         [125., 157., 154.]],

        [[ 36.,  40.,  49.],
         [ 41.,  45.,  54.],
         [ 31.,  35.,  46.],
         ...,
         [131., 163., 160.],
         [129., 161., 158.],
         [126., 158., 155.]],

        [[ 38.,  45.,  51.],
         [ 41.,  48.,  54.],
         [ 37.,  44.,  52.],
         ...,
         [132., 164., 159.],
         [129., 161., 156.],
         [127., 159., 154.]],

        ...,

        [[178., 165., 120.],
         [171., 158., 114.],
         [169., 156., 112.],
         ...,
         [190., 188.,   9.],
         [179., 174.,  12.],
         [170., 163.,  13.]],

        [[169., 156., 111.],
         [170., 157., 112.],
         [156., 143.,  99.],
         ...,
         [189., 186.,   9.],
         [178., 173.,  11.],
         [168., 161.,  11.]],

        [[159., 145., 110.],
         [147., 133.,  94.],
         [150., 137.,  93.],
         ...,
         [189., 187.,   6.],
         [180., 177.,   4.],
         [172., 167.,   3.]]],


       ...,


       [[[175., 139., 107.],
         [177., 141., 109.],
         [178., 142., 110.],
         ...,
         [ 75.,  68.,  76.],
         [ 69.,  62.,  70.],
         [ 69.,  62.,  70.]],

        [[175., 139., 107.],
         [177., 141., 109.],
         [178., 142., 110.],
         ...,
         [ 75.,  68.,  76.],
         [ 73.,  66.,  74.],
         [ 73.,  66.,  74.]],

        [[175., 139., 107.],
         [177., 141., 109.],
         [178., 142., 110.],
         ...,
         [ 75.,  68.,  76.],
         [ 75.,  68.,  76.],
         [ 75.,  68.,  76.]],

        ...,

        [[220., 168., 111.],
         [198., 146.,  89.],
         [182., 130.,  73.],
         ...,
         [178., 126.,  76.],
         [184., 132.,  82.],
         [187., 135.,  85.]],

        [[189., 137.,  80.],
         [198., 146.,  89.],
         [205., 153.,  96.],
         ...,
         [161., 109.,  59.],
         [160., 108.,  58.],
         [163., 111.,  61.]],

        [[188., 136.,  79.],
         [193., 141.,  84.],
         [198., 146.,  89.],
         ...,
         [166., 114.,  64.],
         [172., 120.,  70.],
         [169., 117.,  67.]]],


       [[[214., 203., 157.],
         [214., 203., 157.],
         [215., 204., 158.],
         ...,
         [207., 197., 136.],
         [207., 197., 136.],
         [207., 197., 136.]],

        [[213., 202., 156.],
         [214., 203., 157.],
         [214., 203., 157.],
         ...,
         [207., 197., 136.],
         [207., 197., 136.],
         [207., 197., 136.]],

        [[212., 201., 155.],
         [213., 202., 156.],
         [214., 203., 157.],
         ...,
         [207., 197., 136.],
         [207., 197., 136.],
         [207., 197., 136.]],

        ...,

        [[120., 106.,  80.],
         [115.,  99.,  66.],
         [118.,  94.,  56.],
         ...,
         [150., 131.,  62.],
         [151., 131.,  68.],
         [144., 123.,  66.]],

        [[118., 105.,  71.],
         [114.,  98.,  62.],
         [114.,  90.,  52.],
         ...,
         [150., 130.,  59.],
         [148., 127.,  60.],
         [144., 123.,  60.]],

        [[118., 105.,  71.],
         [114.,  98.,  62.],
         [114.,  90.,  52.],
         ...,
         [150., 130.,  59.],
         [148., 127.,  60.],
         [144., 123.,  60.]]],


       [[[ 56.,  54.,   6.],
         [ 57.,  52.,  14.],
         [ 60.,  51.,  20.],
         ...,
         [ 41.,  42.,  37.],
         [ 37.,  37.,  35.],
         [ 17.,  17.,  15.]],

        [[ 50.,  47.,   2.],
         [ 51.,  45.,   9.],
         [ 53.,  44.,  13.],
         ...,
         [ 38.,  28.,  16.],
         [ 22.,  12.,   0.],
         [ 57.,  44.,  35.]],

        [[ 53.,  50.,   7.],
         [ 51.,  45.,  11.],
         [ 52.,  43.,  14.],
         ...,
         [ 53.,  38.,  19.],
         [ 91.,  74.,  56.],
         [ 49.,  30.,  13.]],

        ...,

        [[ 50., 168., 107.],
         [ 50., 168., 107.],
         [ 50., 168., 107.],
         ...,
         [ 24., 152., 101.],
         [ 30., 152., 103.],
         [ 34., 145., 102.]],

        [[ 50., 168., 107.],
         [ 50., 168., 107.],
         [ 50., 168., 107.],
         ...,
         [ 22., 149.,  98.],
         [ 29., 149., 103.],
         [ 35., 144., 102.]],

        [[ 44., 164., 102.],
         [ 45., 165., 103.],
         [ 44., 164., 102.],
         ...,
         [ 32., 161., 104.],
         [ 29., 158., 101.],
         [ 26., 158., 100.]]]], dtype=float32))

In [None]:
get_pre_features_from_images(xception.Xception, (224,224), resnet50.preprocess_input, "Xception")

In [None]:
get_pre_features_from_images(inception_v3.InceptionV3, (224,224), resnet50.preprocess_input, "InceptionV3")