In [None]:
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.losses import *
from tensorflow.keras.applications import VGG19
import tensorflow.keras.backend as K
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import joblib, os, cv2, time

In [None]:
# take in training data

os.chdir('train')
rawX = joblib.load('lrImgs.sav')
rawX2 = joblib.load('hrImgs.sav')
os.chdir('..')

m = rawX.shape[0]
batchSize = 16
X = tf.data.Dataset.from_tensor_slices(rawX).batch(batchSize)
X2 = tf.data.Dataset.from_tensor_slices(rawX2).batch(batchSize)
listX = list(X.as_numpy_iterator())
listY = list(y.as_numpy_iterator())

In [None]:
def buildGBlock(inp):
  cv1 = Conv2D(64, 3, padding='same')(inp)
  bn1 = BatchNormalization()(cv1)
  r1 = LeakyReLU()(bn1)
  cv2 = Conv2D(64, 3, padding='same')(r1)
  bn2 = BatchNormalization()(cv2)
  pr2 = Add()([inp, bn2])
  return pr2

def buildGBlockNoBN(inp):
  cv1 = Conv2D(64, 3, padding='same')(inp)
  r1 = LeakyReLU()(cv1)
  cv2 = Conv2D(64, 3, padding='same')(r1)
  add = Add()([inp, cv2])
  return add

def buildDBlock(inp):
  cv1 = Conv2D(32, 3, strides=2, padding='same')(inp)
  bn1 = BatchNormalization()(cv1)
  lr1 = LeakyReLU()(bn1)
  d1 = Dropout(0.2)(lr1)
  return d1

In [None]:
'''
Building model architectures - a far cry from that in the paper, but it worked for me.

Changes:
No PReLU - too many parameters and not much improvement
No PixelShuffle - there isn't a Keras layer for that and PixelShuffle didn't
help too much anyway
Conv2DTranspose and UpSampling2D layers - Conv2D gave ugly dots on the SR image,
UsSampling2D didn't give much detail, both worked when together
No BN after upsampling - Make sure that the dots from the deconv. layers 
wouldn't be as noticeable in the final SR image
'''

def genGen():
  inp = Input((64, 64, 3))
  layer = Conv2D(64, 5, padding='same')(inp)
  layer = LeakyReLU()(layer)
  us1 = UpSampling2D()(layer)
  us2 = UpSampling2D()(us1)

  for i in range(2):
    layer = buildGBlock(layer)

  layer = Conv2DTranspose(64, 3, strides=2, padding='same')(layer)
  layer = LeakyReLU()(layer)
  layer = Add()([layer, us1])
  layer = Conv2DTranspose(64, 3, strides=2, padding='same')(layer)
  layer = LeakyReLU()(layer)
  layer = Add()([layer, us2])
  
  for i in range(2):
    layer = buildGBlockNoBN(layer)
    layer = Add()([layer, us2])

  output = Conv2D(3, 5, padding='same', activation='sigmoid')(layer)

  generator = Model(inp, output, name='generator')
  return generator

def genDisc():
  inp = Input((128, 128, 3))
  layer = Conv2D(32, 3, padding='same')(inp)
  layer = LeakyReLU()(layer)

  for i in range(5):
    layer = buildDBlock(layer)

  flat = Flatten()(layer)
  output = Dense(1, activation='sigmoid')(flat)
  discriminator = Model(inp, output, name='discriminator')

  return discriminator

def build_vgg():
  vgg = VGG19(input_shape=(128, 128, 3), include_top=False, weights="imagenet")
  vgg.outputs = [vgg.layers[6].output]
  inputLayer = vgg.layers[0].output

  return Model(inputLayer, vgg.outputs)

In [None]:
bce = BinaryCrossentropy()

def discLoss(ytrue, yfake):
  global bce
  trueLoss = bce(tf.ones_like(ytrue), ytrue)
  fakeLoss = bce(tf.zeros_like(yfake), yfake)
  return K.mean(trueLoss + fakeLoss)

def genLoss(yfake, y, fakeImgs):
  global bce
  alpha = 0.2 # not found in the paper, but allowed me to change how dependent the loss was on advesarial or perceptual loss
  mse = MeanSquaredError()
  recLoss = (1 - alpha) * mse(y, fakeImgs)
  adLoss = alpha * K.mean(bce(tf.ones_like(yfake), yfake))
  return recLoss + adLoss


In [None]:
def step(batch, y):
  global genModel, discModel, vgg, genOpt, discOpt
  with tf.GradientTape() as gtape, tf.GradientTape() as dtape:
    fakes = genModel(batch, training=True)
    truePreds = discModel(y, training=True)
    fakePreds = discModel(fakes, training=True)
    fakeMean, trueMean = K.mean(fakePreds).numpy(), K.mean(truePreds).numpy()
    print('Fake Preds: {} | True Preds: {}'.format(fakeMean, trueMean))
    trueVGG = vgg(y, training=False)
    fakeVGG = vgg(fakes, training=False)

    dloss = discLoss(truePreds, fakePreds)
    gloss = genLoss(fakePreds, trueVGG, fakeVGG)

    # cripple the generator/discriminator if they got too good, made adversarial training more stable
    if fakeMean < 0.7:
      gradGen = gtape.gradient(gloss, genModel.trainable_variables)
      genOpt.apply_gradients(zip(gradGen, genModel.trainable_variables))
    if trueMean < 0.7 or fakeMean > 0.3:
      gradDisc = dtape.gradient(dloss, discModel.trainable_variables)
      discOpt.apply_gradients(zip(gradDisc, discModel.trainable_variables))
    
  return dloss, gloss
 
def train(X, y, epochs, steps=1000):
  global m, batchSize, listX, listY
  for i in range(epochs):
    dcost = 0
    gcost = 0
    gloss = 0
    for batch in range(steps):
      batchInd = np.random.randint(low=0, high=m//batchSize)
      batchX = listX[batchInd]
      batchY = listY[batchInd]
      dloss, gloss = step(batchX, batchY, gloss)
      print('Batch: {} | Discriminator Batch Loss: {} | Generator Batch Loss: {}'.format(batch, dloss, gloss))

      dcost += dloss
      gcost += gloss

    print('\n-----Epoch: {} | Discriminator Cost: {} | Generator Cost: {}-----\n'.format(i, dcost, gcost))

In [None]:
vgg = build_vgg()
vgg.trainable = False

genModel = genGen()
discModel = genDisc()

# load in trained model
'''
if tf.__version__ == '2.2.0':
  genModel = tf.keras.models.load_model('models/tf_220/srGAN/gen')
  discModel = tf.keras.models.load_model('models/tf_220/srGAN/disc')
else:
  genModel = tf.keras.models.load_model('models/tf_230/srGAN/gen')
  discModel = tf.keras.models.load_model('models/tf_230/srGAN/disc')
'''

In [None]:
genOpt = Adam(learning_rate=1e-4) # Adam's my guy
discOpt = Adam(learning_rate=1e-4)

In [None]:
# train the model and show its results

while True:
  rows, cols = 3, 5
  fig = plt.figure(figsize=(30, 15))
  axes = fig.subplots(rows, cols)
  for i in range(cols):
    if i % 2 == 0:
      predInput = np.array([rawX[i]])
      pred = genModel.predict(predInput)[0]
    
      axes[0][i].imshow(rawX[i])
      axes[1][i].imshow(pred)
      axes[2][i].imshow(rawX2[i])
    else:
      randI = np.random.randint(low=0, high=m)
      predInput = np.array([rawX[randI]])
      pred = genModel.predict(predInput)[0]

      axes[0][i].imshow(rawX[randI])
      axes[1][i].imshow(pred)
      axes[2][i].imshow(rawX2[randI])

  plt.show()
  train(X, X2, 1)

In [None]:
now = time.time()
if tf.__version__ == '2.2.0':
  genModel.save('models/tf_220/srGAN_{}_{}/gen'.format(mode, now))
  discModel.save('models/tf_220/srGAN_{}_{}/disc'.format(mode, now))
else:
  genModel.save('models/tf_230/srGAN_{}_{}/gen'.format(mode, now))
  discModel.save('models/tf_230/srGAN_{}_{}/disc'.format(mode, now))