In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization,LeakyReLU,PReLU,Dense,Flatten,add,Conv2DTranspose,UpSampling2D
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Input
import numpy as np
from tensorflow.image import resize,ResizeMethod
import tensorflow.keras.backend as K
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import os
file_path = '/content/gdrive/My Drive/Colab Notebooks/SRGAN'
a=os.listdir('/content/gdrive/My Drive/Colab Notebooks/SRGAN')

In [None]:
#model architecture
def D_block(D,output_size,stride):
  D = Conv2D(filters=output_size, kernel_size=3,strides = stride, padding="same")(D)
  D = BatchNormalization()(D)
  D = LeakyReLU()(D)
  return D

def res_block(G):
  res = G
  G = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(G)
  G = BatchNormalization()(G)
  G = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None,shared_axes=[1,2])(G)
  G = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(G)
  G = BatchNormalization()(G)
  G = add([res,G])
  return G

class GAN:
  def __init__(self,image_shape,noise_shape):
    self.image_shape = image_shape
    self.noise_shape = noise_shape
    self.D = self.discriminator()
    self.G = self.generator()
 
  def discriminator(self):
    input = Input(shape=self.image_shape)
    D = Conv2D(filters=64, kernel_size=3, strides=1, padding="same")(input)
    D = LeakyReLU(alpha=0.2)(D)
    D = D_block(D,64,2)
    D = D_block(D,128,1)
    D = D_block(D,128,2)
    D = D_block(D,256,1)
    D = D_block(D,256,2)
    D = D_block(D,512,1)
    D = D_block(D,512,2)

    D = Flatten()(D)
    D = Dense(1024)(D)
    D = LeakyReLU(alpha=0.2)(D)

    D = Dense(1,activation = 'sigmoid')(D)
    D_model = tf.keras.Model(inputs=input,outputs=D)

    return D_model


  def generator(self):
    input = Input(shape=self.noise_shape)
    G = Conv2D(filters=64,kernel_size=9,strides=1,padding="same")(input)
    G = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(G)

    res = G

    for index in range(16):
      G = res_block(G)
    
    G = Conv2D(filters=64,kernel_size=3,strides=1,padding="same")(G)
    G = BatchNormalization()(G)
    G = add([res,G])

    for index in range(2):
      G = Conv2D(filters=256,kernel_size=3,strides=1,padding="same")(G)
      #G = Conv2DTranspose(filters= ,kernel_size = ,stride = ,padding = )(G)
      G = UpSampling2D(size=(2,2))(G)
      G = PReLU(alpha_initializer='zeros', alpha_regularizer=None, alpha_constraint=None, shared_axes=[1,2])(G)
    
    G = Conv2D(filters=3,kernel_size=9,strides=1,padding="same")(G)

    G_model = tf.keras.Model(inputs=input,outputs=G)
    return G_model

def combined(discriminator,generator):
  gan_input = Input(shape=(None,None,3))
  SR = generator(gan_input)
  gan_output = discriminator(SR)
  gan = tf.keras.Model(inputs=gan_input,outputs=[SR,gan_output])
  lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
      initial_learning_rate = 1e-4,
      decay_steps=1e5,
      decay_rate=0.1)
  adam = Adam(learning_rate = lr_schedule)
  gan.compile(loss=[content_loss,'binary_crossentropy'], loss_weights=[1,1e-3],optimizer='adam')

  return gan


In [None]:
#vgg19 block5_conv4의 feature map의 MSE 를 loss로 이용했다. content loss
from tensorflow.keras.applications import VGG19

vgg19 = VGG19(include_top=False, weights='imagenet', input_shape=(256,256,3))
vgg19.trainable = False
for layer in vgg19.layers:
  layer.trainable = False
feature_map = tf.keras.Model(inputs = vgg19.input, outputs=vgg19.get_layer('block5_conv4').output)
feature_map.trainable = False

@tf.function()
def content_loss(y_true, y_pred):
  return K.mean(K.square(feature_map(y_true) - feature_map(y_pred)))




In [None]:
#데이터 준비
#hr_train = 이미지파일 load
#lr_train = 이미지 파일 load
#hr이미지는 [-1,1] 범위를 갖도록
hr_train_norm = hr_train.astype(np.float32)/127.5 - 1
#lr 이미지는 [0,1] 범위를 갖도록
lr_train_norm = lr_train.astype(np.float32)/255.0
'''
def hr_to_lr(hr):
  lr = []
  for i in range(len(hr)):
    lr.append(resize(hr[i],[64,64],method=ResizeMethod.BICUBIC).numpy())

  return lr

def data_preprocess(hr,lr):
  #hr은 [-1,1]을 갖도록
  for i in range(0,800):
    hr[i] = (hr[i]/127.5)-1
    lr[i] = (lr[i]/255.0)
  #ㅣr은 [0,1]을 갖도록
  
  return hr, lr

hr_train = np.load(file_path+'/hr_train.npy')
lr_train = np.load(file_path+'/lr_train.npy')
'''

In [None]:
noise_shape=(64,64,3)
image_shape=(256,256,3)

model = GAN(image_shape,noise_shape)

#pre-trained SRResNet(generator)
#논문에서는 pretrain iteration 1e6
generator = model.G
adam = Adam(learning_rate=1e-4,beta_1= 0.9)
generator.compile(loss=content_loss,optimizer=adam)
discriminator = model.discriminator()
discriminator.compile(loss='binary_crossentropy',optimizer = adam)
discriminator.trainable = False
if os.path.isfile(file_path + '/generator_weights.h5'):
  generator.load_weights(file_path + '/generator_weights.h5')

for i in range(1000):
  print('iterate : ', i)
  generator.fit(x=lr_train_norm,y=hr_train_norm,batch_size=16,epochs=100,verbose=1)
  generator.save_weights(file_path + '/generator_weights.h5')
  sr_predict = generator.predict(np.array([lr_train_norm[0]]))
  plt.imshow((sr_predict+1)/2)
  plt.show()

In [None]:
#SRGAN training
noise_shape=(64,64,3)
image_shape=(256,256,3)

gan = combined(discriminator,generator)

epochs = 100
batch_size = 8
# 1e5 동안 learning rate = 1e-4
# 1e5 동안 learning rate = 1e-5
for i in range(1,epochs+1):
  print('#'*15, "Epochs : ", i, '#'*15)
  for count in range(0,round(hr_train_norm.shape[0]/batch_size)):
    #끝부분에서 이상하게 되겠는데

    batch=np.random.randint(0,hr_train_norm.shape[0],size=batch_size)
    hr_batch = hr_train_norm[batch]
    lr_batch = lr_train_norm[batch]
    

    sr_batch = generator.predict(lr_batch)

    #discriminator 학습
    discriminator.trainable = True
    d_loss_real = discriminator.train_on_batch(hr_batch,np.ones((batch_size,1)))
    d_loss_fake = discriminator.train_on_batch(sr_batch,np.zeros((batch_size,1)))

    discriminator.trainable = False
    
    
    #discriminator의 loss로 다시 generator 학습
    # D 에서 label이 1 이라고 했을 떄의 loss를 구해서 그 차이를 줄이는 방향으로 G를 학습
    gan_loss = gan.train_on_batch(lr_batch,[hr_batch,np.ones((batch_size,1))])
    

  print('d_loss_real : ',d_loss_real)
  print('d_loss_fake : ',d_loss_fake)
  print('gan_loss : ',gan_loss)

  generator.save_weights(file_path + '/weights/gen_weights.h5')
  discriminator.save_weights(file_path + '/weights/dis_weights.h5')
  gan.save_weights(file_path + '/weights/gan_weights.h5')