In [None]:
import random 
import glob 
import subprocess 
import os 
from PIL import Image
import numpy as np 
from tensorflow.keras.models import Sequential
from tensorflow.keras import layers
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import Callback
import wandb 
from wandb.keras import WandbCallback
from tensorflow.keras.optimizers import Adam
import tensorflow as tf

In [None]:
run = wandb.init(project='superres') 
config = run.config

config.num_epochs    = 50 
config.batch_size    = 32 #16, 64, 128
config.input_height  = 32
config.input_width   = 32
config.output_height = 256
config.output_width  = 256

val_dir = '/content/drive/MyDrive/superres/data/test'
train_dir = '/content/drive/MyDrive/superres/data/train'

# automatically get the data if it doesn't exist
if not os.path.exists("data"):
    print("Downloading flower dataset...")
    subprocess.check_output(
        "mkdir data && curl https://storage.googleapis.com/wandb/flower-enhance.tar.gz | tar xzf - -C data", shell=True)

config.steps_per_epoch = len(
    glob.glob(train_dir + "/*-in.jpg")) // config.batch_size
config.val_steps_per_epoch = len(
    glob.glob(val_dir + "/*-in.jpg")) // config.batch_size




In [None]:
def image_generator(batch_size, img_dir):
    """A generator that returns small images and large images.  DO NOT ALTER the validation set"""
    input_filenames = glob.glob(img_dir + "/*-in.jpg")
    counter = 0
    random.shuffle(input_filenames)
    while True:
        small_images = np.zeros(
            (batch_size, config.input_width, config.input_height, 3))
        large_images = np.zeros(
            (batch_size, config.output_width, config.output_height, 3))
        if counter+batch_size >= len(input_filenames):
            counter = 0
        for i in range(batch_size): 
            img = input_filenames[counter + i]
            small_images[i] = np.array(Image.open(img)) / 255.0
            large_images[i] = np.array(
                Image.open(img.replace("-in.jpg", "-out.jpg"))) / 255.0
        yield (small_images, large_images)
        counter += batch_size
        

In [None]:
def perceptual_distance(y_true, y_pred):
    """Calculate perceptual distance, DO NOT ALTER"""
    y_true *= 255
    y_pred *= 255
    rmean = (y_true[:, :, :, 0] + y_pred[:, :, :, 0]) / 2
    r = y_true[:, :, :, 0] - y_pred[:, :, :, 0]
    g = y_true[:, :, :, 1] - y_pred[:, :, :, 1]
    b = y_true[:, :, :, 2] - y_pred[:, :, :, 2]

    return K.mean(K.sqrt((((512+rmean)*r*r)/256) + 4*g*g + (((767-rmean)*b*b)/256)))

In [None]:
class ImageLogger(Callback):
    def on_epoch_end(self, epoch, logs):
        preds = self.model.predict(in_sample_images)
        in_resized = []
        for arr in in_sample_images:
            # Simple upsampling
            in_resized.append(arr.repeat(8, axis=0).repeat(8, axis=1))
        wandb.log({
            "examples": [wandb.Image(np.concatenate([in_resized[i] * 255, o * 255, out_sample_images[i] * 255], axis=1)) for i, o in enumerate(preds)]
        }, commit=False)
        if logs['val_perceptual_distance'] < 45:
          self.model.stop_training = True
        if epoch==30:
          K.set_value(self.model.optimizer.lr, 5e-5)
        if epoch==300:
          K.set_value(self.model.optimizer.lr, 2e-5)


In [None]:
train_generator = image_generator(config.batch_size, train_dir)
val_generator ``= image_generator(config.batch_size, val_dir)
in_sample_images, out_sample_images = next(val_generator)

feature_size   = 256
num_layers     = 32
scaling_factor = 0.1

In [None]:
def resBlock(x, channels = 256):
  tmp = layers.Conv2D(channels, (3,3), padding=='same', activation='relu')(x)
  tmp = layers.Conv2D(channels, (3,3), padding='same')(tmp)
  tmp = layers.Conv2D(lambda sf: sf+scaling_factor)(tmp)

  return layers.Add()([x,tmp])

In [None]:
def Generator(lr_shape=(32,32,3)):
  gen_input = layers.Input(shape = lr_shape, dtype='float32')
  conv_1 = layers.Conv2D(feature_size, (3,3), padding='same', activation='relu')(gen_input)
  x = conv_1

  for i in range(num_layers):
    x = resBlock(x)
  
  x = layers.Conv2D(feature_size, (3,3), padding='same', activation='relu')(x)
  tmp = layers.Add()([x, conv_1])
  tmp = layers.Conv2D(feature_size, (3,3), padding='same', activation='relu')(tmp)
  tmp = layers.Conv2D(3*(8**2), (3,3), padding='same')(tmp)
  output = layers.Lambda(lambda x: tf.depth_to_space(x, 8), name='gen_output'(tmp))

  gen_model = Model(inputs=gen_input, outputs=output)

  return gen_model

In [None]:
def discriminator_block(model, filters, kernel_size, strides):
  model = layers.Conv2D(filters=filters, kernel_size = kernel_size, strides = strides, padding = 'same')(model)
  model = layers.BatchNormalization(momentum = 0.5)(model)
  model = layers.LeakyReLU(alpha = 0.2)(model)

  return model

In [None]:
def Discriminator(hr_shape=(256,256,3)):
  dis_input = layers.Input(shape = hr_shape, dtype='float32')
  model = layers.Conv2D(filters=64, kernel_size = 3, strides = 1, padding = 'same')(dis_input)
  model = layers.LeakyReLU(alpha = 0.2)(model)

  model = discriminator_block(model, 64, 3, 2)
  model = discriminator_block(model, 128, 3, 1)
  model = discriminator_block(model, 128, 3, 2)
  model = discriminator_block(model, 256, 3, 1)
  model = discriminator_block(model, 256, 3, 2)
  model = discriminator_block(model, 512, 3, 1)
  model = discriminator_block(model, 512, 3, 2)
  model = discriminator_block(model, 512, 3, 2)

  model = layers.Flatten()(model)
  model = layers.Dense(512)(model)
  model = layers.LeakyReLU(alpha = 0.2)(model)
  model = layers.Dense(1)(model)

  discriminator_model = Model(inputs = dis_input, outputs = model)

  return discriminator_model


In [None]:
vgg = applications.vgg19.VGG19(include_top=False, weights='imagenet', input_shape=(256,256,3))
vgg.trainable = False

for l in vgg.layers:
  l.trainable = False

#vgg_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block5_conv4').output)
#vgg_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block4_conv1').output)
vgg_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block2_conv2').output)
vgg_model.trainable = False
vgg_model.compile(loss='mse', optimizer=Adam(lr=1e-04, epsilon=1e-08))



In [None]:
def vgg_loss(y_true,y_pred):
  y_true += 255
  y_pred += 255
  #y_pred = K.clip(y_pred, 0.0, 255.0)
  #return K.mean(K.square(vgg_model(y_true) - vgg_model(y_pred)))
  return K.mean(K.abs(vgg_model(y_true) - vgg_model(y_pred)))

In [None]:
def BCEwithLogitLoss(y_true, y_pred):
  pred_g_fake = y_pred[:config.batch_size]
  pred_d_real = y_pred[config.batch_size:]
  # not sure if this works ???? Stack overflow lol
  #t1 = pred_g_fake - K.sigmoid(pred_g_fake - K.mean(pred_d_real))
  #t2 = pred_d_real - K.sigmoid(pred_d_real - K.mean(pred_g_fake))
  t1 = pred_g_fake - K.mean(pred_d_real)
  t2 = pred_d_real - K.mean(pred_g_fake)
  bce = (K.binary_crossentropy(y_true[:config.batch_size], t1, from_logits=True) + K.binary_crossentropy(y_true[config.batch_size:]))

  return K.mean(bce)



In [None]:
def get_gan_network(discriminator, lr_shape, generator, hr_shape):
  discriminator.trainable = False
  lr_images   = layers.Input(shape=lr_shape, name='lr_input')
  fake_images = generator(lr_images)
  real_images = layers.Input(shape=hr_shape, name='hr_input')
  comined_img = layers.Lambda(lambda x: K.concatenate([x[0], x[1]], axis=0), name='fr_combine')([fake_images, real_images])
  #print(combined_img)
  gan_output = discriminator(combined_img)
  #print(gan_output)
  gan = Model(inputs = [lr_images, hr_images], outputs = [fake_images, gan_output])

  gan.compile(loss=[vgg_loss, BCEwithLogitLoss], loss_weights=[2e-3, 1e-3], optimizer=Adam(lr=2.5e-05, epsilon=1e-08))



  return gan

In [None]:
adam = Adam(lr = .5e-05, beta_1 = 0.9, beta_2 = 0.999, epsilon = 1e-08)

In [None]:
real_data_Y = K.ones((config.batch_size,))
fake_data_Y = K.zeros((config.batch_size,))

dis_label = K.concatenate([fake_data_Y, real_data_Y], axis=0)
gan_label = K.concatenate([real_data_Y, fake_data_Y], axis=0)

Model_save_path = '/content/drive/MyDrive/superres/model_save_esrgan_optv4/'
if not os.path.exists(Model_save_path):
  os.makedirs(Model_save_path)

'''
generator = Generator(lr_shape=(32,32,3))
discriminator = Discriminator(hr_shape=(256,256,3))
generator.compile(loss=perceptual_distance, optimizer=adam)
discriminator.compile(loss=BCEwithLogitLoss, optimizer=Adam(lr=1e-04, epsilon=1e-08))
gan = get_gan_network(discriminator, (32,32,3), generator, (256,256,3))

print(generator.summary())
print(gan.summary())
'''

generator = Generator(lr_shape=(32,32,3))
discriminator = Discriminator(hr_shape=(256,256,3))
generator.compile(loss=perceptual_distance, optimizer=Adam(lr=1e-04, epsilon=1e-08))
discriminator.compile(loss=BCEwithLogitLoss, optimizer=Adam(lr=1e-04, epsilon=1e-08))




In [None]:
if not os.path.exists(Model_save_path+'gen_model.h5'):
  for epoch in range(1,501):
    print('EPOCH', epoch)
    for _ in range(config.steps_per_epoch):
      train_lr_images, train_hr_images = next(train_generator)
      d_loss = 0

      if epoch%2 == 0:
        gen_image_sr = generator.predict(train_lr_images)
        dis_input = K.concatenate([gen_images_sr, train_hr_images, axis=0])
        d_loss = discriminator.train_on_batch(dis_input, dis_label)

    if epoch%2 == 0:
      in_sample_images, out_sample_images = next(val_generator)
      pred = generator.predict(in_sample_images)
      val_perceptual = perceptual_distance(np.array(out_sample_images), pred)
      print('val_perceptual = ', val_perceptual)
      generator.save(Model_save_path+'gen_model.h5', overwrite=True)
      discriminator.save(Model_save_path+'dis_model.h5', overwrite=True)
  

  
else:
  discrininator = load_model(Model_save_path+'dis_model.h5', custom_objects={'BCEwithLogitLoss':BCEwithLogitLoss})
  discriminator.compile(loss=BCEwithLogitLoss, optimizer=Adam(lr=2.5e-05, epsilon=1e-08))

  generator = load_model(Model_save_path+'gen_model.h5', custom_objects={'perceptual_distance':perceptual_distance, 'tf':tf})
  generator.compile(loss=perceptual_distance, optimizer=adam)

  for epoch in range(1,501):
    print('EPOCH', epoch)
    for _ in range(config.steps_per_epoch):
      train_lr_images, train_hr_images = next(train_generator)
      d_loss = 0

      if epoch%2 == 0:
        gen_image_sr = generator.predict(train_lr_images)
        dis_input = K.concatenate([gen_images_sr, train_hr_images, axis=0])
        d_loss = discriminator.train_on_batch(dis_input, dis_label)

    if epoch%2 == 0:
      in_sample_images, out_sample_images = next(val_generator)
      pred = generator.predict(in_sample_images)
      val_perceptual = perceptual_distance(np.array(out_sample_images), pred)
      print('val_perceptual = ', val_perceptual)
      generator.save(Model_save_path+'gen_model.h5', overwrite=True)
      discriminator.save(Model_save_path+'dis_model.h5', overwrite=True)


In [None]:
val_sequence = []
if not os.path.exists(Model_save_path+'gen_model_gan.h5'):
  dis_saved = load_model(Model_save_path+'dis_model.h5', custom_objects={'BCEwithLogitLoss':BCEwithLogitLoss})
  dis_saved.compile(loss=BCEwithLogitLoss, optimizer=Adam(lr=2.5e-05, epsilon=1e-08))
  gen_saved = load_model(Model_save_path+'gen_model.h5', custom_objects={'vgg_loss': vgg_loss, 'perceptual_distance':perceptual_distance, 'tf':tf})
  gan = get_gan_network(dis_saved, (32,32,3), gen_Saved, (256,256,3))

  for epoch in range(1,501):
    print('EPOCH', epoch)
    for _ in range(config.steps_per_epoch):
      train_lr_images, train_hr_images = next(train_generator)
      d_loss = 0

      gen_image_sr = gen_saved.predict(train_lr_images)
      dis_saved.trainable = True
      dis_input = K.concatenate([gen_images_sr, train_hr_images, axis=0])
      d_loss = dis_saved.train_on_batch(dis_input, dis_label)

      dis_saved.trainable = False
      loss_gan = gan.train_on_batch([train_lr_images, train_hr_images], [train_hr_images, gan_label])

    in_sample_images, out_sample_images = next(val_generator)
    pred = gen_saved.predict(in_sample_images)
    tmp_val = perceptual_distance(np.array(out_sample_images), pred)
    print('val_perceptual = ', val_perceptual)
    val_sequence.append(tmp_val)
    gen_saved.save(Model_save_path+'gen_model_gan.h5', overwrite=True)
    dis_saved.save(Model_save_path+'dis_model_gan.h5', overwrite=True)


else:
  dis_saved = load_model(Model_save_path+'dis_model_gan.h5', custom_objects={'BCEwithLogitLoss':BCEwithLogitLoss})
  dis_saved.compile(loss=BCEwithLogitLoss, optimizer=Adam(lr=2.5e-05, epsilon=1e-08))

  gen_saved= load_model(Model_save_path+'gen_model_gan.h5', custom_objects={'vgg_loss': vgg_loss, 'perceptual_distance':perceptual_distance, 'tf':tf})
  gan = get_gan_network(dis_Saved, (32,32,3), gen_saved, (256,256,3))

  for epoch in range(1,251):
    print('EPOCH', epoch)
    for _ in range(config.steps_per_epoch):
      train_lr_images, train_hr_images = next(train_generator)
      d_loss = 0

      gen_image_sr = gen_saved.predict(train_lr_images)
      dis_saved.trainable = True
      dis_input = K.concatenate([gen_images_sr, train_hr_images, axis=0])
      d_loss = dis_saved.train_on_batch(dis_input, dis_label)

      dis_saved.trainable = False
      loss_gan = gan.train_on_batch([train_lr_images, train_hr_images], [train_hr_images, gan_label])

    in_sample_images, out_sample_images = next(val_generator)
    pred = gen_saved.predict(in_sample_images)
    tmp_val = perceptual_distance(np.array(out_sample_images), pred)
    print('val_perceptual = ', val_perceptual)
    val_sequence.append(tmp_val)
    gen_saved.save(Model_save_path+'gen_model_gan.h5', overwrite=True)
    dis_saved.save(Model_save_path+'dis_model_gan.h5', overwrite=True)

                         

In [None]:
model = load_model(Model_save_path+"gen_model_gan.h5", custom_objects={'vgg_loss': vgg_loss, 'perceptual_distance':perceptual_distance, 'tf':tf} )
model.trainable = False
model.compile(loss='mae', optimizer=Adam(lr=2.5e-05, epsilon=1e-08), metrics=[perceptual_distance])
print(model.summary())




In [None]:
model.fit_generator(image_generator(config.batch_size, train_dir),
                    steps_per_epoch = config.steps_per_epoch,
                    epochs = congif.num_epochs,
                    callbacks = [ImageLogger(), WandbCallback()],
                    validation_steps = config.val_steps_per_epoch,
                    vaidation_data = val_generator)