# 导出特征向量

In [1]:
from keras.models import *
from keras.layers import *
from keras.applications import *
from keras.preprocessing.image import *
import h5py
import os

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [None]:
print(K.image_data_format())

可选模型包括inception_resnet_v2，inception_v3，resnet50，xception：

In [38]:
def gen_feature(MODEL, model_name: str, input_size: tuple, train_dir: str, val_dir: str, test_dir: str, pre_process_func=None, batch_size=64):
    img_w = input_size[0]
    img_h = input_size[1]
    x = Input((img_h, img_w, 3))
    if pre_process_func:
        x = Lambda(pre_process_func)(x)
    
    base_model = MODEL(input_tensor=x, weights='imagenet', include_top=False)
    model = Model(base_model.input, GlobalAveragePooling2D()(base_model.output))
    
    img_gen = ImageDataGenerator()
    train_data_gen = img_gen.flow_from_directory(train_dir, input_size, shuffle=False, batch_size=batch_size, class_mode='binary')
    val_data_gen = img_gen.flow_from_directory(val_dir, input_size, shuffle=False, batch_size=batch_size, class_mode='binary')
    test_data_gen = img_gen.flow_from_directory(test_dir, input_size, shuffle=False, batch_size=batch_size)
    
    train_feature = model.predict_generator(train_data_gen,len(train_data_gen), verbose=1)
    val_feature = model.predict_generator(val_data_gen, len(val_data_gen), verbose=1)
    test_feature = model.predict_generator(test_data_gen, len(test_data_gen), verbose=1)
    
    
    h5_file = "feature_%s.h5"%model_name
    if os.path.isfile(h5_file):
        os.remove(h5_file)
    
    with h5py.File(h5_file) as h:
        h.create_dataset("train", data=train_feature)
        h.create_dataset("val", data=val_feature)
        h.create_dataset("test", data=test_feature)
        h.create_dataset("train_label", data=train_data_gen.classes)
        h.create_dataset("val_label", data=val_data_gen.classes)
    

In [None]:
gen_feature(ResNet50, 'ResNet50', (224, 224), train_dir='train_split', val_dir='val_split', test_dir='test', 
            pre_process_func=resnet50.preprocess_input)

In [None]:
gen_feature(InceptionV3, 'InceptionV3', (229, 229), train_dir='train_split', val_dir='val_split', test_dir='test', 
            pre_process_func=inception_v3.preprocess_input)

In [39]:
gen_feature(Xception, 'Xception', (229, 229), train_dir='train_split', val_dir='val_split', test_dir='test', 
            pre_process_func=xception.preprocess_input)

Found 19944 images belonging to 2 classes.
Found 4985 images belonging to 2 classes.
Found 12500 images belonging to 1 classes.


In [None]:
gen_feature(InceptionResNetV2, 'InceptionResNetV2', (229, 229), train_dir='train_split', val_dir='val_split', test_dir='test', 
            pre_process_func=inception_resnet_v2.preprocess_input)