# LR image 放大4倍

In [1]:
import keras
import numpy as np

Using TensorFlow backend.


# Discriminator
![SRGAN](https://github.com/deepak112/Keras-SRGAN/raw/master/Architecture_images/network.jpg)  

Discriminator <font color='#ff0000'>__非conditionalGAN__</font> (https://github.com/deepak112/Keras-SRGAN/blob/master/train.py 是一般GAN)    

In [2]:
def conv_bn_relu(layer_in,filter_size,filters,stride):
    x=keras.layers.Conv2D(filters,(filter_size,filter_size),strides=stride,padding='same')(layer_in)
    x=keras.layers.BatchNormalization(momentum=0.5)(x)
    x=keras.layers.LeakyReLU(alpha=0.2)(x)
    return x

In [3]:
def discriminator(input_shape):
    input_SR=keras.models.Input(shape=input_shape)
    
    x=keras.layers.Conv2D(64,(3,3),strides=1,padding='same')(input_SR)
    x=keras.layers.LeakyReLU(alpha=0.2)(x)
    x=conv_bn_relu(x,3,64,2)
    x=conv_bn_relu(x,3,128,1)
    x=conv_bn_relu(x,3,128,2)
    x=conv_bn_relu(x,3,256,1)
    x=conv_bn_relu(x,3,256,2)
    x=conv_bn_relu(x,3,512,1)
    x=conv_bn_relu(x,3,512,2)
    x=keras.layers.Flatten()(x)
    x=keras.layers.Dense(1024)(x)
    x=keras.layers.LeakyReLU(alpha=0.2)(x)
    x=keras.layers.Dense(1)(x)
    x=keras.layers.Activation('sigmoid')(x)
    
    model=keras.models.Model(inputs=input_SR,outputs=x,name='discriminator')
    model.compile(loss='binary_crossentropy',optimizer=keras.optimizers.Adam(lr=0.0001,beta_1=0.9))
    return model

In [4]:
d_model=discriminator((384,384,3))
d_model.summary()

Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where
Model: "discriminator"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 384, 384, 3)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 384, 384, 64)      1792      
_________________________________________________________________
leaky_re_lu_1 (LeakyReLU)    (None, 384, 384, 64)      0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 192, 192, 64)      36928     
_________________________________________________________________
batch_normalization_1 (Batch (None, 192, 192, 64)      256       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 192, 192, 64)      0         
_________________________

# Generator

In [5]:
def residual_block(layer_in,filter_size,filters,stride):
    x=keras.layers.Conv2D(filters,(filter_size,filter_size),strides=stride,padding='same')(layer_in)
    x=keras.layers.BatchNormalization(momentum=0.5)(x)
    x=keras.layers.PReLU(shared_axes=[1,2])(x)
    x=keras.layers.Conv2D(filters,(filter_size,filter_size),strides=stride,padding='same')(x)
    x=keras.layers.BatchNormalization(momentum=0.5)(x)
    x=keras.layers.Add()([x,layer_in])
    return x

In [6]:
def upsampling_block(layer_in,filter_size,filters,stride):
    x=keras.layers.Conv2D(filters,(filter_size,filter_size),strides=stride,padding='same')(layer_in)
    x=keras.layers.UpSampling2D(size=2)(x)
    x=keras.layers.PReLU(shared_axes=[1,2])(x)
    return x

In [7]:
def generator(input_shape):
    input_LR=keras.models.Input(shape=input_shape)
    x=keras.layers.Conv2D(64,(9,9),strides=1,padding='same')(input_LR)
    x=keras.layers.PReLU(shared_axes=[1,2])(x)
    skip=x
    for i in range(16):
        x=residual_block(x,3,64,1)
    x=keras.layers.Conv2D(64,(3,3),strides=1,padding='same')(x)
    x=keras.layers.BatchNormalization(momentum=0.5)(x)
    x=keras.layers.Add()([x,skip])
    for i in range(2):
        x=upsampling_block(x,3,256,1)
    x=keras.layers.Conv2D(3,(9,9),strides=1,padding='same')(x)
    x=keras.layers.Activation('tanh')(x)
    
    model=keras.models.Model(inputs=input_LR,outputs=x,name='generator')
    return model

In [8]:
g_model=generator((96,96,3))
g_model.summary()

Model: "generator"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
conv2d_9 (Conv2D)               (None, 96, 96, 64)   15616       input_2[0][0]                    
__________________________________________________________________________________________________
p_re_lu_1 (PReLU)               (None, 96, 96, 64)   64          conv2d_9[0][0]                   
__________________________________________________________________________________________________
conv2d_10 (Conv2D)              (None, 96, 96, 64)   36928       p_re_lu_1[0][0]                  
__________________________________________________________________________________________

# Composite model
![SRGAN](https://1.bp.blogspot.com/-Y5bRR5aX4pI/XZMAuwcG5uI/AAAAAAAAEsQ/cDA7XKvQ-CIvNmXGk_SWb_Yg55hEFy5fQCLcBGAsYHQ/s1600/IMG_2377.JPG)  
圖錯，代修正  
<font color='#00aa00'>__D的input只有一個(非conditionalGAN)__  
__vgg的input只有一個，vgg(HR)為ground_truth__  </font>


### import vgg19

In [9]:
import ssl
from keras.applications.vgg19 import VGG19
ssl._create_default_https_context = ssl._create_unverified_context
vgg19=VGG19(include_top=False,weights='imagenet',input_shape=(384,384,3))
vgg=keras.models.Model(inputs=vgg19.input,outputs=vgg19.get_layer('block5_conv4').output,name='vgg')
vgg.trainable=False
#for layer in vgg.layers:
#    layer.trainable=False
vgg.summary()

Model: "vgg"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_3 (InputLayer)         (None, 384, 384, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 384, 384, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 384, 384, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 192, 192, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 192, 192, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 192, 192, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 96, 96, 128)       0       

### composite model = generator + discriminator + vgg19

In [10]:
def composite(g_model,d_model,input_shape_LR):
    d_model.trainable=False
    input_LR=keras.models.Input(shape=input_shape_LR)
    sr_out=g_model(input_LR)
    dis_out=d_model(sr_out)
    vgg_out=vgg(sr_out)
    model=keras.models.Model(inputs=input_LR,outputs=[dis_out,vgg_out],name='composite')
    model.compile(loss=['binary_crossentropy','mse'],optimizer=keras.optimizers.Adam(lr=0.0001,beta_1=0.9),
                  loss_weights=[1,1000])
    return model

In [11]:
c_model=composite(g_model,d_model,(96,96,3))
c_model.summary()

Model: "composite"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 96, 96, 3)    0                                            
__________________________________________________________________________________________________
generator (Model)               (None, 384, 384, 3)  2044291     input_4[0][0]                    
__________________________________________________________________________________________________
discriminator (Model)           (None, 1)            306684737   generator[1][0]                  
__________________________________________________________________________________________________
vgg (Model)                     (None, 24, 24, 512)  20024384    generator[1][0]                  
Total params: 328,753,412
Trainable params: 2,040,067
Non-trainable params: 326,713,345
__

### trainable
注意一下各個model的trainable屬性  
<font color='#ff0000'>__model trainable v.s. layer trainable:__</font> https://stackoverflow.com/questions/56675964/what-is-the-difference-between-setting-a-keras-model-trainable-vs-making-each-la/56682979#56682979?newreg=61cf9299fb974eb9999d473a9cac52ea  

In [12]:
print(vgg.trainable)
print(vgg.layers[1].trainable)
print(vgg.layers)

False
True
[<keras.engine.input_layer.InputLayer object at 0x7ff1c0468b00>, <keras.layers.convolutional.Conv2D object at 0x7ff1c0468be0>, <keras.layers.convolutional.Conv2D object at 0x7ff1c0468e48>, <keras.layers.pooling.MaxPooling2D object at 0x7ff1c04018d0>, <keras.layers.convolutional.Conv2D object at 0x7ff1c0401588>, <keras.layers.convolutional.Conv2D object at 0x7ff1c0415860>, <keras.layers.pooling.MaxPooling2D object at 0x7ff1c0426a20>, <keras.layers.convolutional.Conv2D object at 0x7ff1c04265f8>, <keras.layers.convolutional.Conv2D object at 0x7ff1c04382b0>, <keras.layers.convolutional.Conv2D object at 0x7ff1c03ceac8>, <keras.layers.convolutional.Conv2D object at 0x7ff1c03de9b0>, <keras.layers.pooling.MaxPooling2D object at 0x7ff1c03f2898>, <keras.layers.convolutional.Conv2D object at 0x7ff1c03f2470>, <keras.layers.convolutional.Conv2D object at 0x7ff1c0389eb8>, <keras.layers.convolutional.Conv2D object at 0x7ff1c039d940>, <keras.layers.convolutional.Conv2D object at 0x7ff1c03b1

In [13]:
print(d_model.trainable)
print(g_model.trainable)
print(c_model.trainable)

False
True
True


# Fit

In [14]:
'''just like real_image_generator'''
def hr_image_generator(dataset,randi):
    HRs=dataset[randi]
    ones=np.ones((len(HRs),1))
    return HRs,ones

In [15]:
'''just like fake_image_generator'''
def sr_image_generator(g_model,dataset,randi):
    LRs=dataset[randi]
    SRs=g_model.predict(LRs)
    zeros=np.zeros((len(SRs),1))
    return SRs,zeros

In [16]:
def lr_image_egnerator(dataset,randi):
    LRs=dataset[randi]
    return LRs

In [17]:
'''vgg(HR),high level feature maps'''
def vgg_of_hr(HRs):
    return vgg.predict(HRs)

In [18]:
def fit(g_model,d_model,c_model,dataset_lr,dataset_hr,input_shape=(96,96,3),epochs=1000,batch_size=64):
    assert len(dataset_hr)==len(dataset_lr)
    batches=len(dataset_lr)//batch_size
    for epoch in range(1,epochs+1):
        for batch in range(1,batches+1):
            #fit discriminator
            randi=np.random.randint(0,len(dataset_lr),batch_size)  
            HRs,true_y=hr_image_generator(dataset_hr,randi)
            loss_hr=d_model.train_on_batch(HRs,true_y)
            SRs,false_y=sr_image_generator(g_model,dataset_lr,randi)
            loss_sr=d_model.train_on_batch(SRs,false_y)
            #fit generator
            randi=np.random.randint(0,len(dataset_lr),batch_size)
            LRs=lr_image_generator(dataset_lr,randi)
            HRs,true_y=hr_image_generator(dataset_hr,randi)
            loss_perceptual,_,_=c_model.train_on_batch(LRs,[true_y,vgg_of_hr(HRs)])
        print('epoch {}/{}: loss_hr:{:.2f} loss_sr:{:.2f} loss_perceptual:{:.2f}'.format(epoch,epochs,loss_hr,loss_sr,loss_perceptual))
        if epoch%10==0:
            summarize_performance(epoch,g_model,d_model,dataset_lr,dataset_hr)

In [19]:
def summarize_performance(epoch,g_model,d_model,dataset_lr,dataset_hr):
    #save model
    g_model.save('SRGAN//generator {}.h5'.format(epoch))
    d_model.save('SRGAN//discriminator {}.h5'.format(epoch))
    #save picture
    randi=np.random.randint(0,len(dataset_lr),5)
    LRs=lr_image_generator(dataset_lr,randi)
    SRs=sr_image_generator(g_model,dataset_lr,randi)
    HRs=hr_image_generator(dataset_hr,randi)
    SRs=(SRs+1)/2
    HRs=(HRs+1)/2
    plt.figure(figsize=(15,9))
    for i in range(5):
        plt.subplot(3,5,i+1)
        plt.imshow(LRs[i])
        plt.axis('off')
        plt.subplot(3,5,i+6)
        plt.imshow(SRs[i])
        plt.axis('off')
        plt.subplot(3,5,i+11)
        plt.imshow(HRs[i])
        plt.axis('off')
    plt.savefig('SRGAN//epoch {}.png'.format(epoch))
    plt.close()

# load dataset
using the COCO(2014) dataset  
PIL <-> ndarray https://stackoverflow.com/questions/384759/how-to-convert-a-pil-image-into-a-numpy-array  

In [20]:
import os
from PIL import Image
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import load_img
import matplotlib.pyplot as plt

In [21]:
def load_image(path,lr_size=(96,96),hr_size=(384,384),pic=1000):
    #(None,96,96,3)
    LRs=[]
    #(None,384,384,3)
    HRs=[]
    i=0
    for filename in os.listdir(path):
        img=load_img(os.path.join(path,filename))
        if img.width<384 or img.height<384:
            continue
        img=img_to_array(load_img(os.path.join(path,filename),target_size=lr_size,interpolation='bicubic'),dtype='float')
        LRs.append(img)
        img=img_to_array(load_img(os.path.join(path,filename),target_size=hr_size,interpolation='bicubic'),dtype='float')
        HRs.append(img)
        i+=1
        if i>=pic:
            break
    return np.asarray(LRs),np.asarray(HRs)

### save .npz with 800 images for training

In [22]:
data=np.load('SRGANtrainingset.npz')
trainLR,trainHR=data['arr_0'],data['arr_1']
print(trainLR.shape,trainHR.shape)

(800, 96, 96, 3) (800, 384, 384, 3)


In [23]:
print(trainLR[0,:4,:4,:])

[[[34. 17. 10.]
  [32. 16.  9.]
  [27. 14.  7.]
  [20. 10.  7.]]

 [[34. 17. 10.]
  [34. 18. 10.]
  [30. 15.  8.]
  [24. 12.  7.]]

 [[34. 18. 10.]
  [35. 19. 12.]
  [34. 18. 11.]
  [29. 15.  8.]]

 [[36. 18. 11.]
  [36. 18. 11.]
  [27. 14.  8.]
  [27. 16. 10.]]]


### from \[0,255\] to \[-1,1\] or \[0,1\]
Paper says LR be \[0,1\], HR be \[-1,1\]  

In [24]:
trainLR=trainLR/255
trainHR=(trainHR-127.5)/127.5

In [25]:
#already define
#g_model=generator((96,96,3))
#d_model=discriminator((384,384,3))
#c_model=composite(g_model,d_model,(96,96,3))
fit(g_model,d_model,c_model,trainLR,trainHR,input_shape=(96,96,3),epochs=1000,batch_size=64)

  'Discrepancy between trainable weights and collected trainable'


ResourceExhaustedError:  OOM when allocating tensor of shape [32,192,192,128] and type float
	 [[{{node gradients/batch_normalization_2/cond_grad/If/else/_307/zeros_like}}]] [Op:__inference_keras_scratch_graph_18044]

Function call stack:
keras_scratch_graph
