Data set pre-processing for SRGAN

In [12]:
import numpy as np
import tensorflow as tf
import tensorlayer as tfl

import matplotlib.pyplot as plt

n_epoch = 200
batch_size = 16

lr_init = 1e-4

hr_images_path = "Data\DIV2K_train_HR\DIV2K_train_HR"
hr_images_file = sorted(tfl.files.load_file_list(path=hr_images_path, regx='.*62.png', printable=False))
hr_images_train = tfl.visualize.read_images(hr_images_file, path=hr_images_path)

hr_images_ds = []
lr_images_ds = []

for img in hr_images_train:
    hr_patch = tf.random_crop(img, [384, 384, 3])
    hr_patch = tf.divide(hr_patch, 255) * 2 - 1
    hr_patch  = tf.image.random_flip_left_right(hr_patch)
    lr_patch  = tf.image.resize_images(hr_patch, size=[96, 96])
    hr_images_ds.append(hr_patch)
    lr_images_ds.append(lr_patch)

train_images_ds = tf.data.Dataset.from_tensor_slices((lr_images_ds, hr_images_ds))
train_images_ds = train_images_ds.repeat(300)
train_images_ds = train_images_ds.shuffle(128)
train_images_ds = train_images_ds.prefetch(2048)
train_images_ds = train_images_ds.batch(16)

print(train_images_ds)


[TL] read 8 from Data\DIV2K_train_HR\DIV2K_train_HR
<BatchDataset shapes: ((?, 96, 96, 3), (?, 384, 384, 3)), types: (tf.float32, tf.float32)>



Define the model 

In [15]:
#from tensorlayer.layers import (INPUT, CONV2D, BATCHNORM, ELEMENTWISE, SUBPIXELCONV2D, FLATTEN, DENSE)
#from tensorlayer.models import MODEL
import tensorflow as tf
import tensorlayer as tfl
from tensorlayer.layers import (InputLayer, Conv2d, BatchNormLayer, ElementwiseLayer, SubpixelConv2d, FlattenLayer, DenseLayer)
from tensorlayer.models import *

def SRGAN_GEN(input_shape):
    w_init = tf.random_normal_initializer(stddev=0.02)
    g_init = tf.random_normal_initializer(1., 0.02)
    
    n0 = InputLayer(input_shape)
    n = Conv2d(64, (3, 3), (1, 1), act=tf.nn.relu, padding='SAME', w_init=w_init)(n0)
    tmp = n
    
    # Residual blocks
    for i in range(16):
        nn = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
        nn = BatchNormLayer(act=tf.nn.relu, gamma_init=g_init)(nn)
        nn = Conv2d(64, (3,3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(nn)
        nn = BatchNormLayer(gamma_init=g_init)(n)
        nn = ElementwiseLayer(tf.add)([n, nn])
        n  = nn
    
    n = Conv2d(64, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(gamma_init=g_init)(n)
    n = ElementwiseLayer(tf.add)([n, tmp])
    
    n = Conv2d(256, (3, 3), (1, 1), padding='SAME', W_init=w_init)(n)
    n = SubPixelConv2D(scale=2, n_out_channels=None, act=tf.nn.relu)(n)
    
    n = Conv2D(256, (3, 3), (1, 1), padding='SAME', W_init=w_init)(n)
    n = SubPixelConv2d(scale=2, n_out_channels=None, act=tf.nn.relu)(n)
    
    nn = Conv2d(3, (1, 1), (1, 1), act=tf.nn.tanh, padding='SAME', W_init=w_init)(n)
    
    GEN = Model(inputs=n0, outputs=nn, name='SRGAN_GENERATOR')
    
    return GEN


def SRGAN_DIS(input_shape):
    w_init = tf.random_normal_initializer(stddev=0.02)
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    n0 = InputLayer(input_shape)
    n = Conv2d(64, (4, 4), (2, 2), act=tfl.act.lrelu(0.2), padding='SAME', W_init=w_init)(n0)
    
    n = Conv2d(64 * 2, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 4, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 8, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 16, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 32, (4, 4), (2, 2), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 16, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 8, (1, 1), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
    nn = BatchNormLayer(gamma_init=gamma_init)(n)
    
    n = Conv2d(64 * 2, (1, 1), (2, 2), padding='SAME', W_init=w_init, b_init=None)(nn)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 2, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(act=lrelu, gamma_init=gamma_init)(n)
    n = Conv2d(64 * 8, (3, 3), (1, 1), padding='SAME', W_init=w_init, b_init=None)(n)
    n = BatchNormLayer(gamma_init=gamma_init)(n)
    n = ElementwiseLayer(tf.add, act=lrelu)(n, nn)
    
    n = FlattenLayer()(n)
    no = DenseLayer(n_units=1, W_init=w_init)(n)
    DIS = Model(inputs=n0, outputs=no, name="SRGAN_DISCRIMINATOR")
    
    return DIS


GEN = SRGAN_GEN((batch_size, None, None, 3)) # (None, 96, 96, 3)
DIS = SRGAN_DIS((batch_size, None, None, 3)) # (None, 384, 384, 3)

VGG19 = tfl.models.vgg19(pretrained=True, end_with='pool4', mode='static')


AttributeError: 'tuple' object has no attribute 'all_layers'

Training pipeline

In [14]:


print(GEN)
print(DIS)
print(VGG19)

# optimizer
gen_opt = tf.optimizer.Adam(1e-4, beta_1=beta1)
dis_opt = tf.optimizer.Adam(1e-4, beta_1=beta1)


# Initial learning
n_step = round(n_epoch // batch_size)
for step, (lr_patchs, hr_patchs) in enumerate(train_images_ds):
    time_stamp = time.time()
    with tf.GradientTape() as grad:
        pred_hr_patch = GEN(lr_patch)
        mse_loss = tfl.cost.mean_squared_error(pred_hr_patch, hr_patch, is_mean=True)
    w_grad = grad.gradient(mse_loss, GEN.weights)
    g_opt.apply_gradients(zip(grad, GEN.weights))
    step += 1
    epoch = step // n_step
    print("Epoch: [{}/{}] step:[{}/{}] time: {}s, mse: {}".format(epoch, n_epoch, step, n_step, time.time() - time_stamp, mse_loss))


# Generate Adversarial network training (GEN, DIS)
n_step = round(n_epoch // batch_size)
for step, (lr_patchs, hr_patchs) in enumerate(train_images_ds):
    with tf.GradientTape as grad:
        pred_patchs = GEN(lr_patchs)
        pred_logits = DIS(pred_patchs)
        real_logits = DIS(hr_patchs)
        pred_feature = VGG19((pred_patchs+1)/2.)
        real_feature = VGG19((hr_patchs+1)/2.)
        d_loss1 = tfl.cost.sigmoid_cross_entropy(real_logits, tf.ones_like(real_logits))
        d_loss2 = tfl.cost.sigmoid_cross_entropy(pred_logits, tf.zeros_like(pred_logits))
        d_loss = d_loss1 + d_loss2
        g_gan_loss = 1e-3 * tfl.cost.sigmoid_cross_entropy(pred_logits, tf.ones_like(pred_logits))
        mse_loss = tfl.cost.mean_squared_error(pred_patchs, hr_patchs, is_mean=True)
        vgg_loss = 2e-6 * tfl.cost.mean_squared_error(pred_feature, real_feature, is_mean=True)
        g_loss = mse_loss + vgg_loss + g_gan_loss
        
    gen_grad = grad.gradient(g_loss, GEN.trainable_weights)
    gen_opt.apply_gradients(zip(gen_grad, GEN.trainable_weights))
    dis_grad = grad.gradient(d_loss, DIS.trainable_weights)
    dis_optimizer.apply_gradients(zip(grad, DIS.trainable_weights))
    step += 1
    epoch = step//n_step
    print("Epoch: [{}/{}] step: [{}/{}] time: {}s, g_loss(mse:{}, vgg:{}, adv:{}) d_loss: {}".format(
           epoch, n_epoch_init, step, n_step_epoch, time.time() - step_time, mse_loss, vgg_loss, g_gan_loss, d_loss))

    
    # Update the learning rate
    decay_step = n_epoch // 2
    if epoch != 0 and (epoch % decay_step == 0):
        new_lr_decay = (0.1)**(epoch // decay_step)
        lr_v.assign(1e-4 * new_lr_decay)
        log = " ** new learning rate: %f (for GAN)" % (1e-4 * new_lr_decay)
        print(log)
    
    
    
    
    

AttributeError: 'tuple' object has no attribute 'all_layers'