In [None]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [None]:
import time
import numpy as np
import pandas as pd
from sklearn.model_selection import KFold

from lmmnn.layers import NLL
from lmmnn.callbacks import EarlyStoppingWithSigmasConvergence, PrintSigmas
from lmmnn.menet import menet_fit_generator, menet_predict_generator

import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import Model
from tensorflow.keras.callbacks import EarlyStopping, Callback
from tensorflow.keras.layers import (Concatenate, Conv2D, Dense, Dropout,
                                     Embedding, Flatten, Input, MaxPool2D,
                                     Reshape, GlobalAveragePooling2D, Layer)
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
gpus = tf.config.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
class Count:
    curr = 0

    def __init__(self, startWith=None):
        if startWith is not None:
            Count.curr = startWith - 1

    def gen(self):
        while True:
            Count.curr += 1
            yield Count.curr
    
    def __call__(self):
        return Count.curr

In [None]:
IMG_WIDTH = 178
IMG_HEIGHT = 218
batch_size = 20
epochs = 100
patience = 10
images_df = pd.read_csv('data/list_landmarks_align_celeba_processed.csv')
images_dir = 'data/img_align_celeba_png/'
images_df['nose_x'] = images_df['nose_x'].astype(np.float64)
images_df['celeb'] = images_df['celeb'] - 1
images_df.sort_values(by = ['celeb'], inplace=True)
res_df = pd.DataFrame(columns=['exp_type', 'mse', 'mae', 'sigma_e_est', 'sigma_b_est', 'n_epochs', 'time'])
counter = Count().gen()
n_cats = images_df['celeb'].max() + 1
kf = KFold(n_splits=5, shuffle=True, random_state=42)
out_file = 'results/res_celeba.csv'
imgfile2celeb = images_df.set_index('img_file')['celeb'].to_dict()

In [None]:
def sample_split(seed, train_index_subj, valid_frac = 0.1):
    np.random.seed(seed)
    train_n = len(train_index_subj)
    valid_samp = np.random.choice(train_n, int(valid_frac * train_n), replace=False)
    valid_index_subj = train_index_subj[valid_samp]
    train_index_subj = np.delete(train_index_subj, valid_samp)
    return train_index_subj, valid_index_subj

def cnn_ignore():
    input_layer = Input((IMG_HEIGHT, IMG_WIDTH, 3))
    x = Conv2D(32, (5, 5), activation='relu')(input_layer)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(64, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(32, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(16, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)
    x = Dense(100, activation='relu')(x)
    output = Dense(1)(x)
    return Model(inputs=[input_layer], outputs=output)

def cnn_ignore_inception():
    base = InceptionV3(weights='imagenet', include_top=False, input_shape = (IMG_HEIGHT, IMG_WIDTH, 3))
    x = base.output
    x = GlobalAveragePooling2D()(x)
    output = Dense(1)(x)
    model = Model(inputs = base.input, outputs = output)
    train_top = 55
    for layer in model.layers[:-train_top]:
        layer.trainable = False
    for layer in model.layers[-train_top:]:
        layer.trainable = True
    return model

def cnn_lmmnn():
    input_layer = Input((IMG_HEIGHT, IMG_WIDTH, 3))
    y_true_input = Input(shape=(1, ),)
    Z_input = Input(shape=(1, ), dtype=tf.int64)
    x = Conv2D(32, (5, 5), activation='relu')(input_layer)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(64, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(32, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(16, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)
    x = Dense(100, activation='relu')(x)
    y_pred_output = Dense(1)(x)
    nll = NLL(10.0, 10.0)(y_true_input, y_pred_output, Z_input)
    return Model(inputs=[input_layer, y_true_input, Z_input], outputs=nll)

def cnn_lmmnn_inception():
    y_true_input = Input(shape=(1, ),)
    Z_input = Input(shape=(1, ), dtype=tf.int64)
    base = InceptionV3(weights='imagenet', include_top=False, input_shape = (IMG_HEIGHT, IMG_WIDTH, 3))
    x = base.output
    x = GlobalAveragePooling2D()(x)
    y_pred_output = Dense(1)(x)
    nll = NLL(1.0, 1.0)(y_true_input, y_pred_output, Z_input)
    model = Model(inputs=[base.input, y_true_input, Z_input], outputs=nll)
    train_top = 58
    for layer in model.layers[:-train_top]:
        layer.trainable = False
    for layer in model.layers[-train_top:]:
        layer.trainable = True
    return model

def cnn_embedding(n_cats, embed_dim):
    input_layer = Input((IMG_HEIGHT, IMG_WIDTH, 3))
    Z_input = Input(shape=(1,))
    embed = Embedding(n_cats, embed_dim, input_length = 1)(Z_input)
    embed = Reshape(target_shape = (embed_dim,))(embed)
    x = Conv2D(32, (5, 5), activation='relu')(input_layer)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(64, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(32, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Conv2D(16, (5, 5), activation='relu')(x)
    x = MaxPool2D((2, 2))(x)
    x = Flatten()(x)
    x = Dropout(0.5)(x)
    x = Dense(100, activation='relu')(x)
    concat = Concatenate()([x, embed])
    output = Dense(1)(concat)
    return Model(inputs=[input_layer, Z_input], outputs=output)

def cnn_embedding_inception(n_cats, embed_dim):
    base = InceptionV3(weights='imagenet', include_top=False, input_shape = (IMG_HEIGHT, IMG_WIDTH, 3))
    Z_input = Input(shape=(1,))
    embed = Embedding(n_cats, embed_dim, input_length = 1)(Z_input)
    embed = Reshape(target_shape = (embed_dim,))(embed)
    x = base.output
    x = GlobalAveragePooling2D()(x)
    concat = Concatenate()([x, embed])
    output = Dense(1)(concat)
    model = Model(inputs = [base.input, Z_input], outputs = output)
    train_top = 55
    for layer in model.layers[:-train_top]:
        layer.trainable = False
    for layer in model.layers[-train_top:]:
        layer.trainable = True
    return model

def calc_b_hat(Z_train, y_train, y_pred_tr, n_cats, sig2e, sig2b):
    b_hat = []
    for i in range(n_cats):
        i_vec = Z_train == i
        n_i = i_vec.sum()
        if n_i > 0:
            y_bar_i = y_train[i_vec].mean()
            y_pred_i = y_pred_tr[i_vec].mean()
            # BP(b_i) = (n_i * sig2b / (sig2a + n_i * sig2b)) * (y_bar_i - y_pred_bar_i)
            b_i = n_i * sig2b * (y_bar_i - y_pred_i) / (sig2e + n_i * sig2b)
        else:
            b_i = 0
        b_hat.append(b_i)
    return np.array(b_hat)

def custom_train_generator_lmmnn(train_generator, epochs):
    count = 0
    while True:
        if count == train_generator.n * epochs:
            train_generator.reset()
            break
        count += train_generator.batch_size
        data = train_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, y_true, Z], None

def custom_valid_generator_lmmnn(valid_generator, epochs):
    count = 0 
    while True:
        if count == valid_generator.n * epochs:
            valid_generator.reset()
            break
        count += valid_generator.batch_size
        data = valid_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, y_true, Z], None

def custom_test_generator_lmmnn(test_generator, epochs):
    count = 0 
    while True:
        if count == test_generator.n * epochs:
            test_generator.reset()
            break
        count += test_generator.batch_size
        data = test_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, y_true, Z], None

def custom_train_generator_embed(train_generator, epochs):
    count = 0 
    while True:
        if count == train_generator.n * epochs:
            train_generator.reset()
            break
        count += train_generator.batch_size
        data = train_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, Z], y_true

def custom_valid_generator_embed(valid_generator, epochs):
    count = 0 
    while True:
        if count == valid_generator.n * epochs:
            valid_generator.reset()
            break
        count += valid_generator.batch_size
        data = valid_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, Z], y_true

def custom_test_generator_embed(test_generator, epochs):
    count = 0 
    while True:
        if count == test_generator.n * epochs:
            test_generator.reset()
            break
        count += test_generator.batch_size
        data = test_generator.next()
        imgs = data[0]
        y_true = data[1][:, 0]
        Z = data[1][:, 1]
        yield [imgs, Z], y_true

def factors(n):    # (cf. https://stackoverflow.com/a/15703327/849891)
    j = 2
    while n > 1:
        for i in range(j, int(np.sqrt(n+0.05)) + 1):
            if n % i == 0:
                n /= i ; j = i
                yield i
                break
        else:
            if n > 1:
                yield n; break

def get_batchsize_steps(n):
    factors_n = list(factors(n))
    if len(factors_n) > 1:
        batch_size = factors_n[-2]
    else:
        batch_size = 1
    steps = n // batch_size
    return batch_size, steps

def get_generators(images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj, batch_size, reg_type):
    train_datagen = ImageDataGenerator(rescale = 1./255) # preprocessing_function = preprocess_input # for inception
    valid_datagen = ImageDataGenerator(rescale = 1./255) # preprocessing_function = preprocess_input # for inception
    test_datagen = ImageDataGenerator(rescale = 1./255) # preprocessing_function = preprocess_input # for inception
    if reg_type in ['ignore', 'menet']:
        y_cols = ['age']
    else:
        y_cols = ['age', 'subject_id2']
    train_generator = train_datagen.flow_from_dataframe(
        images_df[images_df['subject_id2'].isin(train_samp_subj)],
        directory = images_dir,
        x_col = 'image_id',
        y_col = y_cols,
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = True,
        validate_filenames = False
    )
    valid_generator = valid_datagen.flow_from_dataframe(
        images_df[images_df['subject_id2'].isin(valid_samp_subj)],
        directory = images_dir,
        x_col = 'image_id',
        y_col = y_cols,
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = False,
        validate_filenames = False
    )
    test_generator = test_datagen.flow_from_dataframe(
        images_df[images_df['subject_id2'].isin(test_samp_subj)],
        directory = images_dir,
        x_col = 'image_id',
        y_col = y_cols,
        target_size = (IMG_HEIGHT, IMG_WIDTH),
        class_mode = 'raw',
        batch_size = batch_size,
        shuffle = False,
        validate_filenames = False
    )
    return train_generator, valid_generator, test_generator

def reg_nn_ignore(train_generator, valid_generator, test_generator, n_cats, epochs, patience):
    model = cnn_ignore()
    model.compile(loss='mse', optimizer='adam')
    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]
    history = model.fit(train_generator, validation_data = valid_generator, epochs=epochs, callbacks=callbacks, verbose=1)
    y_pred = model.predict(test_generator, verbose=1).reshape(test_generator.n)
    return y_pred, (None, None), len(history.history['loss'])

def reg_nn_lmm(train_generator, valid_generator, test_generator, n_cats, epochs, patience):
    model = cnn_lmmnn()
    model.compile(optimizer= 'adam')
    
    patience = epochs if patience is None else patience
    callbacks = [EarlyStoppingWithSigmasConvergence(patience=patience), PrintSigmas()]
    step_size_train = train_generator.n // train_generator.batch_size
    step_size_valid = valid_generator.n // valid_generator.batch_size
    history = model.fit(custom_train_generator_lmmnn(train_generator, epochs), steps_per_epoch = step_size_train,
        validation_data = custom_valid_generator_lmmnn(valid_generator, epochs), validation_steps = step_size_valid,
        epochs=epochs, callbacks=callbacks, verbose=1)
    
    sig2e_est, sig2b_est = model.layers[-1].get_vars()
    
    batch_size_train, steps_train = get_batchsize_steps(train_generator.n)
    train_generator.reset()
    train_generator.batch_size = batch_size_train
    y_pred_tr = model.predict(custom_train_generator_lmmnn(train_generator, 1),
                              steps = steps_train,
                              verbose=1).reshape(train_generator.n)
    y_train = train_generator.labels[:, 0]
    Z_train = train_generator.labels[:, 1].astype(np.int)
    Z_test = test_generator.labels[:, 1].astype(np.int)
    b_hat = calc_b_hat(Z_train, y_train, y_pred_tr, n_cats, sig2e_est, sig2b_est)
    batch_size_test, steps_test = get_batchsize_steps(test_generator.n)
    test_generator.batch_size = batch_size_test
    y_pred = model.predict(custom_test_generator_lmmnn(test_generator, 1),
                           steps = steps_test, verbose=0).reshape(test_generator.n) + b_hat[Z_test]
    return y_pred, (sig2e_est, sig2b_est), len(history.history['loss'])

def reg_nn_embed(train_generator, valid_generator, test_generator, n_cats, epochs, patience):
    embed_dim = 10
    model = cnn_embedding(n_cats, embed_dim)
    model.compile(loss='mse', optimizer='adam')
    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]
    
    step_size_train = train_generator.n // train_generator.batch_size
    step_size_valid = valid_generator.n // valid_generator.batch_size
    
    history = model.fit(custom_train_generator_embed(train_generator, epochs), steps_per_epoch = step_size_train,
        validation_data = custom_valid_generator_embed(valid_generator, epochs),
        validation_steps = step_size_valid,
        epochs=epochs, callbacks=callbacks, verbose=1)
    batch_size_test, steps_test = get_batchsize_steps(test_generator.n)
    test_generator.batch_size = batch_size_test
    y_pred = model.predict(custom_test_generator_embed(test_generator, 1),
                           steps = steps_test, verbose=0).reshape(test_generator.n)
    return y_pred, (None, None), len(history.history['loss'])

def reg_nn_menet(train_generator, valid_generator, test_generator, n_cats, epochs, patience):
    model = cnn_ignore()
    model.compile(loss='mse', optimizer='adam')
    callbacks = [EarlyStopping(monitor='val_loss', patience=epochs if patience is None else patience)]

    clusters_train = np.array([imgfile2celeb[image_id] for image_id in train_generator.filenames])
    clusters_valid = np.array([imgfile2celeb[image_id] for image_id in valid_generator.filenames])
    clusters_test = np.array([imgfile2celeb[image_id] for image_id in test_generator.filenames])

    model, b_hat, sig2e_est, n_epochs, _ = menet_fit_generator(model, train_generator, valid_generator,
        clusters_train, clusters_valid, n_cats,
        epochs=epochs, callbacks=callbacks, patience=patience, verbose=1)
    y_pred = menet_predict_generator(model, test_generator, clusters_test, n_cats, b_hat).reshape(test_generator.n)
    return y_pred, (sig2e_est, None), n_epochs

def reg_nn(images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj,
    n_cats, batch_size=20, epochs=100, patience=10, reg_type='ignore'):
    start = time.time()
    train_generator, valid_generator, test_generator = get_generators(
        images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj, batch_size, reg_type)
    if reg_type == 'ignore':
        y_pred, sigmas, n_epochs = reg_nn_ignore(train_generator, valid_generator, test_generator, n_cats, epochs, patience)
    elif reg_type == 'lmm':
        y_pred, sigmas, n_epochs = reg_nn_lmm(train_generator, valid_generator, test_generator, n_cats, epochs, patience)
    elif reg_type == 'embed':
        y_pred, sigmas, n_epochs = reg_nn_embed(train_generator, valid_generator, test_generator, n_cats, epochs, patience)
    elif reg_type == 'menet':
        y_pred, sigmas, n_epochs = reg_nn_menet(train_generator, valid_generator, test_generator, n_cats, epochs, patience)
    else:
        raise ValueError(reg_type + ' is an unknown reg_type')
    end = time.time()
    y_test = test_generator.labels[:, 0]
    mse = np.mean((y_pred - y_test)**2)
    mae = np.mean(np.abs(y_pred - y_test))
    return mse, mae, sigmas, n_epochs, end - start

def iterate_reg_types(images_df, images_dir, res_df, counter, n_cats, train_samp_subj, valid_samp_subj, test_samp_subj,
    out_file, batch_size, epochs, patience):
        mse_lmm, mae_lmm, sigmas, n_epochs_lmm, time_lmm = reg_nn(images_df, images_dir, train_samp_subj, valid_samp_subj,                          test_samp_subj, n_cats, batch_size=batch_size, epochs=epochs, patience=patience, reg_type='lmm')
        print(' finished lmm, mse: %.2f, mae: %.2f' % (mse_lmm, mae_lmm))
        mse_ig, mae_ig, _, n_epochs_ig, time_ig = reg_nn(images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj,                   n_cats, batch_size=batch_size, epochs=epochs, patience=patience, reg_type='ignore')
        print(' finished ignore, mse: %.2f, mae: %.2f' % (mse_ig, mae_ig))
        mse_em, mae_em, _, n_epochs_em, time_em = reg_nn(images_df, images_dir, train_samp_subj, valid_samp_subj, test_samp_subj,                   n_cats, batch_size=batch_size, epochs=epochs, patience=patience, reg_type='embed')
        print(' finished embed, mse: %.2f, mae: %.2f' % (mse_em, mae_em))
        mse_me, mae_me, sigmas_me, n_epochs_me, time_me = reg_nn(images_df, images_dir, train_samp_subj, valid_samp_subj,                           test_samp_subj, n_cats, batch_size=batch_size, epochs=epochs, patience=patience, reg_type='menet')
        print(' finished menet, mse: %.2f, mae: %.2f' % (mse_me, mae_me))
        res_df.loc[next(counter)] = ['ignore', mse_ig, mae_ig, np.nan, np.nan, n_epochs_ig, time_ig]
        res_df.loc[next(counter)] = ['lmm', mse_lmm, mae_lmm, sigmas[0], sigmas[1], n_epochs_lmm, time_lmm]
        res_df.loc[next(counter)] = ['embed', mse_em, mae_em, np.nan, np.nan, n_epochs_em, time_em]
        res_df.loc[next(counter)] = ['menet', mse_me, mae_me, sigmas_me[0], np.nan, n_epochs_me, time_me]
        res_df.to_csv(out_file)


In [None]:
for i, (train_index_subj, test_index_subj) in enumerate(kf.split(np.zeros(n_cats), np.zeros(n_cats))):
        print('iteration: %d' % i)
        train_index_subj, valid_index_subj = sample_split(i, train_index_subj)
        iterate_reg_types(images_df, images_dir, res_df, counter, n_cats, train_index_subj, valid_index_subj, test_index_subj,
            out_file, batch_size=batch_size, epochs=epochs, patience=patience)