## Import libraries

In [1]:
from keras.models import Model
from keras.layers import Flatten, Dense, Lambda
from keras.applications.vgg16 import VGG16
from keras.optimizers import Adam
from keras.preprocessing import image
from keras import backend as K

Using TensorFlow backend.


## Set variables and import data

In [2]:
# set variables
gen = image.ImageDataGenerator()
batch_size = 64

In [3]:
# import training data
batches = gen.flow_from_directory('data/train',
                                  target_size=(224,224),
                                  class_mode='categorical',
                                  shuffle=True,
                                  batch_size=batch_size)

Found 8721 images belonging to 3 classes.


In [4]:
# import validation data
val_batches = gen.flow_from_directory('data/valid',
                                      target_size=(224,224),
                                      class_mode='categorical',
                                      shuffle=True,
                                      batch_size=batch_size)

Found 2100 images belonging to 3 classes.


## Import and configure model

In [6]:
# retrieve the full Keras VGG model including imagenet weights
vgg = VGG16(include_top=True, weights='imagenet',
                               input_tensor=None, input_shape=(224,224,3), pooling=None)

# define a new output layer to connect with the last fc layer in vgg
x = vgg.layers[-2].output
output_layer = Dense(3, activation='softmax', name='predictions')(x) # 3 ouputs - forward, left, right

# combine the original VGG model with the new output layer
vgg2 = Model(inputs=vgg.input, outputs=output_layer)

# compile the new model
vgg2.compile(optimizer=Adam(lr=0.0001), #was 0.001
                loss='categorical_crossentropy', metrics=['accuracy'])

## Train it!

In [9]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 1)

Epoch 1/1


<keras.callbacks.History at 0x7f574145bd68>

In [10]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 1)

Epoch 1/1


<keras.callbacks.History at 0x7f57397bcda0>

In [11]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 1)

Epoch 1/1


<keras.callbacks.History at 0x7f5738f5ab70>

In [12]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f5738f5aef0>

In [13]:
vgg2.save('keras_models/gtavc1.h5')

In [12]:
from keras.models import load_model

In [13]:
vgg2 = load_model('keras_models/gtavc1.h5')

## Overtraining - not used

In [14]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f608a59e128>

In [15]:
vgg2.fit_generator(batches,
                   steps_per_epoch = batches.samples // batch_size,
                   validation_data = val_batches, 
                   validation_steps = val_batches.samples // batch_size,
                   epochs = 3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7f608a59e2e8>

In [None]:
# print learning rate
print(K.eval(vgg2.optimizer.lr))

In [None]:
# change learning rate
K.set_value(vgg2.optimizer.lr, 0.1)