In [1]:
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator


In [2]:
datagen = ImageDataGenerator(
    preprocessing_function = tf.keras.applications.xception.preprocess_input,
    validation_split=0.2,
)

In [3]:
train_gen = datagen.flow_from_directory(
    '../cub/CUB_200_2011/images',
    target_size=(299,299),
    batch_size=50,
    subset='training',
    shuffle=True,
)

validation_gen = datagen.flow_from_directory(
    '../cub/CUB_200_2011/images',
    target_size=(299,299),
    batch_size=50,
    subset='validation',
    shuffle=True,
)


Found 9465 images belonging to 200 classes.
Found 2323 images belonging to 200 classes.


In [4]:
train_gen.next()[0].shape

(50, 299, 299, 3)

In [17]:
from tensorflow.keras.layers import Dense, Flatten, GlobalAveragePooling2D, Activation, Flatten, Dropout, BatchNormalization

model = tf.keras.models.Sequential()
model.add(tf.keras.applications.Xception(include_top=False, pooling='avg', weights="imagenet"))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(200, activation='softmax'))
model.layers[0].trainable = False


In [18]:
%%time
model.compile(
    optimizer='Adam',
    loss='categorical_crossentropy',
    metrics=['accuracy'],
)

CPU times: user 3.68 ms, sys: 50 µs, total: 3.73 ms
Wall time: 3.47 ms


In [None]:
%%time
history = model.fit(
    train_gen,
    epochs=100,
    validation_data=validation_gen,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(
            monitor='val_accuracy',
            patience=5,
            min_delta=0.001,
            mode='max',
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filepath='checkpoints3',
            save_weights_only=True,
        ),
    ]
)

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100

In [None]:
%%time
model.evaluate(validation_gen)

In [None]:
import os
model.save(f'model.{os.uname()[1]}.003')

In [None]:
from matplotlib import pyplot as plt

plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()