In [1]:
# python train_model.py model={iphone,sony,blackberry} dped_dir=dped vgg_dir=vgg_pretrained/imagenet-vgg-verydeep-19.mat

import tensorflow as tf
from scipy import misc
import numpy as np
import sys

from load_dataset import load_test_data, load_batch
from ssim import MultiScaleSSIM
import models
import utils
import vgg

from tensorflow.python.keras.applications import VGG16
from tensorflow.python.keras.layers import Dropout, Dense
from tensorflow.python.keras.models import Model
from tensorflow.python.keras import backend as K

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
#config.gpu_options.per_process_gpu_memory_fraction = 0.3
#set_session(tf.Session(config=config))
session = tf.Session(config=config)

def mean_score(dis):
    score_range = tf.range(0, 10, 1, dtype=tf.float32)
    score = tf.matmul(dis, tf.reshape(score_range, [-1, 1]))
    return score
def nima_preprocess_input(img):
    img = img / 127.5
    img = img - 1.0
    return img

# defining size of the training image patches

PATCH_WIDTH = 100
PATCH_HEIGHT = 100
PATCH_SIZE = PATCH_WIDTH * PATCH_HEIGHT * 3

# processing command arguments

batch_size = 20
train_size = 30000
learning_rate = 5e-4
num_train_iters = 20000

w_content = 100.0
w_color = 0.5
w_texture = 1
w_tv = 2000
w_nima = 10

dped_dir = '/home/public/hw/dataset/dped/dped/'
vgg_dir = 'vgg_pretrained/imagenet-vgg-verydeep-19.mat'
eval_step = 1000

phone = "my"

np.random.seed(0)

# loading training and test data

print("Loading test data...")
test_data, test_answ = load_test_data(phone, dped_dir, PATCH_SIZE)
print("Test data was loaded\n")

print("Loading training data...")
train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)
print("Training data was loaded\n")

TEST_SIZE = test_data.shape[0]
num_test_batches = int(test_data.shape[0]/batch_size)

  from ._conv import register_converters as _register_converters


Loading test data...
Test data was loaded

Loading training data...
Training data was loaded



In [2]:
# defining system architecture

with tf.Graph().as_default(), tf.Session(config=config) as sess:
    
    # placeholders for training data

    phone_ = tf.placeholder(tf.float32, [None, PATCH_SIZE])
    phone_image = tf.reshape(phone_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])

    dslr_ = tf.placeholder(tf.float32, [None, PATCH_SIZE])
    dslr_image = tf.reshape(dslr_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 3])

    adv_ = tf.placeholder(tf.float32, [None, 1])

    # get processed enhanced image

    enhanced = models.resnet(phone_image)

    # transform both dslr and enhanced images to grayscale

    enhanced_gray = tf.reshape(tf.image.rgb_to_grayscale(enhanced), [-1, PATCH_WIDTH * PATCH_HEIGHT])
    dslr_gray = tf.reshape(tf.image.rgb_to_grayscale(dslr_image),[-1, PATCH_WIDTH * PATCH_HEIGHT])

    # push randomly the enhanced or dslr image to an adversarial CNN-discriminator

    adversarial_ = tf.multiply(enhanced_gray, 1 - adv_) + tf.multiply(dslr_gray, adv_)
    adversarial_image = tf.reshape(adversarial_, [-1, PATCH_HEIGHT, PATCH_WIDTH, 1])

    discrim_predictions = models.adversarial(adversarial_image)

    # losses
    # 1) texture (adversarial) loss

    discrim_target = tf.concat([adv_, 1 - adv_], 1)

    loss_discrim = -tf.reduce_sum(discrim_target * tf.log(tf.clip_by_value(discrim_predictions, 1e-10, 1.0)))
    loss_texture = -loss_discrim

    correct_predictions = tf.equal(tf.argmax(discrim_predictions, 1), tf.argmax(discrim_target, 1))
    discim_accuracy = tf.reduce_mean(tf.cast(correct_predictions, tf.float32))

    # 2) content loss

    CONTENT_LAYER = 'pool3'

    enhanced_vgg = vgg.net(vgg_dir, vgg.preprocess(enhanced * 255))
    dslr_vgg = vgg.net(vgg_dir, vgg.preprocess(dslr_image * 255))

    content_size = utils._tensor_size(dslr_vgg[CONTENT_LAYER]) * batch_size
    loss_content = 2 * tf.nn.l2_loss(enhanced_vgg[CONTENT_LAYER]/content_size - dslr_vgg[CONTENT_LAYER]/content_size) 

    # 3) color loss

    enhanced_blur = utils.blur(enhanced)
    dslr_blur = utils.blur(dslr_image)

    loss_color = tf.reduce_sum(tf.pow(dslr_blur - enhanced_blur, 2))/(2 * batch_size)

    # 4) total variation loss

    batch_shape = (batch_size, PATCH_WIDTH, PATCH_HEIGHT, 3)
    tv_y_size = utils._tensor_size(enhanced[:,1:,:,:])
    tv_x_size = utils._tensor_size(enhanced[:,:,1:,:])
    y_tv = tf.nn.l2_loss(enhanced[:,1:,:,:] - enhanced[:,:batch_shape[1]-1,:,:])
    x_tv = tf.nn.l2_loss(enhanced[:,:,1:,:] - enhanced[:,:,:batch_shape[2]-1,:])
    loss_tv = 2 * (x_tv/tv_x_size + y_tv/tv_y_size) / batch_size
    
    # 5) nima loss
    base_model = VGG16(input_shape=(None, None, 3), include_top=False, pooling='avg', weights=None)
    x = Dropout(0.75)(base_model.output)
    x = Dense(10, activation='softmax')(x)
    nima_model = Model(base_model.input, x)
    nima_model.trainable = False
    for layer in nima_model.layers:
        layer.trainable = False
    nima_outputs = nima_model(nima_preprocess_input(enhanced * 255.))
    nima_score = tf.reduce_mean(mean_score(nima_outputs))
    loss_nima = 9.0 - nima_score

    # final loss

    loss_generator = w_content * loss_content + w_texture * loss_texture + w_color * loss_color + w_tv * loss_tv + w_nima*loss_nima

    # psnr loss

    enhanced_flat = tf.reshape(enhanced, [-1, PATCH_SIZE])

    loss_mse = tf.reduce_sum(tf.pow(dslr_ - enhanced_flat, 2))/(PATCH_SIZE * batch_size)
    loss_psnr = 20 * utils.log10(1.0 / tf.sqrt(loss_mse))

    # optimize parameters of image enhancement (generator) and discriminator networks

    generator_vars = [v for v in tf.global_variables() if v.name.startswith("generator")]
    discriminator_vars = [v for v in tf.global_variables() if v.name.startswith("discriminator")]

    train_step_gen = tf.train.AdamOptimizer(learning_rate).minimize(loss_generator, var_list=generator_vars)
    train_step_disc = tf.train.AdamOptimizer(learning_rate).minimize(loss_discrim, var_list=discriminator_vars)

    saver = tf.train.Saver(var_list=generator_vars, max_to_keep=100)

    print('Initializing variables')
    sess.run(tf.global_variables_initializer())
    print('Loading weights of nima')
    nima_model.load_weights('../myproject/NIMA_weights/vgg16_weights.h5')

    print('Training network')

    train_loss_gen = 0.0
    train_acc_discrim = 0.0

    all_zeros = np.reshape(np.zeros((batch_size, 1)), [batch_size, 1])
    test_index = np.random.randint(0, TEST_SIZE, 5);
    test_crops = test_data[test_index, :]
    test_dslr_crops = test_answ[test_index, :]

    logs = open('mymodels/' + phone + '.txt', "w+")
    logs.write('w_content:{}, w_texture:{}, w_color:{}, w_tv:{}, w_nima:{}'.format(w_content, w_texture, w_color, w_tv, w_nima))
    logs.close()

    for i in range(num_train_iters):

        # train generator

        idx_train = np.random.randint(0, train_size, batch_size)

        phone_images = train_data[idx_train]
        dslr_images = train_answ[idx_train]

        [loss_temp, temp] = sess.run([loss_generator, train_step_gen],
                                        feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: all_zeros, K.learning_phase():0})
        train_loss_gen += loss_temp / eval_step

        # train discriminator

        idx_train = np.random.randint(0, train_size, batch_size)

        # generate image swaps (dslr or enhanced) for discriminator
        swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])

        phone_images = train_data[idx_train]
        dslr_images = train_answ[idx_train]

        [accuracy_temp, temp] = sess.run([discim_accuracy, train_step_disc],
                                        feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps, K.learning_phase():0})
        train_acc_discrim += accuracy_temp / eval_step

        if i % eval_step == 0:

            # test generator and discriminator CNNs

            test_losses_gen = np.zeros((1, 7))
            test_accuracy_disc = 0.0
            loss_ssim = 0.0
            nima_mean_scores = 0.0

            for j in range(num_test_batches):

                be = j * batch_size
                en = (j+1) * batch_size

                swaps = np.reshape(np.random.randint(0, 2, batch_size), [batch_size, 1])

                phone_images = test_data[be:en]
                dslr_images = test_answ[be:en]

                [enhanced_crops, accuracy_disc, losses, nima_mean] = sess.run([enhanced, discim_accuracy, \
                                [loss_generator, loss_content, loss_color, loss_texture, loss_tv, loss_psnr, loss_nima], nima_score], \
                                feed_dict={phone_: phone_images, dslr_: dslr_images, adv_: swaps, K.learning_phase():0})

                test_losses_gen += np.asarray(losses) / num_test_batches
                test_accuracy_disc += accuracy_disc / num_test_batches
                nima_mean_scores += nima_mean / num_test_batches
                
                loss_ssim += MultiScaleSSIM(np.reshape(dslr_images * 255, [batch_size, PATCH_HEIGHT, PATCH_WIDTH, 3]),
                                                    enhanced_crops * 255) / num_test_batches

            logs_disc = "step %d, %s | discriminator accuracy | train: %.4g, test: %.4g" % \
                  (i, phone, train_acc_discrim, test_accuracy_disc)

            logs_gen = "generator losses | train: %.4g, test: %.4g | content: %.4g, color: %.4g, texture: %.4g, tv: %.4g | \
                        psnr: %.4g, nima: %.4g ssim: %.4g\n nima_mean: %.4g" % \
                  (train_loss_gen, test_losses_gen[0][0], test_losses_gen[0][1], test_losses_gen[0][2],
                   test_losses_gen[0][3], test_losses_gen[0][4], test_losses_gen[0][5], test_losses_gen[0][6],loss_ssim, nima_mean_scores)

            print(logs_disc)
            print(logs_gen)

            # save the results to log file

            logs = open('mymodels/' + phone + '.txt', "a")
            logs.write(logs_disc)
            logs.write('\n')
            logs.write(logs_gen)
            logs.write('\n')
            logs.close()

            # save visual results for several test image crops

            enhanced_crops = sess.run(enhanced, feed_dict={phone_: test_crops, dslr_: dslr_images, adv_: all_zeros, K.learning_phase():0})

            idx = 0
            for crop in enhanced_crops:
                before_after = np.hstack((np.reshape(test_crops[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3]), crop, 
                                          np.reshape(test_dslr_crops[idx], [PATCH_HEIGHT, PATCH_WIDTH, 3])))
                misc.imsave('myresults/' + str(phone)+ "_" + str(idx) + '_iteration_' + str(i) + '.jpg', before_after)
                idx += 1

            train_loss_gen = 0.0
            train_acc_discrim = 0.0

            # save the model that corresponds to the current iteration

            saver.save(sess, 'mymodels/' + str(phone) + '_iteration_' + str(i) + '.ckpt', write_meta_graph=False)

            # reload a different batch of training data

            del train_data
            del train_answ
            train_data, train_answ = load_batch(phone, dped_dir, train_size, PATCH_SIZE)

Initializing variables
Loading weights of nima
Training network
step 0, my | discriminator accuracy | train: 0.00035, test: 0.4923
generator losses | train: 0.5818, test: 553.2 | content: 1.509, color: 796.3, texture: -52.91, tv: 0.0001055 |                         psnr: 11.25, nima: 5.689 ssim: 0.3962
 nima_mean: 3.311


`imsave` is deprecated in SciPy 1.0.0, and will be removed in 1.2.0.
Use ``imageio.imwrite`` instead.


step 1000, my | discriminator accuracy | train: 0.5013, test: 0.5052
generator losses | train: 109.3, test: 74.28 | content: 0.2128, color: 29.69, texture: -14.55, tv: 0.004962 |                         psnr: 22.98, nima: 4.278 ssim: 0.937
 nima_mean: 4.722
step 2000, my | discriminator accuracy | train: 0.5004, test: 0.5003
generator losses | train: 70.63, test: 65.54 | content: 0.1796, color: 22.64, texture: -14.31, tv: 0.004178 |                         psnr: 23.47, nima: 4.221 ssim: 0.9396
 nima_mean: 4.779
step 3000, my | discriminator accuracy | train: 0.4977, test: 0.4984
generator losses | train: 64.79, test: 69.76 | content: 0.1715, color: 31.37, texture: -14.02, tv: 0.004493 |                         psnr: 23.03, nima: 4.197 ssim: 0.9412
 nima_mean: 4.803
step 4000, my | discriminator accuracy | train: 0.4964, test: 0.4988
generator losses | train: 63.19, test: 60.15 | content: 0.1472, color: 18.32, texture: -14.79, tv: 0.005263 |                         psnr: 24.1, nima: 4.0