In [1]:
from Network import Generator, Discriminator
import Utils_model, Utils
from Utils_model import VGG_LOSS

from keras.models import Model
from keras.layers import Input
from tqdm import tqdm, tqdm_notebook
import numpy as np
import argparse

Using TensorFlow backend.


In [2]:
np.random.seed(10)
# Better to use downscale factor as 4
downscale_factor = 3.75
# Remember to change image shape if you are having different size of images
image_shape = (1920,1080,3)

### GAN

In [3]:
def get_gan_network(discriminator, shape, generator, optimizer, vgg_loss):
    discriminator.trainable = False
    gan_input = Input(shape=shape)
    x = generator(gan_input)
    gan_output = discriminator(x)
    gan = Model(inputs=gan_input, outputs=[x,gan_output])
    gan.compile(loss=[vgg_loss, "binary_crossentropy"],
                loss_weights=[1., 1e-3],
                optimizer=optimizer)

    return gan

### Preparing training

In [4]:
epochs = 1
batch_size = 1
input_dir = '../images/photo_fullhd'
output_dir = 'output'
model_save_dir = 'trained_model'
number_of_images = 320
train_test_ratio = 0.8

In [5]:
x_train_lr, x_train_hr, x_test_lr, x_test_hr = Utils.load_training_data(input_dir, '.jpg', number_of_images, train_test_ratio, downscale_factor) 
loss = VGG_LOSS(image_shape)  

Loading files::   9%|▉         | 318/3518 [00:13<02:28, 21.61it/s]
Converting to low-res:   0%|          | 0/256 [00:00<?, ?it/s][A
Converting to low-res:   2%|▏         | 6/256 [00:00<00:04, 57.31it/s][A
Converting to low-res:   5%|▌         | 13/256 [00:00<00:04, 59.53it/s][A
Converting to low-res:   8%|▊         | 20/256 [00:00<00:03, 61.41it/s][A
Converting to low-res:  11%|█         | 27/256 [00:00<00:03, 62.14it/s][A
Converting to low-res:  13%|█▎        | 34/256 [00:00<00:03, 62.55it/s][A
Converting to low-res:  16%|█▌        | 41/256 [00:00<00:03, 63.21it/s][A
Converting to low-res:  19%|█▉        | 48/256 [00:00<00:03, 63.70it/s][A
Converting to low-res:  21%|██▏       | 55/256 [00:00<00:03, 63.64it/s][A
Converting to low-res:  24%|██▍       | 62/256 [00:00<00:03, 63.42it/s][A
Converting to low-res:  27%|██▋       | 69/256 [00:01<00:02, 63.68it/s][A
Converting to low-res:  30%|██▉       | 76/256 [00:01<00:02, 63.62it/s][A
Converting to low-res:  32%|███▏      | 83/

In [6]:
batch_count = int(x_train_hr.shape[0] / batch_size)
shape = (int(image_shape[0]//downscale_factor), int(image_shape[1]//downscale_factor), image_shape[2])

In [7]:
generator = Generator(shape).generator()
discriminator = Discriminator(image_shape).discriminator()

optimizer = Utils_model.get_optimizer()
generator.compile(loss=loss.vgg_loss, optimizer=optimizer)
discriminator.compile(loss="binary_crossentropy", optimizer=optimizer)
    
gan = get_gan_network(discriminator, shape, generator, optimizer, loss.vgg_loss)

W0915 20:33:53.562437 4569007552 deprecation_wrapper.py:119] From /Users/kjedrzejewski/miniconda3/envs/gan_upscaling/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

W0915 20:33:53.591164 4569007552 deprecation_wrapper.py:119] From /Users/kjedrzejewski/miniconda3/envs/gan_upscaling/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

W0915 20:33:53.593585 4569007552 deprecation_wrapper.py:119] From /Users/kjedrzejewski/miniconda3/envs/gan_upscaling/lib/python3.7/site-packages/keras/backend/tensorflow_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

W0915 20:33:53.630385 4569007552 deprecation_wrapper.py:119] From /Users/kjedrzejewski/miniconda3/envs/gan_upscaling/lib/python3.7/site-packages/keras/backend/tensorflow_backend.p

In [8]:
loss_file = open(model_save_dir + '/losses.txt' , 'w+')
loss_file.close()

### Training

In [None]:
for e in range(1, epochs+1):
    print ('-'*15, 'Epoch %d' % e, '-'*15)
    for _ in tqdm(range(batch_count)):
        
        rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
        
        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        #generated_images_sr = generator.predict(image_batch_lr)


  0%|          | 0/256 [00:00<?, ?it/s][A

--------------- Epoch 1 ---------------


In [None]:
for e in range(1, epochs+1):
    print ('-'*15, 'Epoch %d' % e, '-'*15)
    for _ in tqdm(range(batch_count)):

        rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)

        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]
        generated_images_sr = generator.predict(image_batch_lr)

        real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        fake_data_Y = np.random.random_sample(batch_size)*0.2

        discriminator.trainable = True

        d_loss_real = discriminator.train_on_batch(image_batch_hr, real_data_Y)
        d_loss_fake = discriminator.train_on_batch(generated_images_sr, fake_data_Y)
        discriminator_loss = 0.5 * np.add(d_loss_fake, d_loss_real)

        rand_nums = np.random.randint(0, x_train_hr.shape[0], size=batch_size)
        image_batch_hr = x_train_hr[rand_nums]
        image_batch_lr = x_train_lr[rand_nums]

        gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
        discriminator.trainable = False
        gan_loss = gan.train_on_batch(image_batch_lr, [image_batch_hr,gan_Y])


    print("discriminator_loss : %f" % discriminator_loss)
    print("gan_loss :", gan_loss)
    gan_loss = str(gan_loss)

    loss_file = open(model_save_dir + 'losses.txt' , 'a')
    loss_file.write('epoch%d : gan_loss = %s ; discriminator_loss = %f\n' %(e, gan_loss, discriminator_loss) )
    loss_file.close()

    if e == 1 or e % 5 == 0:
        Utils.plot_generated_images(output_dir, e, generator, x_test_hr, x_test_lr)
    if e % 500 == 0:
        generator.save(model_save_dir + 'gen_model%d.h5' % e)
        discriminator.save(model_save_dir + 'dis_model%d.h5' % e)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))



  0%|          | 0/4 [00:00<?, ?it/s][A[A

--------------- Epoch 1 ---------------
