In [None]:
import os
import math
import json
import pickle
import random
import numpy as np
import tensorflow as tf
from tensorflow.data import Dataset
from tensorflow import keras

from datetime import datetime
from matplotlib import pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from tensorflow.keras.optimizers import RMSprop, SGD, Adam
from tensorflow.keras.applications import MobileNet, ResNet50, InceptionV3
from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess
from tensorflow.keras.applications.resnet50 import preprocess_input as resnet_preprocess
from tensorflow.keras.applications.inception_v3 import preprocess_input as inception_preprocess
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing import image
from tensorflow.keras import backend as K
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import Callback, LearningRateScheduler
from tensorflow.keras.layers import Input, Flatten, Dense, Dropout, Lambda, Conv1D, Attention, GlobalAveragePooling1D, BatchNormalization, Layer
from keras_facenet import FaceNet

random.seed(123)
tf.random.set_seed(12)
np.random.seed(123)

In [None]:
with open('data/train_img_embeddings.pkl', 'rb') as f:
       train_embeddings = pickle.load(f)
print(f'The keys examples: {list(train_embeddings.keys())[:5]}')

embedding_shape = list(list(train_embeddings.values())[0].values())[0].shape
print(f'Embeddings shape: {embedding_shape}')

In [None]:
cnt = 0
for k, v in train_embeddings.items():
    cnt += len(v)

print(f'Total imgs: {cnt}')

In [None]:
input_shape = (224, 224, 3)
train_path = './data/train'

In [None]:
def mobilenet(input_shape, l2_value, dropout):
    mobile = MobileNet(
        input_shape=input_shape,
        dropout=dropout,
        include_top=False,
        pooling='avg',
        alpha=1.,
        weights='imagenet'
    )
    
    for layer in mobile.layers:
        layer.trainable = True
        if hasattr(layer, 'kernel_regularizer'):
            setattr(layer, 'kernel_regularizer', keras.regularizers.l2(l2_value))
        
    x = Dense(512, kernel_regularizer=l2(l2_value), activation='relu')(mobile.output)
    x = Lambda(lambda x: K.l2_normalize(x,axis=1))(x)
    return Model(mobile.input, x)

In [None]:
def batching(embeddings, batch_size, input_shape, preprocess):
    cnt = 0
    imgs = []
    labels = []
    for person, embs in embeddings.items():
        person_path = os.path.join(train_path, person)
        
        for img_name, emb in embs.items():
            img_path = os.path.join(person_path, img_name)
            img = image.load_img(img_path, target_size=(input_shape[0], input_shape[1]))
            img = np.array(img).astype('float32')
            imgs.append(img)
            labels.append(emb)
            if len(labels) == batch_size:
                yield (preprocess(np.array(imgs)), np.array(labels).astype(float))
                imgs, labels = [], []

def repeat_generator(embeddings, batch_size, input_shape, preprocess):
    while True:
        for e in batching(embeddings, batch_size, input_shape, preprocess):
            yield e

In [None]:
# Training-Validation split
VAL_FACTOR = 0.12
keys = list(train_embeddings.keys())
random.shuffle(keys)
keys_length = len(keys)
val_factor = int(keys_length * VAL_FACTOR)
val_keys = keys[:val_factor]
train_keys = keys[val_factor:]
print(f'Total keys: {keys_length}, train keys: {len(train_keys)}, valid keys: {len(val_keys)}')

val_embs = {k:train_embeddings[k] for k in val_keys}
train_embs = {k:train_embeddings[k] for k in train_keys}

train_len = 0
for k, v in train_embs.items():
    train_len +=len(v.keys())

val_len = 0
for k, v in val_embs.items():
    val_len +=len(v.keys())

print(f'Total - train imgs: {train_len}, valid imgs: {val_len}')

In [None]:
lr = 7e-5
l2_value = 1e-5
dropout = 0.25
optimizer = 'Adam'
batch_size = 12
epochs = 1000

model = mobilenet(input_shape, l2_value, dropout)
optimizer = eval(optimizer)(learning_rate=lr)
model.compile(loss='cosine_similarity', optimizer=optimizer)

In [None]:
model_name = 'model_002_mobile_512'
ckpt_dir = os.path.join('pretrained/checkpoints', model_name)
log_dir = os.path.join('pretrained/logs', model_name)

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

ckpt_callback = keras.callbacks.ModelCheckpoint(
    filepath=os.path.join(ckpt_dir, 'weights.{epoch:02d}.hdf5'),
    save_weights_only=True,
    period=3
)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)

model.fit(
    repeat_generator(train_embs, batch_size, input_shape, mobilenet_preprocess),
    epochs=epochs,
    steps_per_epoch=train_len // batch_size,
    validation_data=repeat_generator(val_embs, batch_size, input_shape, mobilenet_preprocess),
    validation_steps=val_len // batch_size,
    callbacks=[ckpt_callback, tb_callback]
)