In [None]:
from __future__ import print_function, division
import tensorflow as tf
import numpy as np
import os
from tensorflow import keras
from keras import backend as K
import matplotlib.pyplot as plt
from matplotlib.image import imread
import cv2
from skimage.transform import resize
import skimage.io
import scipy
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D, Add
from keras.layers.advanced_activations import PReLU, LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.applications import VGG19
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import sys
from tensorflow.python.client import device_lib 
!pip install git+https://www.github.com/keras-team/keras-contrib.git
%matplotlib inline

print('Tensorflow Version: ', tf.__version__)
print('Tensorflow built with Cuda: ', tf.test.is_built_with_cuda())
print('Tensorflow built with GPU support: ', tf.test.is_built_with_gpu_support())
print('GPU available: ', tf.test.is_gpu_available(cuda_only=False, min_cuda_compute_capability=None), '\n')
print(device_lib.list_local_devices())

In [None]:
# Set directory
directory = 'SRGAN/non_demented'
# Append list of image paths
filepaths = []

for dir_, _, files in os.walk(directory):
    for fileName in files:
        relDir = os.path.relpath(dir_, directory)
        relFile = os.path.join(relDir, fileName)
        filepaths.append(directory + "/" + relFile)

In [None]:
# Load Data Function
def load_data(batch_size=1, is_testing=False):
    """
    Outputs high-resolution image, low-resolution image
    """
    data_type = "train" if not is_testing else "test"

    batch_images = np.random.choice(filepaths, size=1)

    imgs_hr = []
    imgs_lr = []
    for img_path in batch_images:
        img = cv2.imread(img_path).astype(np.float)
        h, w = (720, 720) # 256?
        low_h, low_w = int(h / 4), int(w / 4)
        img_hr = resize(img, (720, 720, 3))
        img_lr = resize(img, (low_h, low_w, 3))

        if not is_testing and np.random.random() < 0.5:
            img_hr = np.fliplr(img_hr)
            img_lr = np.fliplr(img_lr)

        imgs_hr.append(img_hr)
        imgs_lr.append(img_lr)

    imgs_hr = np.array(imgs_hr) / 719.5 - 1.
    imgs_lr = np.array(imgs_lr) / 719.5 - 1.

    return imgs_hr, imgs_lr

In [None]:
# Network Attributes
dataset_name = 'non_demented'
channels = 3                  # Channel
lr_height = 180                # Low resolution height
lr_width = 180                 # Low resolution width
lr_shape = (lr_height, lr_width, channels)          # Low resolution shape
hr_height = lr_height*4   # High resolution height
hr_width = lr_width*4     # High resolution width
hr_shape = (hr_height, hr_width, channels)          # High resolution shape
n_residual_blocks = 16    # Number of residual blocks in the generator
optimizer = Adam(0.0002, 0.5)     # Optimizer

In [None]:
# Construct VGG19 Pre-Trained Model for Feature Extraction
def build_vgg(hr_shape=hr_shape):
    """
    Builds a pre-trained VGG19 model that outputs image features extracted  at the third block of the model
    """
    vgg = VGG19(weights="imagenet")
    vgg.outputs = [vgg.layers[9].output]
    img = Input(shape=hr_shape)

    # Extract image features
    img_features = vgg(img)

    return Model(img, img_features)

vgg = build_vgg(hr_shape)
vgg.trainable = False
vgg.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])

In [None]:
# Configure data loader

patch = int(hr_height / 2**4)
disc_patch = (patch, patch, 1)

# Number of filters in the first layer of G and D
gf = 64
df = 64

def build_discriminator(hr_shape=hr_shape, df=df):

    def d_block(layer_input, filters, strides=1, bn=True):
        """
        Discriminator layer
        """
        d = Conv2D(filters, kernel_size=3, strides=strides, padding='same')(layer_input)
        d = LeakyReLU(alpha=0.2)(d)
        if bn:
            d = BatchNormalization(momentum=0.8)(d)
        return d

    # Input img
    d0 = Input(shape=hr_shape)

    d1 = d_block(d0, df, bn=False)
    d2 = d_block(d1, df, strides=2)
    d3 = d_block(d2, df*2)
    d4 = d_block(d3, df*2, strides=2)
    d5 = d_block(d4, df*4)
    d6 = d_block(d5, df*4, strides=2)
    d7 = d_block(d6, df*8)
    d8 = d_block(d7, df*8, strides=2)

    d9 = Dense(df*16)(d8)
    d10 = LeakyReLU(alpha=0.2)(d9)
    validity = Dense(1, activation='sigmoid')(d10)

    return Model(d0, validity)

# Build and compile the discriminator
discriminator = build_discriminator(hr_shape, df)
discriminator.compile(loss='mse',
    optimizer=optimizer,
    metrics=['accuracy'])

In [None]:
def build_generator(lr_shape=lr_shape, gf=gf, n_residual_blocks=n_residual_blocks, channels=channels):

    def residual_block(layer_input, filters):
        """
        Residual block described in paper
        """
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(layer_input)
        d = Activation('relu')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Conv2D(filters, kernel_size=3, strides=1, padding='same')(d)
        d = BatchNormalization(momentum=0.8)(d)
        d = Add()([d, layer_input])
        return d

    def deconv2d(layer_input):
        """
        Layers used during upsampling
        """
        u = UpSampling2D(size=2)(layer_input)
        u = Conv2D(256, kernel_size=3, strides=1, padding='same')(u)
        u = Activation('relu')(u)
        return u

    # Low resolution image input
    img_lr = Input(shape=lr_shape)


    c1 = Conv2D(64, kernel_size=9, strides=1, padding='same')(img_lr)
    r = residual_block(c1, gf)
    for _ in range(n_residual_blocks - 1):
        r = residual_block(r, gf)
    c2 = Conv2D(64, kernel_size=3, strides=1, padding='same')(r)
    c2 = BatchNormalization(momentum=0.8)(c2)
    c2 = Add()([c2, c1])

    u1 = deconv2d(c2)
    u2 = deconv2d(u1)

    # Generate high resolution output
    gen_hr = Conv2D(channels, kernel_size=9, strides=1, padding='same', activation='tanh')(u2)

    return Model(img_lr, gen_hr)

# Build the generator
generator = build_generator(lr_shape, gf, n_residual_blocks, channels)


In [None]:

# High res and low res images
img_hr = Input(shape=hr_shape)
img_lr = Input(shape=lr_shape)

# Generate high res version from low res
fake_hr = generator(img_lr)

# Extract image features of the generated img
fake_features = vgg(fake_hr)

discriminator.trainable = False

# Discriminator determines validity of generated high res. images
validity = discriminator(fake_hr)

combined = Model([img_lr, img_hr], [validity, fake_features])
combined.compile(loss=['binary_crossentropy', 'mse'],
                      loss_weights=[1e-3, 1],
                      optimizer=optimizer)

        

def sample_images(epoch, dataset_name=dataset_name):
  """
  asdasd
  """

  imgs_hr, imgs_lr = load_data(batch_size=2, is_testing=True)
  fake_hr = generator.predict(imgs_lr)

  # Rescale images 0 - 1
  imgs_lr = (imgs_lr + 1)*511.5
  fake_hr = (fake_hr + 1)*511.5
  imgs_hr = (imgs_hr + 1)*511.5

  # Save generated images and the high resolution originals
  titles = ['Generated', 'Original']
  fig, axs = plt.subplots(1, 2)
  cnt = 0

  for col, image in enumerate([fake_hr, imgs_hr]):
    print(f'col: {col}')
    axs[col].imshow(image[0].astype(int))
    axs[col].set_title(titles[col])
    axs[col].axis('off')
    cnt += 1
  fig.savefig("images/%s/720/%d.png" % (dataset_name, epoch))
  plt.close()

  fig = plt.figure()
  plt.imshow(imgs_lr[0].astype(int))
  fig.savefig('images/%s/720/%d_lowres%d.png' % (dataset_name, epoch, 0))
  plt.close()



def train(epochs, generator=generator, discriminator=discriminator, disc_patch=disc_patch, vgg=vgg, combined=combined, batch_size=1, sample_interval=50):

    start_time = datetime.datetime.now()

    for epoch in range(epochs):

        # ----------------------
        #  Train Discriminator
        # ----------------------

        # Sample images and their conditioning counterparts
        imgs_hr, imgs_lr = load_data()


        # From low res. image generate high res. version

        fake_hr = generator.predict(imgs_lr)

        valid = np.ones((batch_size,) + disc_patch)
        fake = np.zeros((batch_size,) + disc_patch)

        # Train the discriminators (original images = real / generated = Fake)
        d_loss_real = discriminator.train_on_batch(imgs_hr, valid)
        d_loss_fake = discriminator.train_on_batch(fake_hr, fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

        #  Train Generator
        # Sample images and their conditioning counterparts
        imgs_hr, imgs_lr = load_data()

        # The generators want the discriminators to label the generated images as real
        valid = np.ones((batch_size,) + disc_patch)

        # Extract ground truth image features using pre-trained VGG19 model
        image_features = vgg.predict(imgs_hr)

        # Train the generators
        g_loss = combined.train_on_batch([imgs_lr, imgs_hr], [valid, image_features])
        elapsed_time = datetime.datetime.now() - start_time
        print ("%d time: %s" % (epoch, elapsed_time))

        # save image at every other interval
        if epoch % sample_interval == 0:
            sample_images(epoch)


train(epochs=3500, batch_size=1, sample_interval=50) 
