# Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network

尽管卷积神经网络在单幅图像超分辨率的准确性和速度方面有很大的突破，但是在大尺度缩放图像上进行超分辨率重建时，如何恢复更精细的纹理特征的问题还没有解决。最近的工作主要集中在最小化均方重建误差，由此产生的估计具有较高的峰值信噪比，但是它们通常缺乏高频细节，并没有达到较高分辨率所期望的逼真度，感觉上不能令人满意。  
这篇文章提出SRGAN，图像超分辨率（SR）的生成对抗网络（GAN）。它是第一个能推出4倍放大因子的照片般逼真的自然图像的框架。

## 摘要  
1 输入数据和对应label：原始图像作为下采样4倍的图像作为输入x，原始图像作为对应的label  
2 用一个残差网络作为Generator，输入 x 经过该网络之后，输出就是对应的高分辨率图像  
3 用一个分类网络作为Discriminator，用来对原始图像和生成的高分辨率图像进行分类  
4 用VGG网络分别提取原始图像和生成的高分辨率图像的高层特征，引入一个高层感知损失函数来优化Generator

## 1 数据准备  

1 t_target_image：原始图像随机裁剪成统一大小之后的图像  
2 t_image：t_target_image下采样四倍的输入图像

In [None]:
t_image = tf.placeholder('float32', [batch_size, 96, 96, 3], name='t_image_input_to_SRGAN_generator')
t_target_image = tf.placeholder('float32', [batch_size, 384, 384, 3], name='t_target_image')

## 2 应用Generator和Discriminator  

1 SRGAN_g：根据输入的低分辨率图像生成对应的高分辨率图像  
2 SRGAN_d：输出生成的图像和原始高分辨率图像的分类结果  
3 Vgg19_simple_api：图像缩放以适应VGG网络的输入格式，分别提取生成的图像和原始高分辨率图像的高层特征

In [None]:
net_g = SRGAN_g(t_image, is_train=True, reuse=False)
net_d, logits_real = SRGAN_d(t_target_image, is_train=True, reuse=False)
_,     logits_fake = SRGAN_d(net_g.outputs, is_train=True, reuse=True)
t_target_image_224 = tf.image.resize_images(t_target_image, size=[224, 224], method=0, align_corners=False) # resize_target_image_for_vgg # http://tensorlayer.readthedocs.io/en/latest/_modules/tensorlayer/layers.html#UpSampling2dLayer
t_predict_image_224 = tf.image.resize_images(net_g.outputs, size=[224, 224], method=0, align_corners=False) # resize_generate_image_for_vgg

net_vgg, vgg_target_emb = Vgg19_simple_api((t_target_image_224+1)/2, reuse=False)
_, vgg_predict_emb = Vgg19_simple_api((t_predict_image_224+1)/2, reuse=True)

## 3 损失函数构建  

1 d_loss：优化Discriminator网络的损失函数，logits_real是真实高分辨率图像，logits_fake是生成的图像，Discriminator的作用就是尽可能的区分出这两种图像  
2 g_loss：优化GAN的损失函数，g_gan_loss将生成的图像作为真实的图像输入Discriminator，用来欺骗Discriminator，mse_loss是生成图像与真实图像的均方差损失，用来优化Generator网络，vgg_loss是引入的感知损失函数，用来优化Generator网络，使生成的图像更加自然

In [None]:
d_loss1 = tl.cost.sigmoid_cross_entropy(logits_real, tf.ones_like(logits_real), name='d1')
d_loss2 = tl.cost.sigmoid_cross_entropy(logits_fake, tf.zeros_like(logits_fake), name='d2')
d_loss = d_loss1 + d_loss2

g_gan_loss = 1e-3 * tl.cost.sigmoid_cross_entropy(logits_fake, tf.ones_like(logits_fake), name='g')
mse_loss = tl.cost.mean_squared_error(net_g.outputs , t_target_image, is_mean=True)
vgg_loss = 2e-6 * tl.cost.mean_squared_error(vgg_predict_emb.outputs, vgg_target_emb.outputs, is_mean=True)

g_loss = mse_loss + vgg_loss + g_gan_loss

## 4 训练网络  

1 g_optim_init：用来初始化Generator网络  
2 g_optim，d_optim：训练GAN

In [None]:
## Pretrain
g_optim_init = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(mse_loss, var_list=g_vars)
## SRGAN
g_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(g_loss, var_list=g_vars)
d_optim = tf.train.AdamOptimizer(lr_v, beta1=beta1).minimize(d_loss, var_list=d_vars)

###========================= initialize G ====================###
sess.run(tf.assign(lr_v, lr_init))
print(" ** fixed learning rate: %f (for init G)" % lr_init)
for epoch in range(0, n_epoch_init+1):
    epoch_time = time.time()
    total_mse_loss, n_iter = 0, 0
    ## If your machine have enough memory, please pre-load the whole train set.
    for idx in range(0, len(train_hr_imgs), batch_size):
        step_time = time.time()
        b_imgs_384 = tl.prepro.threading_data(
                train_hr_imgs[idx : idx + batch_size],
                fn=crop_sub_imgs_fn, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
        ## update G
        errM, _ = sess.run([mse_loss, g_optim_init], {t_image: b_imgs_96, t_target_image: b_imgs_384})

        
###========================= train GAN (SRGAN) =========================###
for epoch in range(0, n_epoch+1):    
    epoch_time = time.time()
    total_d_loss, total_g_loss, n_iter = 0, 0, 0

    ## If your machine have enough memory, please pre-load the whole train set.
    for idx in range(0, len(train_hr_imgs), batch_size):
        step_time = time.time()
        b_imgs_384 = tl.prepro.threading_data(
                train_hr_imgs[idx : idx + batch_size],
                fn=crop_sub_imgs_fn, is_random=True)
        b_imgs_96 = tl.prepro.threading_data(b_imgs_384, fn=downsample_fn)
        ## update D
        errD, _ = sess.run([d_loss, d_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
        ## update G
        errG, errM, errV, errA, _ = sess.run([g_loss, mse_loss, vgg_loss, g_gan_loss, g_optim], {t_image: b_imgs_96, t_target_image: b_imgs_384})
