In [None]:
import os
import glob
import numpy as np
import tensorflow as tf

from collections import defaultdict
from random import shuffle
from tensorflow import keras
from tensorflow.keras import Model, layers
import tensorflow.keras.backend as K
from tensorflow.keras.preprocessing import image
from tensorflow.keras.regularizers import l2
from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications.mobilenet import preprocess_input as mobilenet_preprocess

In [None]:
TRAIN_PATH = './data/train'
input_shape = (224, 224, 3)

img_paths = glob.glob(os.path.join(f'{TRAIN_PATH}/*/*/*.jpg'))
labels = [os.path.split(p)[0] for p in img_paths]
persons = list(set(labels))
persons_enc = dict(zip(persons, range(len(persons))))
labels = list(map(lambda x: persons_enc[x], labels))
num_classes = len(persons)

# Create validation set manually by sampling one image for each person
train_x, train_y, val_x, val_y = [], [], [], []
idx_examples = defaultdict(list)

for i in range(len(labels)):
    idx_examples[labels[i]].append(img_paths[i])

for k, v in idx_examples.items():
    val_y.append(k)
    train_y.extend([k]*(len(v) - 1))
    val_x.append(v[0])
    train_x.extend(v[1:])

In [None]:
def batching(img_paths, labels, preprocess, batch_size, input_shape, num_classes, shuffle=True):
    indexes = np.arange(len(img_paths))
    while True:
        if shuffle:
            np.random.shuffle(indexes)
        
        X = []
        y = []
        counter = 0
        for idx in indexes:
            img = image.load_img(img_paths[idx],
                                 target_size=(input_shape[0], input_shape[1]))
            label = labels[idx]
            
            X.append(np.array(img).astype(np.float32))
            y.append(label)
            
            if len(y) == batch_size:
                yield (preprocess(np.array(X)), keras.utils.to_categorical(np.array(y), num_classes=num_classes))
                X, y = [], []

In [None]:
def mobilenet(input_shape, l2_value, dropout, model_name='mobilenet'):
    mobile = MobileNet(
        input_shape=input_shape,
        dropout=dropout,
        include_top=False,
        pooling='max',
        alpha=1.,
        weights='imagenet'
    )
    
    for layer in mobile.layers:
        if hasattr(layer, 'kernel_regularizer'):
            setattr(layer, 'kernel_regularizer', keras.regularizers.l2(l2_value))
        
    x = layers.Dense(1024, kernel_regularizer=l2(l2_value), activation='relu')(mobile.output)
    x = layers.Lambda(lambda x: K.l2_normalize(x,axis=1))(x)
    return Model(mobile.input, x)

In [None]:
learning_rate = 1e-4
l2_value = 1e-9
dropout = 0.1
batch_size = 48
num_epochs = 1000
temperature = 0.1
width = 1024
contrastive_augmentation = {"min_area": 0.25, "brightness": 0.6, "jitter": 0.2}

# Learning rate scheduler
def scheduler(epoch, lr):
    if epoch == 10:
        return 0.5 * lr
    elif epoch == 180:
        return 0.8 * lr
    elif epoch == 250:
        return 0.5 * lr
    elif epoch == 300:
        return 0.7 * lr
    elif epoch == 400:
        return 0.3 *lr
    elif epoch == 700:
        return 0.6 *lr
    return lr
    
lr_callback = keras.callbacks.LearningRateScheduler(scheduler)

In [None]:
train_gen = batching(train_x, train_y, mobilenet_preprocess, batch_size, input_shape, num_classes)
val_gen = batching(val_x, val_y, mobilenet_preprocess, batch_size, input_shape, num_classes)

In [None]:
base_network = mobilenet(input_shape, l2_value, dropout)
print(f'NN number of parameters: {base_network.count_params()}')

In [None]:
# Distorts the color distibutions of images
class RandomColorAffine(layers.Layer):
    def __init__(self, brightness=0, jitter=0, **kwargs):
        super().__init__(**kwargs)

        self.brightness = brightness
        self.jitter = jitter

    def call(self, images, training=True):
        if training:
            batch_size = tf.shape(images)[0]

            # Same for all colors
            brightness_scales = 1 + tf.random.uniform(
                (batch_size, 1, 1, 1), minval=-self.brightness, maxval=self.brightness
            )
            # Different for all colors
            jitter_matrices = tf.random.uniform(
                (batch_size, 1, 3, 3), minval=-self.jitter, maxval=self.jitter
            )

            color_transforms = (
                tf.eye(3, batch_shape=[batch_size, 1]) * brightness_scales
                + jitter_matrices
            )
            images = tf.clip_by_value(tf.matmul(images, color_transforms), 0, 1)
        return images

# Image augmentation module
def get_augmenter(min_area, brightness, jitter):
    zoom_factor = 1.0 - tf.sqrt(min_area)
    return keras.Sequential(
        [
            keras.Input(shape=input_shape),
            layers.experimental.preprocessing.Rescaling(1 / 255),
            layers.experimental.preprocessing.RandomFlip("horizontal"),
            layers.experimental.preprocessing.RandomTranslation(zoom_factor / 2, zoom_factor / 2),
            layers.experimental.preprocessing.RandomZoom((-zoom_factor, 0.0), (-zoom_factor, 0.0)),
            RandomColorAffine(brightness, jitter),
        ]
    )

In [None]:
# Implementation from Keras code examples
# Define the contrastive model with model-subclassing
class ContrastiveModel(keras.Model):
    def __init__(self):
        super().__init__()

        self.temperature = temperature
        self.contrastive_augmenter = get_augmenter(**contrastive_augmentation)
        self.encoder = base_network
        
        # Non-linear MLP as projection head
        self.projection_head = keras.Sequential(
            [
                keras.Input(shape=(width,)),
                layers.Dense(width, activation="relu"),
                layers.Dense(width),
            ],
            name="projection_head",
        )
        self.encoder.summary()
        self.projection_head.summary()

    def compile(self, contrastive_optimizer, **kwargs):
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer

        self.contrastive_loss_tracker = keras.metrics.Mean(name="c_loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(
            name="c_acc"
        )

    @property
    def metrics(self):
        return [
            self.contrastive_loss_tracker,
            self.contrastive_accuracy
        ]

    def contrastive_loss(self, projections_1, projections_2):
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # Cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = tf.math.l2_normalize(projections_1, axis=1)
        projections_2 = tf.math.l2_normalize(projections_2, axis=1)
        similarities = (
            tf.matmul(projections_1, projections_2, transpose_b=True) / self.temperature
        )

        # The similarity between the representations of two augmented views of the
        # same image should be higher than their similarity with other views
        batch_size = tf.shape(projections_1)[0]
        contrastive_labels = tf.range(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(
            contrastive_labels, tf.transpose(similarities)
        )

        # The temperature-scaled similarities are used as logits for cross-entropy
        # a symmetrized version of the loss is used here
        loss_1_2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, similarities, from_logits=True
        )
        loss_2_1 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, tf.transpose(similarities), from_logits=True
        )
        return (loss_1_2 + loss_2_1) / 2

    def train_step(self, data):
        images, _ = data
        
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=True)
        augmented_images_2 = self.contrastive_augmenter(images, training=True)
        with tf.GradientTape() as tape:
            features_1 = self.encoder(augmented_images_1, training=True)
            features_2 = self.encoder(augmented_images_2, training=True)
            # The representations are passed through a projection mlp
            projections_1 = self.projection_head(features_1, training=True)
            projections_2 = self.projection_head(features_2, training=True)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projection_head.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projection_head.trainable_weights,
            )
        )
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        return {m.name: m.result() for m in self.metrics}
    
    def test_step(self, data):
        images, _ = data
        
        # Each image is augmented twice, differently
        augmented_images_1 = self.contrastive_augmenter(images, training=False)
        augmented_images_2 = self.contrastive_augmenter(images, training=False)
        
        features_1 = self.encoder(augmented_images_1, training=False)
        features_2 = self.encoder(augmented_images_2, training=False)
        # The representations are passed through a projection mlp
        projections_1 = self.projection_head(features_1, training=False)
        projections_2 = self.projection_head(features_2, training=False)
        contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        
        self.contrastive_loss_tracker.update_state(contrastive_loss)

        return {m.name: m.result() for m in self.metrics}

    def call(self, images):
        pass
    
    

In [None]:
model_name = 'mobilenet_contrastive'
ckpt_dir = os.path.join('pretrained/checkpoints', model_name)

# Contrastive pretraining
pretraining_model = ContrastiveModel()
pretraining_model.compile(
    contrastive_optimizer=keras.optimizers.Adam()
)

class SaveCheckpointCallback(keras.callbacks.Callback):
    def __init__(self, ckpt_dir):
        super(SaveCheckpointCallback, self).__init__()
        if not os.path.exists(ckpt_dir):
            os.makedirs(ckpt_dir)
        self.val_loss = 1e10
    
    def on_epoch_end(self, epoch, logs={}):
        curr_loss = logs['val_c_loss']
        
        if curr_loss < self.val_loss:
            self.val_loss = curr_loss
            base_network.save(os.path.join(ckpt_dir, 'model.hdf5'), overwrite=True)

pretraining_model.fit(train_gen,
                      validation_data=val_gen,
                      epochs=num_epochs,
                      steps_per_epoch=len(train_x)//batch_size,
                      validation_steps = len(val_x)//batch_size,
                      callbacks=[SaveCheckpointCallback(ckpt_dir)])