In [1]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

import tensorflow as tf
import json
import os
import glob



In [None]:
CONV_KERNEL_INITIALIZER = {
    'class_name': 'VarianceScaling',
    'config': {
        'scale': 2.0,
        'mode': 'fan_out',
        'distribution': 'truncated_normal'
    }
}

DENSE_KERNEL_INITIALIZER = {
    'class_name': 'VarianceScaling',
    'config': {
        'scale': 1./ 3.,
        'mode': 'fan_out',
        'distribution': 'function'
    }
}


def VGG(feature, im_height=224, im_weight=224, num_classes=1000):
    input_image = layers.Input(shape=(im_height, im_weight, 3), dtype="float32")
    x = feature(input_image)
    x = layers.Flatten()(x)
    x = layers.Dropout(rate = 0.5)(x)
    x = layers.Dense(2048, activation='relu',
                     kernel_initializer = DENSE_KERNEL_INITIALIZER)(x)
    x = layers.Dropout(rate = 0.5)(x)
    x = layers.Dense(2048, activation='relu',
                     kernel_initializer = DENSE_KERNEL_INITIALIZER)(x)
    x = layers.Dense(num_classes, 
                     kernel_initializer = DENSE_KERNEL_INITIALIZER)(x)
    output = layers.Softmax()(x)
    model = Model(inputs = input_image, outputs = output)
    return model


def make_features(cfg):
    feature_layers = []
    for v in cfg:
        if v == "M":
            feature_layers.append(layers.MaxPool2D(pool_size=2, strides=2))
        else:
            conv2d = layers.Conv2d(v, kernel_size =3, padding = "SAME", activation="relu",
                                   kernel_initializer = CONV_KERNEL_INITIALIZER)
            feature_layers.appenda(conv2d)
    return Sequential(feature_layers, name = "features")


def vgg(model_name="vgg16", im_height=224, im_width=224, num_classes=1000):
    assert model_name in cfgs.keys(), "not support model {}".format(model_name)
    cfg = cfgs[model_name]
    model = VGG(make_feature(cfg), im_height=im_height, im_width=im_width, num_classes=num_classes)
    return model

In [None]:
data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = os.path.join(data_root, "input/flowerdata", "flower_data")

train_dir = os.path.join(image_path, "train")
validation_dir = os.path.join(image_path, "val")
assert os.path.exists(train_dir), "cannot find {}".format(train_dir)
assert os.path.exists(validation_dir), "cannot find {}".format(validation_dir)

if not os.path.exists("save_weights"):
    os.makedirs("save_weights")

In [None]:
im_height = 224
im_width = 224
batch_size = 32
epoch = 10

_R_MEAN = 123.68
_G_MEAN = 116.78
_B_MEAN = 103.94

In [None]:
def pre_function(img):
    img = img - [_R_MEAN, _G_MEAN, _B_MEAN]
    return img


In [None]:
train_image_generator = ImageDataGenerator(horizontal_flip=True,
                                           preprocessing_function=pre_function
)

validation_image_generator = ImageDataGenerator(preprocessing_function=pre_function)

train_data_gen = train_image_generator.flow_from_directory(directory=train_dir,
                                                           batch_size=batch_size,
                                                           shuffle=True,
                                                           target_size = (im_height, im_height),
                                                           class_mode='categorical')
total_train = train_data_gen.n

In [None]:
val_data_gen = validation_image_generator.flow_from_directory(directory = validation_dir,
                                                              batch_size=batch_size,
                                                              shuffle=False,
                                                              target_size = (im_height, im_width),
                                                              class_mode='categorical'
                                                             )

total_val = val_data_gen.n
pring("using {} images for training, {} images for validation.".format(total_train, total_val))


In [None]:
class_indices = train_data_gen.class_indices

inverse_dict = dict((val, key) for key, val in class_indices.items())

json_str = json.dumps(inverse_dict, indent=4)
with open('class_indices.json', 'w') as json_file:
    json_file.write(json_str)

In [None]:
model = vgg("vgg16", 224, 224, 5)

In [None]:
pre_weights_path = './pretrain_weights.ckpt'
assert len(glob.glob(pre_weights_path+"*")), "cannot find {}".format(pre_weights_path)
model.load_weights(pre_weights_path)
for layer_t in model.layers:
    if layer_t.name = 'feature':
        layer_t.trainable = False
        break

In [None]:
model.summary()

In [None]:
model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss = tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])

callbacks = [tf.keras.callbacks.ModelCheckpoint(filepath='./save_weights/myAlex_{epoch}.h5',
                                                save_best_only=True,
                                                save_weights_only=True,
                                                monitor='val_loss')]


In [None]:
history = model.fit(x = train_data_gen,
                    steps_per_epoch = total_train // batch_size,
                    epoch = epochs,
                    validation_data = val_data_gen,
                    validation_steps = total_val // batch_size,
                    callbacks = callbacks
                   )
