Training

In [None]:
!pip install albumentations==1.3.0
!pip install -I opencv-python-headless==4.1.2.30

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
import math
import os
import cv2
import numpy as np
import tempfile
import threading
import pandas
from itertools import cycle
from imutils.paths import list_images
import albumentations as A
from tqdm.auto import tqdm

from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator

import tensorflow as tf
from tensorflow.keras.applications.mobilenet_v2 import MobileNetV2, preprocess_input
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.losses import binary_crossentropy
from tensorflow import math
from tensorflow.keras import backend as K

In [None]:
def center_crop(path):
    images = list(list_images(path))

    for filename in tqdm(images):
        original = cv2.imread(filename)

        # Center crop image from size 218*178 to 160*160
        y, x = original.shape[:2]
        yo, xo = (y - 160) // 2, (x - 160) // 2
        cropped = original[yo:yo + 160, xo:xo + 160]
        cropped = cv2.resize(cropped, (160, 160))
        cv2.imwrite(filename, cropped)

if not os.path.isfile("/content/drive/MyDrive/celeba/cropped_images.zip"):
  !unzip -qq /content/drive/MyDrive/celeba/img_align_celeba.zip
  center_crop("img_align_celeba")
  !zip -r /content/drive/MyDrive/celeba/cropped_images.zip img_align_celeba

else:
  !unzip -qq /content/drive/MyDrive/celeba/cropped_images.zip

In [None]:
def euclidean(p):
    euclidean_2d = tf.reduce_sum(tf.square(p - p[:, None]), axis=-1) / 2.0
    return euclidean_2d

def loss(t, p):
    # unpack euclidean matrix
    euclidean_2d = euclidean(p)
    # set positives > 1 not to include them in min negatives
    neg = tf.reduce_min(euclidean_2d + t, axis=-1)
    # clip negative to avoid sparsing
    neg = tf.minimum(neg, 0.5) / 0.5
    # positive distances
    pos = tf.reshape(tf.boolean_mask(euclidean_2d, t == 1), [-1])
    # make labels
    pos_lbls, neg_lbls = tf.zeros_like(pos), tf.ones_like(neg)
    # take binary crossentropy
    return binary_crossentropy([pos_lbls, neg_lbls], [pos, neg])

def mean_pos(t, p):
    # unpack euclidean matrix
    euclidean_2d = euclidean(p)
    # get average of positive distances
    return tf.reduce_mean(tf.boolean_mask(euclidean_2d, t == 1))

def mean_neg(t, p):
    # unpack euclidean matrix
    euclidean_2d = euclidean(p)
    # get average of negative distances
    return tf.reduce_mean(tf.boolean_mask(euclidean_2d, t == 0))

def accuracy(t, p):
    # unpack euclidean matrix
    euclidean_2d = euclidean(p)
    temp = euclidean_2d + tf.cast(t == 2, p.dtype)
    # average of rows where positive distance is the smallest value
    acc = tf.equal(tf.argmin(temp, axis=-1), tf.argmax(t == 1, axis=-1)),
    return tf.reduce_mean(tf.cast(acc, p.dtype))

In [None]:
class ImageGen:

    def __init__(self, batch_size, path, identites, subset="train"):
        self.path = path
        self.subset = subset
        self.bs = batch_size

        self.lock = threading.Lock()
        self.data = self.unpack_data(identites, path)
        self.rand_gen = self.get_rand_generator()
        self.it = self.get_iterator()

    def unpack_data(self, identities_path, unzip_path):
        # read identity_CelebA
        df = pandas.read_csv(identities_path)

        # get filenames for each identity except bad images labelled as "0"
        df = df.loc[df['identity'] != 0]

        # get unique classes
        unique = np.unique(df["identity"])

        # make dictionary of identities and related images
        identity_dict = {}

        identities = df['identity'].values
        for class_name in tqdm(unique):
            filenames = df[np.in1d(identities, [class_name])]["filename"]
            imgs = list(map(lambda x: os.path.join(unzip_path, x), filenames))

            # train images
            if self.subset == "train" and len(filenames) >= 2:
                combs = self.chunker(imgs, 2)
                identity_dict[class_name] = combs

            # test images
            if self.subset != "train" and len(filenames) == 1:
                identity_dict[class_name] = cycle([[imgs[0], imgs[0]]])

        return identity_dict

    @staticmethod
    def get_rand_generator():
        transform = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.GaussNoise(p=0.5),
            A.HueSaturationValue(hue_shift_limit=20,
                                 sat_shift_limit=30,
                                 val_shift_limit=20,
                                 p=0.5),
            A.ShiftScaleRotate(rotate_limit=15,
                               scale_limit=(-0.16, 0.16),
                               shift_limit=0.15,
                               border_mode=cv2.BORDER_CONSTANT,
                               p=1.0),
            A.RandomBrightnessContrast(brightness_limit=0.25, p=1.0)
        ])
        return transform

    def get_image(self, path):
        img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        img = self.rand_gen(image=img)["image"]
        return np.expand_dims(img, axis=0)

    def create_static_label(self):
        y = np.tile(np.arange(self.bs), 2)
        y = np.equal(y, y[:, None]).astype(float)
        y[np.diag_indices(self.bs * 2)] = 2
        return y

    def get_iterator(self):

        # data chunker for sampling classes
        chunker = self.chunker(list(self.data.keys()), self.bs)

        # holders for batch data
        x = np.zeros((self.bs * 2, 160, 160, 3), dtype=np.float32)
        y = self.create_static_label()

        for batch in chunker:
            
            # (class1_img1, class1_img2), (class2_img1, class2_img2)...
            all_paths = [next(self.data[cls]) for cls in batch]
            
            for index, (path1, path2) in enumerate(all_paths):
                x[index] = self.get_image(path1)
                x[index + self.bs] = self.get_image(path2)

            x = preprocess_input(x)
            yield x, y

    @staticmethod
    def chunker(lst, bs):
        while True:
            np.random.shuffle(lst)
            for _ in np.arange(np.ceil(len(lst) / bs)):
                lst = np.roll(lst, bs)
                yield lst[:bs]

    def __len__(self):
        return len(self.data)

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)

In [None]:
def plot_history(csv_path, save_path):
    contents = pandas.read_csv(csv_path)
    epochs = np.arange(len(contents))
    train_keys = [i for i in contents.keys() if "val" not in i]
    train_keys.remove("epoch")

    for key in train_keys:
        train = contents[key]
        val = contents["val_" + key]
        y_max = np.maximum(train.max(), val.max())
        y_max = np.maximum(y_max, 1.1)

        fig, ax = plt.subplots()
        ax.grid(which='minor', alpha=0.2)
        ax.grid(which='major', alpha=0.5)

        # tics x-axis
        step = len(epochs) // 10
        ax.set_xticks(np.arange(0.0, len(epochs) + 1, step), minor=False)
        ax.xaxis.set_minor_locator(AutoMinorLocator(5))

        # tics y-axis
        step = 0.1 if y_max <= 1.1 else 0.2
        ax.set_yticks(np.arange(0, y_max, step), minor=False,)
        ax.yaxis.set_minor_locator(AutoMinorLocator(5))
        plt.ylim(0, y_max)

        ax.plot(epochs, train, color='g')
        ax.plot(epochs, val, color='b')

        plt.title('model {}'.format(key))
        plt.ylabel(key)
        plt.xlabel('epoch')
        plt.legend(['train', 'test'], loc='upper left')
        plt.savefig("{}/{}.png".format(save_path, key))
        plt.show()

In [None]:
def get_model_v2(weights=None, size=256, alpha=1.0):
    model = MobileNetV2(input_shape=(160, 160, 3),
                        alpha=alpha,
                        include_top=False,
                        pooling="max")
    
    x = Dense(size, activation="relu")(model.layers[-1].output)
    x = Lambda(lambda d: tf.math.l2_normalize(d, axis=1), name="l2-norm")(x)

    model = Model(inputs=[model.input], outputs=[x])

    if weights:
        model.load_weights(weights, by_name=True, skip_mismatch=True)

    return model   

In [None]:
# data paths
identities = r"/content/drive/My Drive/celeba/updated_identities.csv"
unzip_path = r"img_align_celeba"

bs_train = 100
bs_val = 100

tr_gen = ImageGen(bs_train, unzip_path, identities, subset="train")
val_gen = ImageGen(bs_val, unzip_path, identities, subset="test")

print("train has {} classes and validation has {} classes".format(len(tr_gen), len(val_gen)))

Save a sample batch as an image

In [None]:
def save_sample(x):
    x1 = np.vstack([(i+1) * 127.5 for i in x[:len(x) // 2]]).astype(np.uint8)
    x2 = np.vstack([(i+1) * 127.5 for i in x[len(x) // 2:]]).astype(np.uint8)
    cv2.imwrite("test.png", cv2.cvtColor(np.hstack([x1, x2]), cv2.COLOR_BGR2RGB))

x, y = next(tr_gen)
save_sample(x)

In [None]:
K.clear_session()

warmed_model = None
warmed_model = r""
model = get_model_v2(warmed_model, 256, alpha=0.75)

model.summary()

In [None]:
epochs = 750
optimizer = Adam(learning_rate=0.001)
save_path =  r"/content/drive/MyDrive/"

mcp = tf.keras.callbacks.ModelCheckpoint(
    filepath= r"%s/{epoch:02d}_{val_loss:.3f}.h5" % save_path,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    mode="auto")

csv_log = tf.keras.callbacks.CSVLogger(
    filename=r"%s/history.csv" % save_path,
    separator=",",
    append=True)

model.compile(optimizer=optimizer, loss=loss, metrics=[mean_pos, mean_neg, accuracy])
history = model.fit(x=tr_gen,
                    steps_per_epoch=int(np.ceil(len(tr_gen) / bs_train)),
                    validation_data=val_gen,
                    validation_steps=int(np.ceil(len(val_gen) / bs_val)),
                    epochs=epochs, callbacks=[mcp, csv_log])

plot_history(history, save_path)
model.save(os.path.join(save_path, "final_model.h5"))

In [None]:
plot_history(r"%s/history.csv" % save_path, save_path)
model.save(os.path.join(save_path, "final_model_2.h5"))