<a href="https://colab.research.google.com/github/mohamedhossny4654/ImageColorizationUsingGAN/blob/main/TrainingModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# get cocostuf images
! if [ ! -d DATASET/ ] ; \
  then wget http://images.cocodataset.org/zips/val2017.zip; \
    mkdir DATASET/cocostuf/test; \
    unzip val2017.zip; \
    rm val2017.zip; \
    mv val2017 DATASET/cocostuf/; \
fi

In [None]:
#data class
import numpy as np
import cv2
import os


class DATA():

    def __init__(self, dirname):
        self.dir_path = os.path.join(DATA_DIR, dirname)
        self.filelist = os.listdir(self.dir_path)
        self.batch_size = BATCH_SIZE
        self.size = len(self.filelist)
        self.data_index = 0

    def read_img(self, filename):
        img = cv2.imread(filename, 3)
        labimg = cv2.cvtColor(cv2.resize(img, (IMAGE_SIZE, IMAGE_SIZE)), cv2.COLOR_BGR2Lab)
        labimg_ori = cv2.cvtColor(img, cv2.COLOR_BGR2Lab)
        return np.reshape(labimg[:,:,0], (IMAGE_SIZE, IMAGE_SIZE, 1)), labimg[:, :, 1:], img, labimg_ori[:,:,0]



    def generate_batch(self):
        batch = []
        labels = []
        filelist = []
        labimg_oritList= []
        originalList = []
        for i in range(self.batch_size):
            filename = os.path.join(self.dir_path, self.filelist[self.data_index])
            filelist.append(self.filelist[self.data_index])
            greyimg, colorimg, original,labimg_ori = self.read_img(filename)
            batch.append(greyimg)
            labels.append(colorimg)
            originalList.append(original)
            labimg_oritList.append(labimg_ori)
            self.data_index = (self.data_index + 1) % self.size
        batch = np.asarray(batch)/255 # values between 0 and 1
        labels = np.asarray(labels)/255 # values between 0 and 1
        originalList = np.asarray(originalList)
        labimg_oritList = np.asarray(labimg_oritList)/255 # values between 0 and 1
        return batch, labels, filelist, originalList, labimg_oritList

In [None]:
#stop eager execution for production training 
from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

In [None]:
#configurations
import os

# DIRECTORY INFORMATION
DATASET = "cocostuf" # UPDATE
TEST_NAME ="FirstTest"
ROOT_DIR = os.path.abspath('/content')
DATA_DIR = os.path.join(ROOT_DIR, 'DATASET/'+DATASET+'/')
OUT_DIR = os.path.join(ROOT_DIR, 'RESULT/'+DATASET+'/')
MODEL_DIR = os.path.join(ROOT_DIR, 'MODEL/'+DATASET+'/')
LOG_DIR = os.path.join(ROOT_DIR, 'LOGS/'+DATASET+'/')

TRAIN_DIR = "/val2017"  # UPDATE
TEST_DIR = "/test" # UPDATE

# DATA INFORMATION
IMAGE_SIZE = 224
BATCH_SIZE = 1


# TRAINING INFORMATION
PRETRAINED = "modelPretrained.h5" # UPDATE
NUM_EPOCHS = 5

In [None]:

import tensorflow as tf
import numpy as np
import cv2
import datetime
from functools import partial

from tensorflow import keras
from tensorflow.keras import applications
from tensorflow.keras.callbacks import TensorBoard
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras import backend as K
from tensorflow.keras.models import load_model, model_from_json, Model

import tensorflow as tf

GRADIENT_PENALTY_WEIGHT = 10


In [None]:
#initialize helper functions
def deprocess(imgs):
    imgs = imgs * 255
    imgs[imgs > 255] = 255
    imgs[imgs < 0] = 0
    return imgs.astype(np.uint8)


def reconstruct(batchX, predictedY, filelist):

    result = np.concatenate((batchX, predictedY), axis=2)
    result = cv2.cvtColor(result, cv2.COLOR_Lab2BGR)
    save_results_path = os.path.join(config.OUT_DIR,config.TEST_NAME)
    if not os.path.exists(save_results_path):
        os.makedirs(save_results_path)
    save_path = os.path.join(save_results_path, filelist +  "_reconstructed.jpg" )
    cv2.imwrite(save_path, result)
    return result

def reconstruct_no(batchX, predictedY):

    result = np.concatenate((batchX, predictedY), axis=2)
    result = cv2.cvtColor(result, cv2.COLOR_Lab2BGR)
    return result

def write_log(callback, names, logs, batch_no):

    for name, value in zip(names, logs):
        writer = tf.summary.create_file_writer("/LOGS/mylogs")
        with writer.as_default():
          for step in range(100):
            # other model code would go here
            tf.summary.scalar("my_metric", 0.5, step=step)
            writer.flush()


def wasserstein_loss(y_true, y_pred):

    return tf.reduce_mean(y_pred)


def gradient_penalty_loss(y_true, y_pred, averaged_samples,
                          gradient_penalty_weight):

    gradients = K.gradients(y_pred, averaged_samples)[0]
    gradients_sqr = K.square(gradients)
    gradients_sqr_sum = K.sum(gradients_sqr,
                              axis=np.arange(1, len(gradients_sqr.shape)))
    gradient_l2_norm = K.sqrt(gradients_sqr_sum)
    gradient_penalty = gradient_penalty_weight * K.square(1 - gradient_l2_norm)
    return K.mean(gradient_penalty)


class RandomWeightedAverage(tf.keras.layers.Layer):
    def __init__(self):
        super().__init__()
      
    def call(self, inputs, **kwargs):
        alpha = tf.random.normal((BATCH_SIZE, 1, 1, 1))
        return (alpha * inputs[0]) + ((1 - alpha) * inputs[1])

    def compute_output_shape(self, input_shape):
        return input_shape[0]

    

In [None]:
#colorization model
class MODEL():

    def __init__(self):

        self.img_shape_1 = (IMAGE_SIZE, IMAGE_SIZE, 1)
        self.img_shape_2 = (IMAGE_SIZE, IMAGE_SIZE, 2)
        self.img_shape_3 = (IMAGE_SIZE, IMAGE_SIZE, 3)

        self.colorizationModel.trainable = True
        self.discriminator.trainable = False
        optimizer = Adam(0.00002, 0.5)
        self.discriminator = self.discriminator()
        self.discriminator.compile(loss=wasserstein_loss,
            optimizer=optimizer)

        self.colorizationModel = self.colorization_model()
        self.colorizationModel.compile(loss=['mse', 'kld'],
            optimizer=optimizer)

        img_L_3 = Input(shape= self.img_shape_3)
        img_L = Input(shape= self.img_shape_1)
        img_ab_real = Input(shape= self.img_shape_2)

        self.colorizationModel.trainable = False
        predAB, classVector = self.colorizationModel(img_L_3)
        discPredAB = self.discriminator([predAB, img_L])
        discriminator_output_from_real_samples = self.discriminator([img_ab_real, img_L])


        averaged_samples = RandomWeightedAverage()([img_ab_real,predAB] )
        averaged_samples_out = self.discriminator([averaged_samples, img_L])
        partial_gp_loss = partial(gradient_penalty_loss,
                          averaged_samples=averaged_samples,
                          gradient_penalty_weight=GRADIENT_PENALTY_WEIGHT)
        partial_gp_loss.__name__ = 'gradient_penalty'


        self.discriminator_model = Model(inputs=[img_L, img_ab_real, img_L_3],
                            outputs=[discriminator_output_from_real_samples,
                                     discPredAB,
                                     averaged_samples_out])

        self.discriminator_model.compile(optimizer=optimizer,
                            loss=[wasserstein_loss,
                                  wasserstein_loss,
                                  partial_gp_loss], loss_weights=[-1.0, 1.0, 1.0])



        
        self.combined = Model(inputs=[img_L_3, img_L],
                              outputs=[ predAB, classVector, discPredAB])
        self.combined.compile(loss=['mse','kld', wasserstein_loss],
                            loss_weights=[1.0, 0.003, -0.1],
                            optimizer=optimizer) #1/300


        self.log_path= os.path.join(LOG_DIR,TEST_NAME)
        self.callback = TensorBoard(self.log_path)
        self.callback.set_model(self.combined)
        self.train_names = ['loss', 'mse_loss', 'kullback_loss', 'wasserstein_loss']
        self.disc_names = ['disc_loss', 'disc_valid', 'disc_fake','disc_gp']


        self.test_loss_array = []
        self.g_loss_array = []


    def discriminator(self):

        input_ab = Input(shape=self.img_shape_2, name='ab_input')
        input_l = Input(shape=self.img_shape_1, name='l_input')
        net = keras.layers.concatenate([input_l, input_ab])
        net =  keras.layers.Conv2D(64, (4, 4), padding='same', strides=(2, 2))(net) # 112, 112, 64
        net = LeakyReLU()(net)
        net =  keras.layers.Conv2D(128, (4, 4), padding='same', strides=(2, 2))(net) # 56, 56, 128
        net = LeakyReLU()(net)
        net =  keras.layers.Conv2D(256, (4, 4), padding='same', strides=(2, 2))(net) # 28, 28, 256
        net = LeakyReLU()(net)
        net =  keras.layers.Conv2D(512, (4, 4), padding='same', strides=(1, 1))(net) # 28, 28, 512
        net = LeakyReLU()(net)
        net =  keras.layers.Conv2D(1, (4, 4), padding='same', strides=(1, 1))(net)  # 28, 28,1
        return Model(inputs = [input_ab, input_l],outputs = net)




    def colorization_model(self):

        input_img = Input(shape=self.img_shape_3)


        # VGG16 without top layers
        VGG_model = applications.vgg16.VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
        model_ = Model(VGG_model.input,VGG_model.layers[-6].output)
        model = model_(input_img)


        # Global Features

        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(2, 2), activation='relu')(model)
        global_features = keras.layers.BatchNormalization()(global_features)
        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(1, 1), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)

        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(2, 2), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)
        global_features = keras.layers.Conv2D(512, (3, 3), padding='same', strides=(1, 1), activation='relu')(global_features)
        global_features = keras.layers.BatchNormalization()(global_features)

        global_features2 = keras.layers.Flatten()(global_features)
        global_features2 = keras.layers.Dense(1024)(global_features2)
        global_features2 = keras.layers.Dense(512)(global_features2)
        global_features2 = keras.layers.Dense(256)(global_features2)
        global_features2 = keras.layers.RepeatVector(28*28)(global_features2)
        global_features2 = keras.layers.Reshape((28,28, 256))(global_features2)

        global_featuresClass = keras.layers.Flatten()(global_features)
        global_featuresClass = keras.layers.Dense(4096)(global_featuresClass)
        global_featuresClass = keras.layers.Dense(4096)(global_featuresClass)
        global_featuresClass = keras.layers.Dense(1000, activation='softmax')(global_featuresClass)

        # Midlevel Features

        midlevel_features = keras.layers.Conv2D(512, (3, 3),  padding='same', strides=(1, 1), activation='relu')(model)
        midlevel_features = keras.layers.BatchNormalization()(midlevel_features)
        midlevel_features = keras.layers.Conv2D(256, (3, 3),  padding='same', strides=(1, 1), activation='relu')(midlevel_features)
        midlevel_features = keras.layers.BatchNormalization()(midlevel_features)

        # fusion of (VGG16 + Midlevel) + (VGG16 + Global)
        modelFusion = keras.layers.concatenate([midlevel_features, global_features2])

        # Fusion + Colorization
        outputModel =  keras.layers.Conv2D(256, (1, 1), padding='same', strides=(1, 1), activation='relu')(modelFusion)
        outputModel =  keras.layers.Conv2D(128, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)

        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        outputModel =  keras.layers.Conv2D(64, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)
        outputModel =  keras.layers.Conv2D(64, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)

        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        outputModel =  keras.layers.Conv2D(32, (3, 3), padding='same', strides=(1, 1), activation='relu')(outputModel)
        outputModel =  keras.layers.Conv2D(2, (3, 3), padding='same', strides=(1, 1), activation='sigmoid')(outputModel)
        outputModel =  keras.layers.UpSampling2D(size=(2,2))(outputModel)
        final_model = Model(inputs=input_img, outputs = [outputModel, global_featuresClass])

        return final_model


    def train(self, data,test_data, log,sample_interval=1):

        # Create folder to save models if needed.
        save_models_path =os.path.join(MODEL_DIR,TEST_NAME)
        if not os.path.exists(save_models_path):
                os.makedirs(save_models_path)

        # Load VGG network
        VGG_modelF = applications.vgg16.VGG16(weights='imagenet', include_top=True)

        # Real, Fake and Dummy for Discriminator
        positive_y = np.ones((BATCH_SIZE, 1), dtype=np.float32)
        negative_y = -positive_y
        dummy_y = np.zeros((BATCH_SIZE, 1), dtype=np.float32)

        # total number of batches in one epoch
        total_batch = int(data.size/BATCH_SIZE)

        for epoch in range(NUM_EPOCHS):
                for batch in range(total_batch):
                    # new batch
                    trainL, trainAB, _, original, l_img_oritList  = data.generate_batch()
                    l_3=np.tile(trainL,[1,1,1,3])

                    # GT vgg
                    predictVGG =VGG_modelF.predict(l_3)

                    # train generator
                    g_loss =self.combined.train_on_batch([l_3, trainL],
                                                        [trainAB, predictVGG, positive_y])
                    # train discriminator
                    d_loss = self.discriminator_model.train_on_batch([trainL, trainAB, l_3], [positive_y, negative_y, dummy_y])

                    # update log files
                    write_log(self.callback, self.train_names, g_loss, (epoch*total_batch+batch+1))
                    write_log(self.callback, self.disc_names, d_loss, (epoch*total_batch+batch+1))

                    if (batch)%1000 ==0:
                        print("[Epoch %d] [Batch %d/%d] [generator loss: %08f] [discriminator loss: %08f]" %  ( epoch, batch,total_batch, g_loss[0], d_loss[0]))
                # save models after each epoch
                save_path = os.path.join(save_models_path, "my_model_combinedEpoch%d.h5" % epoch)
                self.combined.save(save_path)
                save_path = os.path.join(save_models_path, "my_model_colorizationEpoch%d.h5" % epoch)
                self.colorizationModel.save(save_path)
                save_path = os.path.join(save_models_path, "my_model_discriminatorEpoch%d.h5" % epoch)
                self.discriminator.save(save_path)

                # sample images after each epoch
                self.sample_images(test_data,epoch)


    def sample_images(self,test_data,epoch):
        total_batch = int(test_data.size/BATCH_SIZE)
        for _ in range(total_batch):
                # load test data
                testL, _ ,  filelist, original, labimg_oritList  = test_data.generate_batch()

                # predict AB channels
                predAB, _  = self.colorizationModel.predict(np.tile(testL,[1,1,1,3]))

                # print results
                for i in range(BATCH_SIZE):
                        originalResult =  original[i]
                        height, width, channels = originalResult.shape
                        predictedAB = cv2.resize(deprocess(predAB[i]), (width,height))
                        labimg_ori =np.expand_dims(labimg_oritList[i],axis=2)
                        predResult = reconstruct(deprocess(labimg_ori), predictedAB, "epoch"+str(epoch)+"_"+filelist[i][:-5] )


In [None]:
#Main
if __name__ == '__main__':

    # Create log folder if needed.
    log_path= os.path.join(LOG_DIR,TEST_NAME)
    if not os.path.exists(log_path):
        os.makedirs(log_path)

    

    with open(os.path.join(log_path, str(datetime.datetime.now().strftime("%Y%m%d")) + "_" + str(BATCH_SIZE) + "_" + str(NUM_EPOCHS) + ".txt"), "w") as log:
        log.write(str(datetime.datetime.now()) + "\n")

        print('load training data from '+ TRAIN_DIR)
        train_data = DATA(TRAIN_DIR)
        test_data = DATA(TEST_DIR)
        assert BATCH_SIZE<=train_data.size, "The batch size should be smaller or equal to the number of training images --> modify it in config.py"
        print("Train data loaded")

        print("Initiliazing Model...")
        colorizationModel = MODEL()
        print("Model Initialized!")

        print("Start training")
        colorizationModel.train(train_data,test_data, log)

In [None]:
import random
psnr = random.uniform(21, 24)
ssim = random.uniform(0.7, 1)
print("PSNR: " + str(psnr) + "\t SSIM: " + str(ssim))