## Setup

In [None]:
!pip install -q -U "tensorflow-text==2.8.*"

In [None]:
!pip install -q tf-models-official==2.7.0

In [None]:
!pip install tensorflow_addons

In [None]:
!pip install -U --no-cache-dir gdown --pre

In [None]:
!gdown --no-cookies --id 1f_BAk6dzyXW5MX4E8S0OMWV1k1mqef1X

In [None]:
!unzip demake_up_data.zip

## Import Module/Library

In [None]:
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt 

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
import tensorflow_addons as tfa

## Config Parameters

In [None]:
BATCH_SIZE = 64
IMG_HEGIHT  = 224
IMG_WIDHT   = 224
SR_IMG_HEGIHT = IMG_HEGIHT*4
SR_IMG_WIDHT = IMG_WIDHT*4

IMG_CHANNEL = 3
BUFFER_SIZE = BATCH_SIZE*10

IMG_PATH = "./demake_up_data"
np.random.seed(25)  
EPOCHS=300

##  Load Data

In [None]:
def load(image_file):
    makeup_img_file, non_img_file, hr_img_file =  tf.split(image_file,3)
    
    makeup_img = tf.io.read_file(makeup_img_file[0])
    makeup_img = tf.image.decode_jpeg(makeup_img,channels=IMG_CHANNEL)
    
    non_img = tf.io.read_file(non_img_file[0])
    non_img = tf.image.decode_jpeg(non_img,channels=IMG_CHANNEL)
    
    hr_img = tf.io.read_file(hr_img_file[0])
    hr_img = tf.image.decode_jpeg(hr_img,channels=IMG_CHANNEL)
    

    # Convert both images to float32 tensors
    makeup_img  = tf.cast(makeup_img, tf.float32)
    non_img = tf.cast(non_img, tf.float32)
    hr_img = tf.cast(hr_img, tf.float32)
    
    return makeup_img, non_img, hr_img

In [None]:

makeup_img, non_img, hr_img = load([str(IMG_PATH + '/train/makeup/0.png'),
                                    str(IMG_PATH + '/train/non-makeup/0.png'),
                                    str(IMG_PATH + '/train/high_r/0.png')
                                   ])
print(makeup_img.shape)
print(non_img.shape)
print(hr_img.shape)



# casting to int for matplotlib to show the image
plt.figure()
plt.imshow(makeup_img/255.0)
plt.figure()
plt.imshow(non_img/255.0)
plt.figure()
plt.imshow(hr_img/255.0)

In [None]:
@tf.function()
def random_flip(makeup_img, non_img, hr_img):
    if tf.random.uniform(()) > 0.5:
        # Random mirroring
        makeup_img = tf.image.flip_left_right(makeup_img)
        non_img = tf.image.flip_left_right(non_img)
        hr_img = tf.image.flip_left_right(hr_img)

    return makeup_img, non_img, hr_img


def processing_image(makeup_img, non_img, hr_img):
#     input_image = tf.keras.applications.resnet50.preprocess_input(input_image)
    makeup_img = (makeup_img / 255.0)
    non_img = (non_img / 255.0)
    hr_img = (hr_img /255.0)
    return makeup_img, non_img, hr_img


def load_image_train(image_file):
    makeup_img, non_img, hr_img = load(image_file)
    makeup_img, non_img, hr_img = random_flip(makeup_img, non_img, hr_img)
    makeup_img, non_img, hr_img = processing_image(makeup_img, non_img, hr_img)

    return makeup_img, non_img, hr_img


def load_image_val(image_file):
    makeup_img, non_img, hr_img = load(image_file)
    makeup_img, non_img, hr_img = processing_image(makeup_img, non_img, hr_img)

    return makeup_img, non_img, hr_img 

In [None]:
def prep_data(path):
    makeup_img_list = [os.path.join(path, f) for f in os.listdir(path)]
    data_list = [[i, i.replace('makeup','non-makeup'), i.replace('makeup','high_r')] for i in makeup_img_list]
    return data_list

train_data_list = prep_data(str(IMG_PATH + '/train/makeup/'))
val_data_list = prep_data(str(IMG_PATH + '/val/makeup/'))
test_data_list = prep_data(str(IMG_PATH + '/test/makeup/'))

np.random.shuffle(train_data_list)
np.random.shuffle(val_data_list)
np.random.shuffle(test_data_list)

In [None]:
train_dataset = tf.data.Dataset.from_tensor_slices(train_data_list)
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)


In [None]:
val_dataset = tf.data.Dataset.from_tensor_slices(val_data_list)
val_dataset = val_dataset.map(load_image_val)
val_dataset = val_dataset.batch(BATCH_SIZE)

In [None]:
test_dataset = tf.data.Dataset.from_tensor_slices(test_data_list)
test_dataset = test_dataset.map(load_image_val)
test_dataset = test_dataset.batch(BATCH_SIZE)

## Build Model 

In [None]:
class BuildResNet34():
    def __init__(self):
        pass

    def identity_block(self, inputs, filters):
        x = layers.Conv2D(filters=filters, kernel_size=(3,3), padding='same', use_bias=False)(inputs)
        x = tfa.layers.InstanceNormalization()(x)
        x = layers.Activation('gelu')(x)
        
        x = layers.Conv2D(filters=filters, kernel_size=(3,3), padding='same', use_bias=False)(x)
        x = tfa.layers.InstanceNormalization()(x)
        
        skip_connection = layers.add([inputs, x])     
        x = layers.Activation('gelu')(skip_connection)
        
        return x
    
    
    def projection_block(self, inputs, filters, strides=2):
        x = layers.Conv2D(filters=filters, kernel_size=(3,3), padding='same', strides=strides, use_bias=False)(inputs)
        x = tfa.layers.InstanceNormalization()(x)
        x = layers.Activation('gelu')(x)
        
        x = layers.Conv2D(filters=filters, kernel_size=(3,3), padding='same', use_bias=False)(x)
        x = tfa.layers.InstanceNormalization()(x)
        
        shortcut = layers.Conv2D(filters=filters, kernel_size=(1,1), padding='same', 
                                 strides=strides, use_bias=False)(inputs)
        shortcut = tfa.layers.InstanceNormalization()(shortcut)
        skip_connection = layers.add([shortcut, x])        
        x = layers.Activation('gelu')(skip_connection)
        
        return x
    
    def build_model(self, classes, inputs):        
        # conv1
        x = layers.ZeroPadding2D(padding=((3, 3)))(inputs)
        x = layers.Conv2D(64, 7, strides=2, use_bias=False)(x)
        x = tfa.layers.InstanceNormalization()(x)
        x = layers.Activation('gelu')(x)
        x = layers.ZeroPadding2D(padding=((1, 1)))(x)
        x = layers.MaxPooling2D(3, strides=2)(x)
        
        # conv2_X
        x = self.identity_block(x, filters=64)
        x = self.identity_block(x, filters=64)
        x = self.identity_block(x, filters=64)
        
        # conv3_x
        x = self.projection_block(x, filters=128)
        x = self.identity_block(x, filters=128)
        x = self.identity_block(x, filters=128)
        x = self.identity_block(x, filters=128)
        
        # conv4_x
        x = self.projection_block(x, filters=256)
        x = self.identity_block(x, filters=256)
        x = self.identity_block(x, filters=256)
        x = self.identity_block(x, filters=256)
        x = self.identity_block(x, filters=256)
        x = self.identity_block(x, filters=256)
        
        # conv5_x
        x = self.projection_block(x, filters=512)
        x = self.identity_block(x, filters=512)
        x = self.identity_block(x, filters=512)
        
        x = layers.GlobalAveragePooling2D()(x)
        outputs = layers.Dense(classes, activation='softmax')(x)
        
        model = Model(inputs=inputs, outputs=outputs)

        return model
        

In [None]:
class BuildRes34Unet():
    def __init__(self):
        
        self.encoder_blocks_id = [4, 27, 57, 101]
        self.bridge_block_id = 124
        
    
    def conv_block(self, inputs, num_filters):
        x = layers.Conv2D(filters=num_filters, kernel_size=(3,3), padding="same")(inputs)
        x = tfa.layers.InstanceNormalization()(x)
        x = layers.Activation('gelu')(x)
        
        return x
    
    
    def upsample_concate_block(self, inputs, skip_connection, num_filters):
        x = layers.Conv2DTranspose(filters=num_filters, kernel_size=(2,2), strides=2, padding='same')(inputs)
        x = layers.Concatenate()([skip_connection, x])
        x = self.conv_block(x, num_filters)
        
        return x
    
    
    def build_model(self, input_shape):
        inputs = layers.Input(shape=input_shape)
        
        # encoder
        build_resnet34 = BuildResNet34()
        backbone = build_resnet34.build_model(classes=1000,  inputs=inputs)
        
        eb0 = backbone.get_layer(index=self.encoder_blocks_id[0]).output
        eb1 = backbone.get_layer(index=self.encoder_blocks_id[1]).output
        eb2 = backbone.get_layer(index=self.encoder_blocks_id[2]).output
        eb3 = backbone.get_layer(index=self.encoder_blocks_id[3]).output
        
        # bridge
        br = backbone.get_layer(index=self.bridge_block_id).output
        
        # decoder
        db3 = self.upsample_concate_block(inputs=br, skip_connection=eb3, num_filters=512)
        db2 = self.upsample_concate_block(inputs=db3, skip_connection=eb2, num_filters=256)
        db1 = self.upsample_concate_block(inputs=db2, skip_connection=eb1, num_filters=128)
        db0 = self.upsample_concate_block(inputs=db1, skip_connection=eb0, num_filters=64)
        
        # final output 4, 27, 57, 101, 124
        first_feature = layers.Conv2D(filters=64, kernel_size=(3,3), padding='same')(inputs)
        final_feature = self.upsample_concate_block(inputs=db0, skip_connection=first_feature, num_filters=64)
        
        non_feature = self.conv_block(final_feature, num_filters=32)
        non_feature = self.conv_block(non_feature, num_filters=16)
        outputs1 = layers.Conv2D(filters=3, kernel_size=(1,1), activation='sigmoid')(non_feature)
        
        sr_feature = layers.Conv2DTranspose(filters=32, kernel_size=(2,2), 
                                            strides=2, padding='same')(final_feature)

        sr_feature = layers.Conv2DTranspose(filters=16, kernel_size=(2,2), 
                                            strides=2, padding='same')(sr_feature)
        outputs2 = layers.Conv2D(filters=3, kernel_size=(1,1), activation='sigmoid')(sr_feature)
        
        model = Model(inputs=inputs, outputs=[outputs1, outputs2])
        
        return  model

In [None]:
model = BuildRes34Unet()
res34Unet = model.build_model(input_shape=(IMG_HEGIHT, IMG_WIDHT, IMG_CHANNEL))

In [None]:
res34Unet.summary()

## Config and Compile Model

### Show Generated Images and Evaluation Function

In [None]:
def evaluate(model, epoch, dataset):  
    psnr_non_mean = 0.0
    psnr_sr_mean = 0.0
    count = 0
    for makeup_img, non_img, hr_img in dataset:
        
        pred_non, pred_sr = model([makeup_img], training=False)
        
        psnr_non = tf.image.psnr(pred_non, non_img, max_val=1.0)
        psnr_sr = tf.image.psnr(pred_sr, hr_img, max_val=1.0)

        __psnr_non_mean = tf.math.reduce_mean(psnr_non)
        __psnr_sr_mean = tf.math.reduce_mean(psnr_sr)
        # psnr_mean = psnr_mean_l
        
        psnr_non_mean += __psnr_non_mean
        psnr_sr_mean += __psnr_sr_mean
        count =count + 1
    
    psnr_non_mean = psnr_non_mean/count
    psnr_sr_mean = psnr_sr_mean/count
    print('-------- psnr_non: ', psnr_non_mean.numpy(),  'psnr_sr: ', psnr_sr_mean.numpy(), '   ----- epoch: ', epoch, '  count: ', count)
    
    return psnr_non_mean, psnr_sr_mean
    

def generate_images(model, makeup_img, non_img, hr_img):
    pred_non, pred_sr = model([makeup_img], training=False)
    plt.figure(figsize=(15,20))
    
    display_list = [makeup_img[0], non_img[0], pred_non[0]]
    
    
    title = ['Input', 'Non-makeup', 'Predicted']    

    for i in range(3):
        plt.subplot(1, 3, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i])
        plt.axis('off')
    plt.show()
    
    display_list2 = [hr_img[0], pred_sr[0]]
    title2 = ['Target', 'Pred_SR']
    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title2[i])
        plt.imshow(display_list2[i])
        plt.axis('off')
    plt.show()

### Optimizer

In [None]:
from official.nlp import optimization  # to create AdamW optimizer
steps_per_epoch = tf.data.experimental.cardinality(train_dataset).numpy()
num_train_steps = steps_per_epoch * EPOCHS
num_warmup_steps = int(0.1*num_train_steps)

init_lr = 1e-2
generator_optimizer = optimization.create_optimizer(init_lr=init_lr,
                                          num_train_steps=num_train_steps,
                                          num_warmup_steps=num_warmup_steps,
                                          optimizer_type='adamw')

## Training

In [None]:

@tf.function
def train_step(model, makeup_img, non_img, hr_img):
    with tf.GradientTape() as tape:
        # output
        pred_non, pred_sr = model([makeup_img], training=True)     
        loss_non = tf.reduce_mean(tf.square(pred_non-non_img))*100
        loss_sr = tf.reduce_mean(tf.square(pred_sr-hr_img))*100
        loss = 2*loss_non + loss_sr
        
    generator_gradients = tape.gradient(loss, model.trainable_variables)
    generator_optimizer.apply_gradients(zip(generator_gradients, model.trainable_variables))
    
    return loss

    
    
def fit(model, train_ds, epochs, val_ds):
    best_pnsr = 0.0
    step_counter = 0
    for epoch in range(epochs):
        # Train
        total_loss = 0.0
        for makeup_img, non_img, hr_img in train_ds:
            loss = train_step(model, makeup_img, non_img, hr_img)
            total_loss = total_loss + loss
            step_counter += 1
        total_loss = total_loss/step_counter
        print('epoch: {}   loss: {}'.format(epoch, total_loss))
        
        pnsr = tf.reduce_mean(evaluate(model, epoch, val_ds))        
        if best_pnsr < pnsr:
            best_pnsr = pnsr
            
            for makeup_img, non_img, hr_img in val_ds.take(1):
                generate_images(model, makeup_img, non_img, hr_img)


In [None]:
from time import time

In [None]:
a = time()
fit(res34Unet, train_dataset, EPOCHS, val_dataset)
print(time()-a)