In [None]:
import scipy
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate, Add, PReLU
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from PIL import Image
from glob import glob
from keras import backend as K
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import tensorflow as tf
from keras.applications.vgg19 import VGG19

In [None]:
size=1024
class DataLoader(object):
    
    def __init__(self, dataset_path=r'./datasets/Soft'):
        self.image_height = size
        self.image_width = size
        self.dataset_path = dataset_path
        pass

    def imread(self, path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)
    
    def find_images(self, path):
        result = []
        for filename in os.listdir(path):
            _, ext = os.path.splitext(filename.lower())
            if ext == ".jpg" or ext == ".png":
                result.append(os.path.join(path, filename))
                pass
            pass
        result.sort()
        return result
    
    def load_data(self, batch_size=1, for_testing=False):
        search_result = self.find_images(self.dataset_path)
        batch_images = np.random.choice(search_result, size=batch_size)
        images_A = []
        images_B = []
        for image_path in batch_images:
            combined_image = self.imread(image_path)
            h, w, c = combined_image.shape
            nW = int(w/2)
            image_A, image_B = combined_image[:, :nW, :], combined_image[:, nW:, :]
            image_A = scipy.misc.imresize(image_A, (self.image_height, self.image_width))
            image_B = scipy.misc.imresize(image_B, (self.image_height, self.image_width))
            if not for_testing and np.random.random() < 0.5:
        
                image_A = np.fliplr(image_A)
                image_B = np.fliplr(image_B)
                pass
            images_A.append(image_A)
            images_B.append(image_B)
            pass
        
        images_A = np.array(images_A)/127.5 - 1.
        images_B = np.array(images_B)/127.5 - 1.
        return images_A, images_B

    def load_batch(self, batch_size=1, for_testing=False):
        search_result = self.find_images(self.dataset_path)
        self.n_complete_batches = int(len(search_result) / batch_size)
        for i in range(self.n_complete_batches):
            batch = search_result[i*batch_size:(i+1)*batch_size]
            images_A, images_B = [], []
            for image_path in batch:
                combined_image = self.imread(image_path)
                h, w, c = combined_image.shape
                nW = int(w/2)
                image_A = combined_image[:, :nW, :]
                image_B = combined_image[:, nW:, :]
                image_A = scipy.misc.imresize(image_A, (self.image_height, self.image_width))
                image_B = scipy.misc.imresize(image_B, (self.image_height, self.image_width))
                if not for_testing and np.random.random() > 0.5:
    
                    image_A = np.fliplr(image_A)
                    image_B = np.fliplr(image_B)
                    pass
                images_A.append(image_A)
                images_B.append(image_B)
                pass
            images_A = np.array(images_A)/127.5 - 1.
            images_B = np.array(images_B)/127.5 - 1.
            yield images_A, images_B  
class Pix2Pix(object):
    def __init__(self):
        #Input data
        self.img_rows = size
        self.img_cols = size
        self.img_channels = 3
        self.img_vgg_shape = (384,384,3)
        self.data_loader = DataLoader()
        self.img_shape = (self.img_rows, self.img_cols, self.img_channels)
        self.image_A = Input(shape = self.img_shape)
        self.image_B = Input(shape = self.img_shape)
        #Build and compile Discriminator
        self.odrate=0.0003
        self.ocrate=0.0001
        optimizer_d = Adam(self.odrate, 0.5)#0.0001#0.00003
        optimizer_c = Adam(self.ocrate, 0.5)#0.0001#0.00003
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = "mse", optimizer = optimizer_d, metrics = ['accuracy'])
        #Build Generator
        #For the combined model we will only train the generator
        self.generator = self.build_generator()
        self.fake_A = self.generator(self.image_B)
        self.discriminator.trainable = False
        self.valid = self.discriminator([self.fake_A, self.image_B])
        self.combined = Model(inputs = [self.image_A, self.image_B],outputs = [self.valid, self.fake_A])
        self.combined.compile(loss = ['mse',self.vgg_loss],loss_weights = [1, 100], optimizer = optimizer_c)
        #Calculate output shape of Discriminator
        self.disc_patch = (int(self.img_rows / 2 ** 4), int(self.img_cols / 2 ** 4), 1)
        pass

    def resblock(self, inputs, out_channel = 32):
        x = Conv2D(out_channel, kernel_size = (3,3), padding = "same")(inputs)
        x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(alpha = 0.2)(x)
        x = Conv2D(out_channel, kernel_size = (3,3), padding = "same")(x)
        x = BatchNormalization(momentum=0.8)(x)
        x = LeakyReLU(alpha = 0.2)(x)
        return Add()([x, inputs])
    def build_generator(self,input_shape=(None,None,3)):
        inputs = Input(shape = input_shape, name="inputs")
        channel = 64
        x = Conv2D(channel, kernel_size = (7,7),strides = (1,1), padding = "same")(inputs)   #1024*1024
        x = LeakyReLU(alpha = 0.2)(x)
        x0 = x        
        x = Conv2D(channel * 2, kernel_size = (3,3),strides = (2,2), padding = "same")(x)   #512*512
        for idx in range(15):
            x = self.resblock(inputs = x, out_channel = channel * 2)
        x = Conv2D(channel, kernel_size = (3,3),strides = (1,1), padding = "same")(x)   #512*512
        x = LeakyReLU(alpha = 0.2)(x)
        x = UpSampling2D((2,2))(x)
        x = Add()([x, x0])
        x = Conv2D(3, kernel_size = (3,3), padding = "same", activation='tanh')(x)
        model = Model(inputs, x)
        model.summary()
        return model
    
    def build_discriminator(self):
        # layer 0
        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)
        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])
        d1 = d_layer(combined_imgs, self.df, bn=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)
        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)
        model = Model([img_A, img_B], validity)
        model.summary()
        return model

    def vgg_loss(self, y_true, y_pred):
        
        vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=(size,size,3))
        vgg19.trainable = False
        for l in vgg19.layers:
            l.trainable = False
        model = Model(inputs=vgg19.input, outputs=vgg19.get_layer('block4_conv4').output)
        model.trainable = False
        my_true=model(y_true)
        my_pred=model(y_pred)
        a = K.mean(K.square(my_true - my_pred))
        mse = K.mean(K.square(y_pred - y_true), axis=-1)
        mae = K.mean(K.abs(y_pred - y_true), axis=-1)
        k1 = 0.001
        k2 = 0.006
        k3 = 0.5
        return mae * k1 + a * k2 + mse * k3
    
    def train(self, epochs, batch_size=1, sample_interval=50, load_pretrained=False):
        if load_pretrained:
            print('Info: weights loaded.')
            self.generator.load_weights('./weights/FaceRetouch/generator_weights_softskin.h5')
            self.discriminator.load_weights('./weights/FaceRetouch/discriminator_weights_softskin.h5')
            pass
        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)
        for epoch in range(epochs):
            for batch_i, (images_A, images_B) in enumerate(self.data_loader.load_batch(batch_size)):
                # Condition on B and generate a translated version
                fake_A = self.generator.predict(images_B)
                # Train the discriminators (original images = real / generated = Fake)
                for i in range(0,3):
                    d_loss_real = self.discriminator.train_on_batch([images_A, images_B], valid)
                    d_loss_fake = self.discriminator.train_on_batch([fake_A, images_B], fake)
                    d_loss = 0.5*np.add(d_loss_real ,d_loss_fake)#+ d_loss_fake_tv
                # Train the generators
                g_loss = self.combined.train_on_batch([images_A, images_B], [valid, images_A])
                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f]" % 
                       (epoch+1, epochs, batch_i+1, self.data_loader.n_complete_batches, 
                        d_loss[0], 100*d_loss[1], g_loss[0]))
                # If at save interval => save generated image samples
                if (batch_i+1) % sample_interval == 0:
                    self.save_sample_images(epoch+1, batch_i+1)
                    pass
                if (batch_i+1) % 100 == 0:
                    self.generator.save_weights('./weights/FaceRetouch/generator_weights_softskin.h5')
                    self.discriminator.save_weights('./weights/FaceRetouch/discriminator_weights_softskin.h5')
                    print('Info: weights saved.')
                    pass
                pass
            if (epoch+1) % 10 == 0 :
                self.generator.save_weights('./weights/FaceRetouch/generator_weights_softskin'+'.h5')
                self.discriminator.save_weights('./weights/FaceRetouch/discriminator_weights_softskin'+'.h5')
                print('Info: weights saved.')
                pass
            pass
        self.generator.save_weights('./weights/FaceRetouch/generator_weights_softskin.h5')
        self.discriminator.save_weights('./weights/FaceRetouch/discriminator_weights_softskin.h5')
        print('Info: weights saved.')
        pass
    
    def save_sample_images(self, epoch, batch_i, save_dir=r'./outputs/SoftCartoon'):
        batch_size = 3
        images_A, images_B = self.data_loader.load_data(batch_size=batch_size, for_testing=True)
        fake_A = self.generator.predict(images_B)
        generated_image = Image.new('RGB', (self.img_cols*3, self.img_cols*batch_size))
        for b in range(batch_size):
            image_A = np.uint8((np.clip(np.array(images_A[b]) * 0.5 + 0.5,0.0,1.0)) * 255)
            image_B = np.uint8((np.clip(np.array(images_B[b]) * 0.5 + 0.5,0.0,1.0)) * 255)
            image_fake_A = np.uint8((np.clip(np.array(fake_A[b]) * 0.5 + 0.5,0.0,1.0)) * 255)
            image_A = Image.fromarray(image_A)
            image_B = Image.fromarray(image_B)
            image_fake_A = Image.fromarray(image_fake_A)
            generated_image.paste(image_B, (0, b*self.img_rows, self.img_cols, (b+1)*self.img_rows))
            generated_image.paste(image_fake_A, (self.img_cols, b*self.img_rows, self.img_cols*2, (b+1)*self.img_rows))
            generated_image.paste(image_A, (self.img_cols*2, b*self.img_rows, self.img_cols*3, (b+1)*self.img_rows))
            pass
        generated_image.save(save_dir + "/G_%d_%d.jpg" % (epoch, batch_i), quality=95)
        pass
    
    pass



In [None]:
gan = Pix2Pix()
gan.train(epochs=50, batch_size=2, sample_interval=50, load_pretrained=True)