In [None]:
import json
import os
import time
import pathlib
import matplotlib.pyplot as plt
import numpy as np
import PIL
import cv2
import imageio
from sklearn.utils import shuffle
from glob import glob
from ast import literal_eval

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import model_from_json
from keras.preprocessing import image

# Only use 1 GPU

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="1"

# train label for load/save path

In [None]:
label = 'castle'

# path to read file

In [None]:
FILE_PATH = './resize_512_512/{}/*.jpg'.format(label)

In [None]:
def read_imgs(file_path,counts):
    imgs_list = glob(file_path)[:counts]
    imgs = []
    for i in tqdm_notebook(imgs_list):
        img = cv2.imread(i)[:,:,::-1].astype(np.float32)/255.
        imgs.append(img)
    imgs = np.array(imgs)
    
    return imgs

In [None]:
def build_generator(latent_dim, output_size):
    filter_num = [256, 128, 64 , 32]
    generator_input = keras.Input(shape=(latent_dim,))
    height, width = output_size
    
    x = layers.Dense(filter_num[0] * int(height//16) * int(width//16))(generator_input)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Reshape((int(height//16), int(width//16), filter_num[0]))(x) 
    
    #### 32*32*256
    x = layers.UpSampling2D(size=(2, 2))(x)
    x = layers.Conv2DTranspose(filter_num[0],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)  
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2DTranspose(filter_num[0],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)

    #### 64*64*128
           
    x = layers.UpSampling2D(size=(2, 2))(x)
    x = layers.Conv2DTranspose(filter_num[1],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)   
    x = layers.Conv2DTranspose(filter_num[1],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x) 
    
    #### 128*128*64
    
    x = layers.UpSampling2D(size=(2, 2))(x)
    x = layers.Conv2DTranspose(filter_num[2],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)   
    x = layers.Conv2DTranspose(filter_num[2],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x) 
    
    #### 256*256*32
    
    x = layers.UpSampling2D(size=(2, 2))(x)
    x = layers.Conv2DTranspose(filter_num[3],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)   
    x = layers.Conv2DTranspose(filter_num[3],(3,3),strides=(1,1),padding='same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x) 
    x = layers.Conv2DTranspose(3,(1,1),strides=(1,1),padding='same',activation='linear', kernel_initializer = 'he_normal')(x)
    
    #### 512*512*3

    return keras.models.Model(generator_input,x)

In [None]:
def build_discriminator(input_size):
    height, width, channels = input_size
    filter_num = [32,64,128,256]
    
    discriminator_input = layers.Input(shape=(height, width, channels))
    
    x = layers.Conv2D(filter_num[0], 3, padding = 'same', kernel_initializer = 'he_normal')(discriminator_input)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(filter_num[0], 3, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(filter_num[0], 3, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.AveragePooling2D()(x)
    
    #### 256*256*32

    x = layers.Conv2D(filter_num[1], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(filter_num[1], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.AveragePooling2D()(x)
    
    #### 128*128*64
    
    x = layers.Conv2D(filter_num[2], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(filter_num[2], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.AveragePooling2D()(x)
    
    #### 64*64*128
    
    x = layers.Conv2D(filter_num[3], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.Conv2D(filter_num[3], 3, strides = 1, padding = 'same', kernel_initializer = 'he_normal')(x)
    x = layers.LeakyReLU(0.2)(x)
    x = layers.AveragePooling2D()(x)
    
    #### 32*32*256

    x = layers.Flatten()(x)
    x = layers.Dense(1, activation='sigmoid')(x)

    return keras.models.Model(discriminator_input, x)

In [None]:
def build_GAN(G, D):
    D.trainable = False
    gan_input = G.input
    gan_output = D(G(gan_input))
    gan = keras.models.Model(gan_input, gan_output)
    return gan

In [None]:
def save_model(root_folder_name, model_dict):
    model_path = './{}/model/'.format(root_folder_name)
    os.makedirs(model_path, exist_ok=True)
    
    for key,model in model_dict.items():
        model_json = model.to_json()
        
        with open(model_path + '{}.json'.format(key), 'w') as json_file:
            json_file.write(model_json)

In [None]:
def train(root_folder_name, imgs, iterations=10000, bs=8, is_load_weight=False, pre_low_step=0):
    start_time_all = time.time()
    
    # build net work
    iterations = iterations
    latent_dim = 200
    height_1, width_1 = 512, 512

    G1 = build_generator(latent_dim, (height_1, width_1))
    D1 = build_discriminator((height_1, width_1, 3))
    GAN1 = build_GAN(G1, D1)
    
    if is_load_weight:
        model_path = './{}/model/'.format(root_folder_name)
        weight_path = './{}/weight/record/'.format(root_folder_name)
        
        with open(model_path+'G1.json', 'r') as json_file:
            temp = json_file.read()
            G1 = model_from_json(temp)
            G1.load_weights(weight_path+'g1_{}.h5'.format(pre_low_step))
    
        with open(model_path+'D1.json', 'r') as json_file:
            temp = json_file.read()
            D1 = model_from_json(temp)
            D1.load_weights(weight_path+'d1_{}.h5'.format(pre_low_step))

    optimizer = keras.optimizers.Adam(lr=0.0001, beta_1=0.5)
    D1.compile(loss='binary_crossentropy', optimizer=optimizer)
    GAN1 = build_GAN(G1, D1)
    optimizer = keras.optimizers.Adam(lr=0.0001, beta_1=0.5)
    GAN1.compile(loss='binary_crossentropy', optimizer=optimizer)

    model_dict = {'G1':G1,'D1':D1,'GAN1':GAN1}
    save_model(root_folder_name,model_dict)
    
    # create folder
    os.makedirs('./{}/result_image/'.format(root_folder_name),exist_ok=True)
    os.makedirs('./{}/weight/latest/'.format(root_folder_name),exist_ok=True)
    os.makedirs('./{}/weight/record/'.format(root_folder_name),exist_ok=True)
    save_dir = './{}/result_image/'.format(root_folder_name)
    weight_path = './{}/weight/latest/'.format(root_folder_name)
    weight_record_path = './{}/weight/record/'.format(root_folder_name)

    # start training loop
    start = 0
    start_time = time.time()

    low_iteration = 600
    high_iteration = 1000
    pre_high_step = 0
    batch_size = bs
    batch_num = len(imgs) // batch_size

    imgs_temp = imgs[:batch_size * batch_num]

    for step in range(iterations):
        imgs_temp = shuffle(imgs_temp)

        for low_step in range(batch_num):
            real_images = imgs_temp[low_step*batch_size:(low_step+1)*batch_size]
            real_images = (real_images-0.5)*2
            random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

            generated_images = G1.predict(random_latent_vectors)

            labels = np.concatenate([np.zeros((batch_size, 1)),
                                 np.ones((batch_size, 1))])

            labels_real = 0.9*np.ones((batch_size, 1)) 
            labels_fake = np.zeros((batch_size, 1)) 

            d_loss_real = D1.train_on_batch(real_images, labels_real)
            d_loss_fake = D1.train_on_batch(generated_images, labels_fake)
            d_loss = 0.5*np.add(d_loss_real, d_loss_fake)

            random_latent_vectors = np.random.normal(size=(batch_size, latent_dim))

            misleading_targets = np.ones((batch_size, 1))

            g_loss = GAN1.train_on_batch(random_latent_vectors, misleading_targets)

            if low_step % 100 == 0:
                # save model weights
                G1.save_weights(weight_path+'g1.h5')
                D1.save_weights(weight_path+'d1.h5')
                
                step_indicator = step*low_iteration+low_step+pre_low_step
                
                if step_indicator % 1000 == 0:
                    G1.save_weights(weight_record_path+'g1_{}.h5'.format(step_indicator))
                    D1.save_weights(weight_record_path+'d1_{}.h5'.format(step_indicator))

                # print metrics
                print('low resolution, discriminator loss at step %s: %s' % (step_indicator, d_loss))
                print('low resolution, adversarial loss at step %s: %s' % (step_indicator, g_loss))
                display_grid = np.zeros((4*height_1,width_1,3))

                for j in range(4):
                    display_grid[j*height_1:(j+1)*height_1,0:width_1,:] = generated_images[j]

                img = image.array_to_img((display_grid[:,:,::-1]*127.5)+127.5, scale=False)
                img.save(os.path.join(save_dir, 'low_generated_' + str(step*low_iteration+low_step+pre_low_step) + '.png'))
                print("--- %s seconds ---" % (time.time() - start_time))
                start_time = time.time()        

In [None]:
if __name__=='__main__':
    
    # load image
    # imgs_length = len(glob(FILE_PATH))
    imgs_length=4000
    imgs = read_imgs(FILE_PATH, imgs_length)
    
    # train
    train(label,imgs)
    
    # train with pretrain weights
    # train(label, imgs, is_load_weight=True, pre_low_step=60000)