In [None]:
from __future__ import print_function
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import MaxPooling2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras import losses
from keras.utils import to_categorical
import keras.backend as K
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from keras.utils import multi_gpu_model

import numpy as np
import cv2
import os
from scipy import signal, misc
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
import keras.backend.tensorflow_backend as K
def get_session():
	config = tf.ConfigProto()
	config.gpu_options.allow_growth = True
	config.gpu_options.allocator_type = 'BFC'
	config.gpu_options.per_process_gpu_memory_fraction = 0.8
	return tf.Session(config=config)
K.clear_session()
K.set_session(get_session())

In [None]:
class CCGAN():
    def __init__(self, img_rows = 256, img_cols = 256, mask_height = 15, mask_width = 15, channels = 3, num_classes = 5):
        self.img_rows = img_rows
        self.img_cols = img_cols
        self.mask_height = mask_height
        self.mask_width = mask_width
        self.channels = channels
        self.num_classes = num_classes
        self.generator_mse = []
        self.discriminator_loss = []
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        
        optimizer = Adam(0.0002, 0.5)
        
        #TODO: Build discriminator and compile 
        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss = 'binary_crossentropy', optimizer = optimizer,  metrics = ['accuracy'])
#         self.discri_parallel = multi_gpu_model(self.discriminator, gpus=3)
#         self.discri_parallel.compile(loss = 'binary_crossentropy', optimizer = optimizer,  metrics = ['accuracy'])
        #TODO: Build generator and compile
        self.generator = self.build_generator()
        self.generator.compile(loss = ['binary_crossentropy'], optimizer = optimizer)
#         self.gener_parallel = multi_gpu_model(self.generator, gpus=3)
#         self.gener_parallel.compile(loss = ['binary_crossentropy'], optimizer = optimizer)
        #TODO: The generator takes masked_img = blurred_img as an input
        masked_img = Input(shape = self.img_shape)
        gen_img = self.generator(masked_img)
#         gen_img = self.gener_parallel(masked_img)
        
        # For the combined model we will only train the generator
        self.discriminator.trainable = False
        valid = self.discriminator(gen_img)
#         valid = self.discri_parallel(gen_img)

        # The combined model  (stacked generator and discriminator) takes
        # masked_img as input => generates images => determines validity 
        #compile means square error of masked_img and gen_img together with entropy signal backprogate from dicriminator
        self.combined = Model(masked_img , [gen_img, valid])
#         self.combined_para = multi_gpu_model(self.combined, gpus=3)
#         self.combined_para.compile(loss=['mse', 'binary_crossentropy'],
#             loss_weights=[1, 0.01],
#             optimizer=optimizer)
        self.combined.compile(loss=['mse', 'binary_crossentropy'],
            loss_weights=[1, 0.01],
            optimizer=optimizer)
    def build_generator(self):
        
        model = Sequential()
        '''
        For the SAME padding, the output height and width are computed as:

        out_height = ceil(float(in_height) / float(strides[1]))

        out_width = ceil(float(in_width) / float(strides[2]))
        And

        For the VALID padding, the output height and width are computed as:

        out_height = ceil(float(in_height - filter_height + 1) / float(strides[1]))

        out_width = ceil(float(in_width - filter_width + 1) / float(strides[2]))
        '''
        model.add(Conv2D(64, kernel_size = 3, strides = 2, input_shape = self.img_shape,padding = 'same')) #out 128/2 = 64
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size = 3, strides = 2, padding = 'same')) #out 64/2 = 32
        model.add(Activation('relu'))
        model.add(Conv2D(256, kernel_size = 3, strides = 2, padding = 'same')) #out 32/2 = 16
        model.add(Activation('relu'))
        model.add(Conv2D(256, kernel_size = 3, strides = 2, padding = 'same')) #out 16/2 = 8
        model.add(Activation('relu'))
        model.add(Conv2D(256, kernel_size = 3, strides = 2, padding = 'same')) #out 8/2 = 4
        model.add(Activation('relu'))
        model.add(BatchNormalization())
        # Decoder
        model.add(UpSampling2D()) #4*2 = 8
        model.add(Conv2D(256, kernel_size=3, padding="same")) # same => 8 
        model.add(Activation('relu')) 
        model.add(UpSampling2D()) #8*2 = 16
        model.add(Conv2D(256, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(UpSampling2D()) #16*2 = 32
        model.add(Conv2D(256, kernel_size=3, padding="same")) # same => 8 
        model.add(Activation('relu')) 
        model.add(UpSampling2D()) #32*2 = 64
        model.add(Conv2D(256, kernel_size=3, padding="same")) # same => 8 
        model.add(Activation('relu')) 
#         model.add(UpSampling2D()) #64*2 = 128
#         model.add(Conv2D(64, kernel_size=3, padding="same")) # same => 8 
#         model.add(Activation('relu')) 
#         model.add(UpSampling2D()) #128*2 = 256
#         model.add(Conv2D(32, kernel_size=3, padding="same")) # same => 8 
#         model.add(Activation('relu')) 
        model.add(UpSampling2D()) #256*2 = 512
        model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
        model.add(Activation('tanh'))

#         model.summary()

        masked_img = Input(shape=self.img_shape)
        img = model(masked_img)

        return Model(masked_img, img)
        
    def build_discriminator(self):
        model = Sequential()
        
#         model.add(Conv2D(32, kernel_size = 3, input_shape= self.img_shape, padding = 'same'))
#         model.add(Activation('relu'))
#         model.add(MaxPooling2D()) #512/2 = 256
        
#         model.add(Conv2D(64, kernel_size = 3, padding = 'same'))
#         model.add(Activation('relu'))
#         model.add(MaxPooling2D()) #256/2 = 128
        
        
        model.add(Conv2D(64, kernel_size = 3, input_shape= self.img_shape,padding = 'same'))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) #128/2 = 64
        #resnet 1
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) # 64/2 = 32
        
        #resnet 2
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(64, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) # 32/2 = 16
        #resnet 3
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) # 16/2 = 8
        #resnet 4
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) # 8/2 = 4
        #resnet 5
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(Conv2D(128, kernel_size=3, padding="same"))
        model.add(Activation('relu'))
        model.add(MaxPooling2D()) # 4/2 = 2
        
        model.add(Flatten())

#         model.summary()
        img = Input(shape=self.img_shape)
        features = model(img)
        valid = Dense(1, activation="sigmoid")(features)
#         label = Dense(self.num_classes+1, activation="softmax")(features)
#         return Model(img, [valid, label])
        return Model(img, valid)

    def create_depth_map(self, imgs):
        list_output = []
        block = 15
        for img in imgs:
            clone = np.amin(img, axis = 2)  
#             clone = 255- clone
            list_output.append(clone)
        
        return np.array(list_output)
    
    def mask_randomly(self, imgs, time = 1):
        masked_imgs = np.empty_like(imgs)
        for idx, image in enumerate(imgs): 
#             cv_img = cv2.cvtColor(image.astype(np.uint8))
#             print('type and shape of the image: {type} and {shape}'.format(type = type(cv_img),shape = cv_img.shape))
            blur_image = cv2.GaussianBlur(image, (5,5),time)
#             print('type and shape of the image: {type} and {shape}'.format(type = type(blur_image),shape = blur_image.shape))
            masked_imgs[idx] = blur_image
        return masked_imgs
        
    def load_all_images(self, dir_path, files_extension = None):
        file_names = [s for s in os.listdir(dir_path) if not os.path.isdir(os.path.join(dir_path, s))]
        file_names = sorted(file_names)
        if not files_extension or files_extension == '':
            list_names =  [os.path.join(dir_path, file) for file in file_names]
        else:
            list_names =  [os.path.join(dir_path, file) for file in file_names if file.lower().endswith(files_extension)]
        return np.array(list_names)
    
    def load_images_equivalance_to_names(self, list_names):
        array_images = []
        for name in list_names:
            full_name =  name
            img = cv2.cvtColor(cv2.imread(full_name,3), cv2.COLOR_BGR2RGB)
            array_images.append(img)
        array_images = np.array(array_images)
        return array_images
    
    def normalize_imgs(self, imgs):
        imgs = imgs / 255
        imgs = 2*imgs -1
        return imgs
    
    def train(self, epochs, batch_size=128, save_interval=50):
#         (X_train, y_train), (X_test, y_test) = mnist.load_data()
        X_train =  self.load_all_images('smaller origin', 'jpg')
        Hazed_imgs = self.load_all_images('smaller hazed', 'jpg')
        
        half_batch = int(batch_size / 2)

        # Class weights:
        # To balance the difference in occurences of digit class labels. 
        # 50% of labels that the discriminator trains on are 'fake'.
        # Weight = 1 / frequency
        cw1 = {0: 1, 1: 1}
        cw2 = {i: self.num_classes / half_batch for i in range(self.num_classes)}
        cw2[self.num_classes] = 1 / half_batch
        class_weights = [cw1, cw2]
        
        for epoch in range(epochs):
            #----------------------------
            # Train Discriminator
            #----------------------------
            
            idx = np.random.randint(0, X_train.shape[0], half_batch)
            names_imgs = X_train[idx]
            names_masked_imgs = Hazed_imgs[idx]
            imgs = self.load_images_equivalance_to_names(names_imgs)
            masked_imgs = self.load_images_equivalance_to_names(names_masked_imgs)
            imgs = self.normalize_imgs(imgs)
            masked_imgs = self.normalize_imgs(masked_imgs)
            #Generates a half batch of new images
            gen_imgs = self.generator.predict(masked_imgs)
            
            valid = np.ones((half_batch, 1))
            fake = np.zeros((half_batch, 1))
            
#===========================================================================
            d_loss_real = self.discriminator.train_on_batch(imgs, valid, class_weight=class_weights)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, valid, class_weight=class_weights)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
            
            # ---------------------
            #  Train Generator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            names_imgs = X_train[idx]
        
            names_masked_imgs = Hazed_imgs[idx]
            imgs = self.load_images_equivalance_to_names(names_imgs)
            masked_imgs = self.load_images_equivalance_to_names(names_masked_imgs)
            imgs = self.normalize_imgs(imgs)
            masked_imgs = self.normalize_imgs(masked_imgs)
             # Generator wants the discriminator to label the generated images as valid
            valid = np.ones((batch_size, 1))
            
            # Train the generator
#             g_loss = self.combined.train_on_batch(masked_imgs, [imgs, valid])
            g_loss = self.combined.train_on_batch(masked_imgs, [imgs, valid])

            print ("%d [D loss: %f, acc: %.2f%%] [G loss: %f, mse: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss[0], g_loss[1]))
            # If at save interval => save generated image samples
            if epoch % save_interval == 0:
                self.generator_mse.append(g_loss[1])
                self.discriminator_loss.append(g_loss[0])
                # Select a random half batch of images
                idx = np.random.randint(0, X_train.shape[0], 3)
                names_imgs = X_train[idx]
        
                names_masked_imgs = Hazed_imgs[idx]
            
                imgs = self.load_images_equivalance_to_names(names_imgs)
                masked_imgs = self.load_images_equivalance_to_names(names_masked_imgs)
                imgs = self.normalize_imgs(imgs)
                masked_imgs = self.normalize_imgs(masked_imgs)
                gen_imgs = self.generator.predict(masked_imgs)
                imgs = 0.5 * imgs + 0.5
                masked_imgs = 0.5 * masked_imgs + 0.5
                gen_imgs = 0.5 * gen_imgs + 0.5
#                 print('shapes of imgs, masked_imgs, gen_imgs: ', imgs.shape, masked_imgs.shape, gen_imgs.shape)
                self.save_imgs(imgs, masked_imgs, gen_imgs,epoch)
                self.save_model(epoch)
        
    def save_imgs(self,valid_imgs, masked_imgs, generated_images, epoch = -1):
        r = 5
        c = generated_images.shape[0]
        
        fig, axs = plt.subplots(r, c)
        #fig.suptitle("DCGAN: Generated digits", fontsize=12)
#         valid_imgs_shrinked = valid_imgs.squeeze()
#         masked_imgs_shrinked = masked_imgs.squeeze()
#         gen_imgs_shrinked = generated_images.squeeze()
        name = epoch+25
        
        if not os.path.isdir('dcgan_exp') :
            os.mkdir('dcgan_exp')
        if not os.path.isdir('dcgan_exp/images 2') :
            os.mkdir('dcgan_exp/images 2')
        if not os.path.isdir('dcgan_exp/images 2/gernerated') :
            os.mkdir('dcgan_exp/images 2/gernerated')    
        if not os.path.isdir('dcgan_exp/images 2/masked') :
            os.mkdir('dcgan_exp/images 2/masked')
        if not os.path.isdir('dcgan_exp/images 2/valid') :
            os.mkdir('dcgan_exp/images 2/valid')   
        for i in range(c):
            cv2.imwrite('dcgan_exp/images 2/generated/{epoch}_generated_{image_at_idx}.jpg'.format(epoch = str(name).zfill(5), image_at_idx = i),cv2.cvtColor(np.array(generated_images[i]*255, dtype = 'uint8'),cv2.COLOR_RGB2BGR))
            cv2.imwrite('dcgan_exp/images 2/masked/{epoch}_masked_{image_at_idx}.jpg'.format(epoch = str(name).zfill(5), image_at_idx = i),cv2.cvtColor(np.array(masked_imgs[i]*255, dtype = 'uint8'),cv2.COLOR_RGB2BGR))
            cv2.imwrite('dcgan_exp/images 2/valid/{epoch}_valid_{image_at_idx}.jpg'.format(epoch = str(name).zfill(5), image_at_idx = i),cv2.cvtColor(np.array(valid_imgs[i]*255, dtype = 'uint8'),cv2.COLOR_RGB2BGR))
            axs[0,i].imshow(valid_imgs[i, :,:])
            axs[0,i].axis('off')
            axs[1,i].imshow(masked_imgs[i, :,:])
            axs[1,i].axis('off')
            axs[2,i].imshow(generated_images[i, :,:])
            axs[2,i].axis('off')
            axs[3,i].imshow(np.abs(masked_imgs[i, :,:]-valid_imgs[i, :,:]))
            axs[3,i].axis('off')
            plt.subplots_adjust(hspace = .3)
            axs[4,i].imshow(np.abs(generated_images[i, :,:]-valid_imgs[i, :,:]))
            axs[4,i].axis('off')
        fig.savefig("dcgan_exp/images 2/natural_image_%d.png" % name)
        plt.close()
        
    def save_model(self,epoch):
        name = epoch + 25
        def save(model, model_name,epoch = -1):
            if not os.path.isdir('dcgan_exp') :
                os.mkdir('dcgan_exp')
            if not os.path.isdir('dcgan_exp/saved_models_2') :
                os.mkdir('dcgan_exp/saved_models_2')
                
            model_path = "dcgan_exp/saved_models_2/{}_epoch_{}.json".format(model_name, epoch+25)
            weights_path = "dcgan_exp/saved_models_2/{}_weights_epoch_{}.hdf5".format(model_name , epoch+25)
            options = {"file_arch": model_path, 
                        "file_weight": weights_path}
            json_string = model.to_json()
            open(options['file_arch'], 'w').write(json_string)
            model.save_weights(options['file_weight'])

        save(self.generator, "dcgan_generator", name)
        save(self.discriminator, "dcgan_discriminator", name)
        
    def load_model(self, file_name1, file_name2):
        self.generator = self.build_generator()
        optimizer = Adam(0.0002, 0.5)
        self.generator.compile(loss = ['binary_crossentropy'], optimizer = optimizer)
        self.generator.load_weights(file_name1)
        self.discriminator.compile(loss = 'binary_crossentropy',
                                  optimizer = optimizer,
                                  metrics = ['accuracy'])
        self.discriminator.load_weights(file_name2)
        

In [None]:
if __name__ == '__main__':
    ccgan = CCGAN()
    ccgan.train(epochs=500000, batch_size=30, save_interval=10000)

In [None]:
ccgan = CCGAN()
if not os.path.isdir('dcgan_exp') :
    os.mkdir('dcgan_exp')
if not os.path.isdir('dcgan_exp/saved_models') :
    os.mkdir('dcgan_exp/saved_models')

gener_name = "dcgan_exp/saved_models/ccgan_generator_weights_epoch_249000.hdf5"
discri_name = "dcgan_exp/saved_models/ccgan_discriminator_weights_epoch_249000.hdf5"
ccgan.load_model(gener_name,discri_name)
test_list = ccgan.load_all_images('test data', 'jpg')
print(test_list)
test_img = ccgan.load_images_equivalance_to_names(test_list)
masked_imgs = ccgan.normalize_imgs(test_img)
gen_imgs = ccgan.generator.predict(masked_imgs)
gen_imgs = 0.5 * gen_imgs + 0.5
c = gen_imgs.shape[0]
for i in range(c):
    cv2.imwrite('{epoch}_generated_now_{image_at_idx}.jpg'.format(epoch = str(i).zfill(5), 
                                                                                         image_at_idx = i),
                cv2.cvtColor(np.array(gen_imgs[i]*255, dtype = 'uint8'),cv2.COLOR_RGB2BGR))

In [None]:
ccgan = CCGAN()
if not os.path.isdir('dcgan_exp') :
    os.mkdir('dcgan_exp')
if not os.path.isdir('dcgan_exp/saved_models') :
    os.mkdir('dcgan_exp/saved_models')

gener_name = "dcgan_exp/saved_models/ccgan_generator_weights_epoch_249000.hdf5"
discri_name = "dcgan_exp/saved_models/ccgan_discriminator_weights_epoch_249000.hdf5"
ccgan.load_model(gener_name,discri_name)
ccgan.train(epochs=250000, batch_size=30, save_interval=1000)