# 预训练——数据增强的特征提取
- 直接在卷积基后面上重新定义全连接层，然后进行训练
- 这种方式可以使用数据增强技术。这种方法比直接使用特征提取的方式计算量更加大，精度也更加高：[kaggle链接](https://www.kaggle.com/liuyixi/dogs-vs-cats-pre-train2)

In [1]:
# 直接用上面模型的那个数据集路径，这里只要把路径名声明一下即可
import os
original_dataset_dir = '../../data/cats_and_dogs/train'
base_dir = '../../data/cats_and_dogs_small'
train_dir = os.path.join(base_dir, 'train')
valid_dir = os.path.join(base_dir, 'valid')
test_dir = os.path.join(base_dir, 'test')

train_dir_dogs = os.path.join(train_dir, 'dogs')
valid_dir_dogs = os.path.join(valid_dir, 'dogs')
test_dir_dogs = os.path.join(test_dir, 'dogs')
train_dir_cats = os.path.join(train_dir, 'cats')
valid_dir_cats = os.path.join(valid_dir, 'cats')
test_dir_cats = os.path.join(test_dir, 'cats')

In [None]:
# 定义卷积基
from keras.applications import vgg16
conv_base = vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3))

In [None]:
# 使用卷积基，然后加上全连接层
from keras import models, layers
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
model.summary()

In [None]:
# 将卷积基的参数设置为不可训练
print(f'冻结之前，模型参数的数量为：{len(model.trainable_weights)}')
conv_base.trainable = False
print(f'冻结之后，模型参数的数量为：{len(model.trainable_weights)}')

In [None]:
# 开始准备数据、训练模型
from keras.preprocessing.image import ImageDataGenerator
from keras import optimizers
train_datagen = ImageDataGenerator(rescale=1./255,rotation_range=40,
                                      width_shift_range=0.2,height_shift_range=0.2,
                                        shear_range=0.2,zoom_range=0.2,horizontal_flip=True,
                                        fill_mode='nearest')
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(train_dir,target_size=(150,150),batch_size=20,class_mode='binary')
valid_generator = test_datagen.flow_from_directory(valid_dir,target_size=(150,150),batch_size=20,class_mode='binary')
model.compile(loss='binary_crossentropy',optimizer='rmsprop',metrics=['acc'])
history = model.fit_generator(train_generator,steps_per_epoch=100,epochs=30,validation_data=valid_generator,validation_steps=50)