In [1]:
from keras.models import Model
from keras.optimizers import RMSprop
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.models import load_model

Using TensorFlow backend.


In [2]:
datagen = ImageDataGenerator(rescale=1./255, 
                             brightness_range=[0.7,1.3],
                             rotation_range=30,
                             zoom_range=[0.7,1.3],
                             fill_mode='nearest',
                             validation_split=0.01)

train_datagen = datagen.flow_from_directory('data/train/', seed=42, class_mode='categorical', subset='training', target_size=(512,512))
# prints "75750 images belonging to 101 classes"
val_datagen = datagen.flow_from_directory('data/train/', seed=42, class_mode='categorical', subset='validation', target_size=(512,512)) 
# prints "15150 images belonging to 101 classes"

train_steps = len(train_datagen) #1894
val_steps = len(val_datagen) #474
classes = len(list(train_datagen.class_indices.keys())) #101

Found 75043 images belonging to 102 classes.
Found 707 images belonging to 102 classes.


In [3]:
model = load_model('tuned_resnet.h5', compile=True)

In [4]:
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
resnet50v2 (Model)           (None, 2048)              23564800  
_________________________________________________________________
batch_normalization_1 (Batch (None, 2048)              8192      
_________________________________________________________________
dropout_1 (Dropout)          (None, 2048)              0         
_________________________________________________________________
dense_1 (Dense)              (None, 512)               1049088   
_________________________________________________________________
batch_normalization_2 (Batch (None, 512)               2048      
_________________________________________________________________
dropout_2 (Dropout)          (None, 512)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 102)              

In [5]:
# for layer in model.get_layer('resnet50v2').layers:
    # print(layer.name)

In [6]:
len(model.trainable_weights)

42

In [7]:
set_trainable = False
for layer in model.get_layer('resnet50v2').layers:
        if layer.name == 'conv5_block1_preact_bn':
            set_trainable = True
        if set_trainable == True:
            layer.trainable = True
        else:
            layer.trainable = False

In [8]:
len(model.trainable_weights)

42

In [9]:
model.compile(loss='categorical_crossentropy',
                      optimizer=RMSprop(lr=0.00005),
                      metrics=['acc','top_k_categorical_accuracy'])

In [10]:
callbacks = [
    ModelCheckpoint(
        filepath='fully_trained_resnet.h5',
        monitor='acc',
        save_best_only=True,
    ),
    EarlyStopping(
        monitor='val_acc',
        patience=2,
    )
]

In [11]:
history = model.fit_generator(
    train_datagen,
    steps_per_epoch=train_steps,
    epochs=15,
    verbose=2,
    validation_data=val_datagen,
    validation_steps=val_steps,
    callbacks=callbacks
)

Epoch 1/15
 - 6469s - loss: 0.6032 - acc: 0.8556 - top_k_categorical_accuracy: 0.9618 - val_loss: 0.5635 - val_acc: 0.7638 - val_top_k_categorical_accuracy: 0.9335
Epoch 2/15
 - 6297s - loss: 0.5430 - acc: 0.8686 - top_k_categorical_accuracy: 0.9683 - val_loss: 0.0549 - val_acc: 0.7779 - val_top_k_categorical_accuracy: 0.9321
Epoch 3/15
 - 6371s - loss: 0.5002 - acc: 0.8792 - top_k_categorical_accuracy: 0.9721 - val_loss: 0.0263 - val_acc: 0.7864 - val_top_k_categorical_accuracy: 0.9307
Epoch 4/15
 - 6145s - loss: 0.4642 - acc: 0.8882 - top_k_categorical_accuracy: 0.9752 - val_loss: 1.6273 - val_acc: 0.7610 - val_top_k_categorical_accuracy: 0.9264
Epoch 5/15
 - 5909s - loss: 0.4313 - acc: 0.8960 - top_k_categorical_accuracy: 0.9781 - val_loss: 0.1141 - val_acc: 0.7511 - val_top_k_categorical_accuracy: 0.9307
