## Train the CNN model for age and gender estimation

Note: adapted from train.py

In [None]:
import pandas as pd
import logging
import argparse
from pathlib import Path
import numpy as np
from keras.callbacks import LearningRateScheduler, ModelCheckpoint
from keras.optimizers import SGD, Adam
from keras.utils import np_utils
from wide_resnet import WideResNet
from utils import load_data
from keras.preprocessing.image import ImageDataGenerator
from mixup_generator import MixupGenerator
from random_eraser import get_random_eraser

logging.basicConfig(level=logging.DEBUG)

import os
from keras.callbacks import TensorBoard, ReduceLROnPlateau, EarlyStopping
from keras.layers import core

In [None]:
# Select GPU device
os.environ["CUDA_VISIBLE_DEVICES"]="0" # first gpu

In [None]:
class Schedule:
    def __init__(self, nb_epochs, initial_lr):
        self.epochs = nb_epochs
        self.initial_lr = initial_lr

    def __call__(self, epoch_idx):
        if epoch_idx < self.epochs * 0.25:
            return self.initial_lr
        elif epoch_idx < self.epochs * 0.50:
            return self.initial_lr * 0.2
        elif epoch_idx < self.epochs * 0.75:
            return self.initial_lr * 0.04
        return self.initial_lr * 0.008


def get_optimizer(opt_name, lr):
    if opt_name == "sgd":
        return SGD(lr=lr, momentum=0.9, nesterov=True)
    elif opt_name == "adam":
        return Adam(lr=lr)
    else:
        raise ValueError("optimizer name should be 'sgd' or 'adam'")

### Arguments: default values

In [None]:
##help:    path to input database mat file
args_input = "data/imdb_db.mat"
##help:    batch size
args_batch_size = 32
##help:    number of epochs
args_nb_epochs = 30
##help:    initial learning rate
args_lr = 0.1
##help:    optimizer name; 'sgd' or 'adam'
args_opt = "sgd"
##help:    depth of network (should be 10, 16, 22, 28, ...)
args_depth = 16
##help:    width of network
args_width = 8
##help:    validation split ratio
args_validation_split = 0.1
##help:    use data augmentation if set true
args_aug = True
##help:    checkpoint dir
args_output_path = "checkpoints"

In [None]:
# args = get_args()
input_path = args_input
batch_size = args_batch_size
nb_epochs = args_nb_epochs
lr = args_lr
opt_name = args_opt
depth = args_depth
k = args_width
validation_split = args_validation_split
use_augmentation = args_aug

current_nb_path = os.getcwd()
output_path = Path(current_nb_path).resolve().parent.joinpath(args_output_path)
output_path.mkdir(parents=True, exist_ok=True)

In [None]:
logging.debug("Loading data...")
image, gender, age, _, image_size, _ = load_data(input_path)
X_data = image
y_data_g = np_utils.to_categorical(gender, 2)
y_data_a = np_utils.to_categorical(age, 101)

In [None]:
model = WideResNet(image_size, depth=depth, k=k)()
opt = get_optimizer(opt_name, lr)

classification_layers = [layer.name for layer in model.layers[-2:] if type(layer) is core.Dense]
age_only = len(classification_layers) == 1 and "pred_age" in classification_layers

if age_only:
    model_loss = "categorical_crossentropy"
else:
    model_loss = ["categorical_crossentropy", "categorical_crossentropy"]
model.compile(optimizer=opt, loss=model_loss,
              metrics=['accuracy'])

logging.debug("Model summary...")
model.count_params()
model.summary()

In [None]:
log_dir = 'logs/001/'

In [None]:
# callbacks = [LearningRateScheduler(schedule=Schedule(nb_epochs, lr)),
#              ModelCheckpoint(str(output_path) + "/weights.{epoch:02d}-{val_loss:.2f}.hdf5",
#                              monitor="val_loss",
#                              verbose=1,
#                              save_best_only=True,
#                              mode="auto")
#              ]

tb_logging = TensorBoard(log_dir=log_dir)
checkpoint = ModelCheckpoint(log_dir + 'ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5',
    monitor='val_loss', save_weights_only=True, save_best_only=True, period=3)
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1)
early_stopping = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=1)

callbacks=[tb_logging, checkpoint, reduce_lr, early_stopping]

In [None]:
logging.debug("Running training...")

data_num = len(X_data)
indexes = np.arange(data_num)
np.random.shuffle(indexes)
X_data = X_data[indexes]
y_data_g = y_data_g[indexes]
y_data_a = y_data_a[indexes]
train_num = int(data_num * (1 - validation_split))
X_train = X_data[:train_num]
X_test = X_data[train_num:]
y_train_g = y_data_g[:train_num]
y_test_g = y_data_g[train_num:]
y_train_a = y_data_a[:train_num]
y_test_a = y_data_a[train_num:]

if age_only:
    train_labels = y_train_a
    test_labels = y_test_a
else:
    train_labels = [y_train_g, y_train_a]
    test_labels = [y_test_g, y_test_a]

if use_augmentation:
    datagen = ImageDataGenerator(
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip=True,
        preprocessing_function=get_random_eraser(v_l=0, v_h=255))
    training_generator = MixupGenerator(X_train, train_labels, batch_size=batch_size, alpha=0.2,
                                        datagen=datagen)()
    hist = model.fit_generator(generator=training_generator,
                               steps_per_epoch=train_num // batch_size,
                               validation_data=(X_test, test_labels),
                               epochs=nb_epochs, verbose=1,
                               callbacks=callbacks)
else:
    hist = model.fit(X_train, train_labels, batch_size=batch_size, epochs=nb_epochs, callbacks=callbacks,
                     validation_data=(X_test, test_labels))
    
logging.debug("Saving history...")
pd.DataFrame(hist.history).to_hdf(output_path.joinpath("history_{}_{}.h5".format(depth, k)), "history")

### Load history

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
input_dir = "/home/biometrics/Krit/__backup_age-gender-estimation/01_ageonly_lr_1e-1_ep30"
hdf_filename = "history_16_8.h5"

history_df = pd.read_hdf(os.path.join(input_dir, hdf_filename))

In [None]:
history_df.tail()

In [None]:
def save_figure(history_df, output, key_prefix, save_dir):
    if len(key_prefix) > 0:
        key_prefix = key_prefix + "_"
    plt.cla()
    plt.plot(history_df[key_prefix + "loss"], label="loss (" + output + ")")
    plt.plot(history_df["val_" + key_prefix + "loss"], label="val_loss (" + output + ")")
    plt.xlabel("number of epochs")
    plt.ylabel("loss")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "loss.png"))

    key_acc = "acc" if key_prefix + "acc" in history_df.columns else "accuracy"
    plt.cla()
    plt.plot(history_df[key_prefix + key_acc], label="accuracy (" + output + ")")
    plt.plot(history_df["val_" + key_prefix + key_acc], label="val_accuracy (" + output + ")")
    plt.xlabel("number of epochs")
    plt.ylabel("accuracy")
    plt.legend()
    plt.savefig(os.path.join(save_dir, "accuracy.png"))

In [None]:
output = ""; key_prefix = "";
# output = "gender"; key_prefix = "pred_" + output;
# output = "age"; key_prefix = "pred_" + output;
save_figure(history_df, output, key_prefix, input_dir)