# 导出特征向量

In [1]:
from keras.models import *
from keras.layers import *
from keras.applications import *
from keras.preprocessing.image import *
import h5py
import os

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
print(K.image_data_format())

channels_last


In [2]:
def gen_feature(MODEL, model_name: str, input_size: tuple, pre_process_func, batch_size=64, 
                pretrain_dir='pretrain', finetune_dir='finetune', val_dir='val', test_dir='test'):
    
    x = Input((input_size[0], input_size[1], 3))
    x = Lambda(pre_process_func)(x)
    
    base_model = MODEL(input_tensor=x, weights='imagenet', include_top=False)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))
    
    img_gen = ImageDataGenerator()
    pretrain_data_gen = img_gen.flow_from_directory(pretrain_dir, input_size, shuffle=False, batch_size=batch_size, class_mode='binary')
    finetune_data_gen = img_gen.flow_from_directory(finetune_dir, input_size, shuffle=False, batch_size=batch_size, class_mode='binary')
    val_data_gen = img_gen.flow_from_directory(val_dir, input_size, shuffle=False, batch_size=batch_size, class_mode='binary')
    test_data_gen = img_gen.flow_from_directory(test_dir, input_size, shuffle=False, batch_size=batch_size, class_mode=None)
    
    pretrain_feature = model.predict_generator(pretrain_data_gen, len(pretrain_data_gen), verbose=1)
    finetune_feature = model.predict_generator(finetune_data_gen, len(finetune_data_gen), verbose=1)
    val_feature = model.predict_generator(val_data_gen, len(val_data_gen), verbose=1)
    test_feature = model.predict_generator(test_data_gen, len(test_data_gen), verbose=1)
    
    
    h5_file = "feature_%s.h5"%model_name
    if os.path.isfile(h5_file):
        os.remove(h5_file)
    
    with h5py.File(h5_file) as h:
        h.create_dataset("pretrain", data=pretrain_feature)
        h.create_dataset("pretrain_label", data=pretrain_data_gen.classes)
        h.create_dataset("finetune", data=finetune_feature)
        h.create_dataset("finetune_label", data=finetune_data_gen.classes)
        h.create_dataset("val", data=val_feature)
        h.create_dataset("val_label", data=val_data_gen.classes)
        h.create_dataset("test", data=test_feature)
        
    

In [3]:
gen_feature(InceptionV3, 'InceptionV3', (299, 299), pre_process_func=inception_v3.preprocess_input)

Found 1991 images belonging to 2 classes.
Found 17925 images belonging to 2 classes.
Found 4979 images belonging to 2 classes.
Found 12500 images belonging to 1 classes.


In [4]:
gen_feature(Xception, 'Xception', (299, 299), pre_process_func=xception.preprocess_input, batch_size=64)

Found 1991 images belonging to 2 classes.
Found 17925 images belonging to 2 classes.
Found 4979 images belonging to 2 classes.
Found 12500 images belonging to 1 classes.


In [4]:
gen_feature(InceptionResNetV2, 'InceptionResNetV2', (299, 299), pre_process_func=inception_resnet_v2.preprocess_input)

Found 1991 images belonging to 2 classes.
Found 17925 images belonging to 2 classes.
Found 4979 images belonging to 2 classes.
Found 12500 images belonging to 1 classes.
