# Keras 2 VGG Dogs vs Cats

In [1]:
# import libraries
from keras.models import Model
from keras.layers import Flatten, Dense, Lambda
from keras.applications.vgg16 import VGG16
from keras.optimizers import Adam
from keras.preprocessing import image

Using TensorFlow backend.


In [2]:
# set variables
gen = image.ImageDataGenerator()
batch_size = 64

In [3]:
# import training data
batches = gen.flow_from_directory('data/dogscats/train',
                                  target_size=(224,224),
                                  class_mode='categorical',
                                  shuffle=True,
                                  batch_size=batch_size)

Found 23000 images belonging to 2 classes.


In [4]:
# import validation data
val_batches = gen.flow_from_directory('data/dogscats/valid',
                                      target_size=(224,224),
                                      class_mode='categorical',
                                      shuffle=True,
                                      batch_size=batch_size)

Found 2000 images belonging to 2 classes.


In [5]:
# retrieve the full Keras VGG model including imagenet weights
vgg = VGG16(include_top=True, weights='imagenet',
                               input_tensor=None, input_shape=(224,224,3), pooling=None)

In [6]:
# set all layers to non-trainable
for layer in vgg.layers: layer.trainable=False

In [7]:
# define a new output layer to connect with the last fc layer in vgg
# thanks to joelthchao https://github.com/fchollet/keras/issues/2371
x = vgg.layers[-2].output
output_layer = Dense(2, activation='softmax', name='predictions')(x)

In [8]:
# combine the original VGG model with the new output layer
vgg2 = Model(inputs=vgg.input, outputs=output_layer)

In [9]:
# compile the new model
vgg2.compile(optimizer=Adam(lr=0.001),
                loss='categorical_crossentropy', metrics=['accuracy'])

In [10]:
# run it!
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 1)

Epoch 1/1


<keras.callbacks.History at 0x7ff681995748>