In [3]:
import tensorflow as tf
import os
import numpy as np

## SRGAN model
<img src="model.PNG">

In [5]:
def PReLU(_x):
    alpha = tf.get_variable("alpha", _x.get_shape()[-1], 
                            initializer = tf.constant_initializer(0.0), dtype = tf.int32)
    return tf.nn.relu(_x) + alpha * (_x - tf.abs(_x)) * 0.5

# pixel shuffle

\\(S_s  : [0,1]\\)<sup>(sH * sW * C)</sup> -> \\([0,1]\\)<sup>(H * W * s<sup>2</sup>c)</sup>
<br></br>
<br></br>
\\(S_s(I)\\)<sub>i,j,k</sub> = \\(I\\)<sub>si+k%s, sj+(k%s), k/s<sup>2</sup></sub>

In [11]:
def phaseShift(inputs, scale, shape_1, shape_2):
    X = tf.reshape(inputs, shape_1)
    X = tf.transpose(X, [0, 1, 3, 2, 4])
    return tf.reshape(X, shape_2)

In [12]:
def pixelShuffler(inputs, scale = 2):
    size = tf.shape(inputs)
    batch_size = size[0]
    h = size[1]
    w = size[2]
    c = inputs.get_shape().as_list()[-1]
    
    # Get the target channel size
    channel_target = c // (scale**2)
    channel_factor = c // channel_target
    
    shape_1 = [batch_size, h, w, channel_factor // scale, channel_factor // scale]
    shape_2 = [batch_size, h * scale, w * scale, 1]
    
    # Reshape and transpose for periodic shuffling for each channel
    
    input_split = tf.split(inputs, channel_target, axis = 3)
    output = tf.concat([phaseShift(x, scale, shape_1, shape_2) for x in input_split], axis = 3)
    
    return output

In [8]:
def B_residual_block(inputs, output_dim, k_size, is_training, scope = 'G_b_res_block'):
    with tf.variable_scope(scope) as scope:
        w1 = tf.get_variable('w1', [k_size, k_size, I.get_shape()[-1], output_dim],
                            initializer = tf.truncated_normal_initializer(stddev = 0.2))
        conv1 = tf.nn.conv2d(inputs, w1, strides = [1,1,1,1], padding = 'same')
        b1 = tf.get_variable('b1', [output_dim], initializer = tf.constant_initializer(0.0))
        conv1 = tf.nn.bias_add(conv1, b1)
        
        bn = tf.contrib.layers.batch_norm(conv1, is_training = is_training, scope = 'bn', 
                                           decay = 0.9, zero_debias_moving_mean = True)
        prelu = PReLU(bn1)
        
        w2 = tf.get_variable('w2', [k_size, k_size, output_dim, output_dim],
                            initializer = tf.truncated_normal_initializer(stddev = 0.2))
        conv2 = tf.nn.conv2d(prelu, w2, strides = [1,1,1,1], padding = 'same')
        b2 = tf.get_variable('b2', [output_dim], initializer = tf.constant_initializer(0.0))
        conv2 = tf.nn.bias_add(conv2, b2)
        
        return conv2 + inputs

In [15]:
def last_block(inputs, output_dim, k_size, scope = 'G_last_block'):
    with tf.variable_scope(scope) as scope:
        w = tf.get_variable('w', [k_size, k_size, inputs.get_shape()[-1], output_dim],
                           initializer = tf.truncated_normal_initializer(stddev = 0.2))
        conv = tf.nn.conv2d(inputs, w, strides = [1,1,1,1], padding = 'same')
        b = tf.get_variable('b', [output_dim], initializer = tf.constant_initializer(0.0))
        conv = tf.nn.bias_add(conv, b)
        px = pixelShuffler(conv)
        return PReLU(px)

In [18]:
def vgg_block(inputs, output_dims, k_size, s, is_training, scope = 'D_disc_block'):
    with tf.variable_scope(scope) as scope:
        w = tf.get_variable('w', [k_size, k_size, inputs.get_shape()[-1], output_dim],
                           initializer = tf.truncated_normal_initializer(stddev = 0.2))
        conv = tf.nn.conv2d(inputs, w, strides = [s,s,s,s], padding = 'same')
        b = tf.get_variable('b', [output_dim], initializer = tf.constant_initializer(0.0))
        conv = tf.nn.bias_add(conv, b)
        return tf.nn.leaky_relu(conv, alpha=0.2)

# Loss function
<img src="total_loss.PNG"></img>
<br></br>
## Content loss
* MSE based loss
<img src="mse_loss.PNG"></img>
    - 가장 널리 쓰이지만 smooth한 영역에 대해서는 성능이 좋지 않다.
    
* VGG based loss
<img src="vgg_loss.PNG"></img>
    - smooth한 영역에 대해서 성능을 개선시킬 수 있는 loss function
    
## Adversirial loss
<img src="advers_loss.PNG"></img>



In [17]:
class SRGAN:
    def __init__(self, sess, checkpoint_dir, log_dir, img_dir, low_resolution_size, r, channel, 
                 feature_root = 64, batch_size = 1, lr = 0.0002, beta1 = 0.5, beta2 = 0.999,
                 dropout = 0.5, loss_type = 'vgg'):
        self.sess = sess
        self.checkpoint_dir = checkpoint_dir
        self.log_dir = log_dir
        self.img_dir = img_dir
        self.low_resolution_size = low_resolution_size
        self.real_image_size = low_resolution_size * r
        self.r = r
        self.channel = channel
        self.feature_root = 64
        self.batch_size = batch_size
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.dropout = dropout
        self.pool_k_size = 2
        self.pool_s_size = 2
        self.loss_type = loss_type
        # saver 저장
        
    def generator(self, images):
        # conv -> PReLU
        with tf.variable_scope('G_start') as scope:
            w = tf.get_variable('w', [9, 9, self.channel, self.feature_root],
                               initializer = tf.truncated_normal_initializer(stddev = 0.2))
            conv = tf.nn.conv2d(images, w, strides = [1,1,1,1], padding = 'same')
            b = tf.get_variable('b', [self.feature_root], 
                               initializer = tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv, a)
            prelu = PReLU(conv)
        
        skip_late = prelu
        skip = prelu
        
        # 5 B residual blocks
        for i in range(5):
            skip = B_residual_block(skip, self.feature_root, 3, is_training = True, 
                                    scope = 'G_b_res_block' + str(i))
            
        with tf.variable_scope('G_middle') as scope:
            w = tf.get_variable('w', [3, 3, self.feature_root, self.feature_root],
                               initializer = tf.truncated_normal_initializer(stddev = 0.2))
            conv = tf.nn.conv2d(skip, w, strides = [1,1,1,1], padding = 'same')
            b = tf.get_variable('b', [self.feature_root], 
                               initializer = tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv, a)
            bn = tf.contrib.layers.batch_norm(conv, is_training = is_training, scope = 'bn', 
                                           decay = 0.9, zero_debias_moving_mean = True)
            es = bn + skip_late
        
        last = last_block(es, self.feature_root*4, scope = 'G_last_block_0')
        last = last_block(last, self.feature_root*4, scope = 'G_last_block_1')
        
        with tf.variable_scope('G_last') as scope:
            w = tf.get_variable('w', [3, 3, self.feature_root, self.feature_root],
                               initializer = tf.truncated_normal_initializer(stddev = 0.2))
            conv = tf.nn.conv2d(last, w, strides = [1,1,1,1], padding = 'same')
            b = tf.get_variable('b', [self.feature_root], 
                               initializer = tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv, a)
        
        return conv
    
    def discriminator(self, images):
        # Conv -> LeakyReLU
        with tf.variable_scope('D_start') as scope:
            w = tf.get_variable('w', [3, 3, self.feature_root, self.feature_root],
                               initializer = tf.truncated_normal_initializer(stddev = 0.2))
            conv = tf.nn.conv2d(last, w, strides = [1,1,1,1], padding = 'same')
            b = tf.get_variable('b', [self.feature_root], 
                               initializer = tf.constant_initializer(0.0))
            conv = tf.nn.bias_add(conv, a)
            output = tf.nn.leaky_relu(conv, alpha = 0.2)
        
        features = self.feature_root
        pools = []
        for i in range(7):
            output = vgg_block(output, features, (i+1)%2+1, scope = 'D_disc_block' + str(i))
            if(i%2 == 0):
                pools.append(output)
                output = tf.nn.max_pool(output, ksize = [1, self.pool_k_size, self.pool_k_size, 1],
                        strides = [1, self.pool_s_size, self.pool_s_size, 1], padding = 'SAME')
                features = features * 2
                
        with tf.variable_scope('D_dense') as scope: 
            flat = tf.contrib.layers.flatten(output)
            dense = tf.layers.dense(flat, 1024, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer())
            dense = tf.nn.lrelu(dense)
            dense = tf.layers.dense(dense, 1, activation=None, kernel_initializer=tf.contrib.layers.xavier_initializer())
            dense = tf.nn.sigmoid(dense)
            
        return dense, pools
    
    def content_loss_mse(hr, g_lr):
        hr_flat = tf.layers.flatten(hr)
        g_lr_flat = tf.layers.flatten(g_lr)
        return tf.reduce_mean((hr_flat - g_lr_flat)**2)
    
    def content_loss_vgg(real_pools, fake_pools):
        total_flat = 0
        for vm_real, vm_fake in zip(real_pools, fake_pools):
            flat_real = tf.layers.flatten(vm_real)
            flat_fake = tf.layers.flatten(vm_fake)
            if total_flat == 0:
                total_flat = (flat_real - flat_fake)**2
            else:
                total_flat = tf.concate([total_flat, (flat_real - flat_fake)**2], 1)
        
        return tf.reduce_mean(total_flat)
    
    # ONLY for generator
    def adversarial_loss(fake_logits):
        return tf.reduce_mean(-tf.log(fake_logits))
            
    def build_model(self):
        real_images = tf.placeholder(tf.float32, name = 'real_images',
                                    shape = [None, self.real_image_size[0], self.real_image_size[1], self.channel])
        low_resolutions = tf.placeholder(tf.float32, name = 'low_resolutions',
                                    shape = [None, self.low_resolution_size[0], self.low_resolution_size[1], self.channel])
        
        real_logits, real_pools = discriminator(real_images)
        fake_images = generator(low_resolutions)
        fake_logits, fake_pools = discriminator(fake_images)
        
        d_loss = tf.reduce_mean(tf.log(real_logits) + tf.log(1-fake_logits))
        if self.loss_type == 'vgg':
            g_loss = content_loss_vgg(real_pools, fake_pools) + adversarial_loss(fake_logits)
        else:
            g_loss = content_loss_mse(real_image, fake_images) + adversarial_loss(fake_logits)
            
        tvar = tf.trainable_variables()
        dvar = [var for var in tvar if 'D' in var.name]
        gvar = [var for var in tvar if 'G' in var.name] 
        
        d_train_step = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = self.beta1, 
                                              beta2 = self.beta2).minimize(d_loss, var_list = dvar)
        g_train_step = tf.train.AdamOptimizer(learning_rate = self.lr, beta1 = self.beta1,
                                             beta2 = self.beta2).minimize(g_loss, var_list = gvar)
        