# Deep Learning project work, Image Deblurring with WGANs

Federico Battistella 0000926542

In this project we're going to implement a GAN to perform image deblurring, in this case the dataset CIFAR10 has been selected and the type of blurring applied is a random Gaussian blur.
This projet is based on the paper "DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks" by Orest Kupyn et al.

https://arxiv.org/pdf/1711.07064.pdf

The WGAN (Wasserstein GAN) model, is based on the Wasserstein loss.
The WGAN is an extension of the traditional GANs which manage to stabilize the training of the GANs which usually is very long and difficult to reach the convergency.
The reason GANs are difficult to train is that the architecture involves the simultaneous training of a generator and a discriminator model in a zero-sum game. Stable training requires finding and maintaining an equilibrium between the capabilities of the two models.
Instead of using a discriminator to classify or predict the probability of generated images of being real or fake, the WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.
WGAN changes or replaces the discriminator model with a critic that scores the realness or fakeness of a given image.

This change is motivated by a mathematical argument that training the generator should seek a minimization of the distance between the distribution of the data observed in the training dataset and the distribution observed in generated examples.

In [None]:
# Setting Google Drive for Colab
from google.colab import drive
drive.mount('/content/gdrive')
# !pip install tensorflow
# !pip install keras

Imports

In [None]:
import tqdm
import datetime
import math
import os
import gc
import h5py
import numpy as np
from PIL import Image
from skimage.metrics import structural_similarity
import tensorflow as tf
import keras
import keras.backend as K

from keras import losses
from keras import backend
from keras.backend import image_data_format
#from keras.backend import normalize_data_format
from keras.layers import InputSpec
from tensorflow.keras.layers import Layer
from keras.utils import conv_utils
from keras.models import Model, load_model
from keras.applications.vgg16 import VGG16
from tensorflow.keras.optimizers import Adam
from keras.layers import Input, Lambda
from keras.layers.merge import Add
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Conv2D, UpSampling2D
from keras.layers.advanced_activations import LeakyReLU
from tensorflow.keras.layers import BatchNormalization
from keras.initializers import RandomNormal

Eager execution is required to train the model, some modules used are not compatible with tf2 graph execution

In [None]:
tf.config.run_functions_eagerly(True)

Test if the GPU device is available for training

In [None]:
#GPU TEST
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  print(
      '\n\nThis error most likely means that this notebook is not '
      'configured to use a GPU.  Change this in Notebook Settings via the '
      'command palette (cmd/ctrl-shift-P) or the Edit menu.\n\n')
  raise SystemError('GPU device not found')

In [None]:
# Parameters

image_shape = (32, 32, 3)
ngf = 64
ndf = 64
n_blocks_generator = 9
n_layers_discriminator = 3
len_all = 50000

## Models

Each residual block starts with a padding to keep the dim as the input after the convolutions, then we find a batch normalization layer that alter the distribution of the datas to make them having mean = 0 and standard deviation = 1.
Then we find a dropout and a residual connection in the end that just performs a summation between input and output of the block, making it easier for the network to learn the identity function without adding computational complexity nor additional parameters.

In [None]:
def res_block(input, filters, kernel_size=(3, 3), strides=(1, 1), use_dropout=True):
    # Resnet Block 
    init = RandomNormal(mean=0.0, stddev=0.0025)
    
    block = ReflectionPadding2D((1, 1))(input)
    block = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides, kernel_initializer=init)(block)
    block = BatchNormalization()(block)
    block = Activation('relu')(block)

    if use_dropout:
        block = Dropout(0.5)(block)

    block = ReflectionPadding2D((1, 1))(block)
    block = Conv2D(filters=filters,
               kernel_size=kernel_size,
               strides=strides, kernel_initializer=init)(block)
    block = BatchNormalization()(block)

    merged = Add()([input, block])
    return merged

Reflection 2D padding is applied to balance the following convolutional layer so to maintain the dimension of the input image. Reflection 2d Padding reflects the pixel values around the edge, and this enables us to obtain padded inputs that are still part of the same data distribution as the input image.

In [None]:
def spatial_reflection_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
    """
    Used to pad the 2nd and 3rd dimensions of a 4D tensor.

    Input tensor -> x
    Padding dimension -> padding
    Format of the data ('channels_last', 'channels_first') -> data_format
    """
    assert len(padding) == 2
    assert len(padding[0]) == 2
    assert len(padding[1]) == 2

    if data_format is None:
        data_format = image_data_format()
    
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('Unknown data_format ' + str(data_format))

    if data_format == 'channels_first':
        pattern = [[0, 0],
                   [0, 0],
                   list(padding[0]),
                   list(padding[1])]
    else:
        pattern = [[0, 0],
                   list(padding[0]), list(padding[1]),
                   [0, 0]]
    return tf.pad(x, pattern, "REFLECT")

In [None]:
def normalize_data_format(value):
    if value is None:
        value = image_data_format()
    data_format = value.lower()
    if data_format not in {'channels_first', 'channels_last'}:
        raise ValueError('The `data_format` argument must be one of '
                         '"channels_first", "channels_last". Received: ' +
                         str(value))
    return data_format

In [None]:
class ReflectionPadding2D(Layer):

    def __init__(self,
                 padding=(1, 1),
                 data_format=None,
                 **kwargs):
        super(ReflectionPadding2D, self).__init__(**kwargs)
        self.data_format = normalize_data_format(data_format)
        if isinstance(padding, int):
            self.padding = ((padding, padding), (padding, padding))
        elif hasattr(padding, '__len__'):
            if len(padding) != 2:
                raise ValueError('`padding` should have two elements. '
                                 'Found: ' + str(padding))
            height_padding = conv_utils.normalize_tuple(padding[0], 2,
                                                        '1st entry of padding')
            width_padding = conv_utils.normalize_tuple(padding[1], 2,
                                                       '2nd entry of padding')
            self.padding = (height_padding, width_padding)
        else:
            raise ValueError('`padding` should be either an int, '
                             'a tuple of 2 ints '
                             '(symmetric_height_pad, symmetric_width_pad), '
                             'or a tuple of 2 tuples of 2 ints '
                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
                             'Found: ' + str(padding))
        self.input_spec = InputSpec(ndim=4)

    def compute_output_shape(self, input_shape):
        if self.data_format == 'channels_first':
            if input_shape[2] is not None:
                rows = input_shape[2] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[3] is not None:
                cols = input_shape[3] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    input_shape[1],
                    rows,
                    cols)
        elif self.data_format == 'channels_last':
            if input_shape[1] is not None:
                rows = input_shape[1] + self.padding[0][0] + self.padding[0][1]
            else:
                rows = None
            if input_shape[2] is not None:
                cols = input_shape[2] + self.padding[1][0] + self.padding[1][1]
            else:
                cols = None
            return (input_shape[0],
                    rows,
                    cols,
                    input_shape[3])

    def call(self, inputs):
        return spatial_reflection_2d_padding(inputs,
                                             padding=self.padding,
                                             data_format=self.data_format)

    def get_config(self):
        config = {'padding': self.padding,
                  'data_format': self.data_format}
        base_config = super(ReflectionPadding2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Generator model architecture

The generator architecture consists in:
-
*   Downsampling
*   Residual blocks X9
*   Upsampling

In [None]:
def generator_model():
    init = RandomNormal(mean=0.0, stddev=0.0025)
    inputs = Input(shape=image_shape)

    x = ReflectionPadding2D((3, 3))(inputs)
    x = Conv2D(filters=ngf, kernel_size=(7, 7), padding='valid', kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    n_downsampling = 2
    for i in range(n_downsampling):
        mult = 2**i
        x = Conv2D(filters=ngf*mult*2, kernel_size=(3, 3), strides=2, padding='same', kernel_initializer=init)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
    
    mult = 2**n_downsampling
    for i in range(n_blocks_generator):
        x = res_block(x, ngf*mult, use_dropout=True)
    

    for i in range(n_downsampling):
        mult = 2**(n_downsampling - i)
        x = UpSampling2D()(x)
        x = Conv2D(filters=int(ngf * mult / 2), kernel_size=(3, 3), padding='same', kernel_initializer=init)(x)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        

    x = ReflectionPadding2D((3, 3))(x)
    x = Conv2D(filters=3, kernel_size=(7, 7), padding='valid', kernel_initializer=init)(x)
    x = Activation('tanh')(x)
    
    
    outputs = Add()([x, inputs])
    outputs = Lambda(lambda z: z/2)(outputs)

    model = Model(inputs=inputs, outputs=outputs, name='Generator')
    return model

Discriminator model architecture

Disciminator model is a convolutional patchGAN modified to ouput a critic score instead of the probability for the image of being real or not.
The critic model indicates how close is the model distribution to the real distribution.
The model is made by several convolutions for downsampling, followed by batch normalizations and leakyRelu as activation funcitons.
In the end a dense layer with a linear activation is required to output the score of the critic.

In [None]:
def discriminator_model():
    # discriminator architecture
    init = RandomNormal(mean=0.0, stddev=0.0025)
    
    inputs = Input(shape=image_shape)

    x = Conv2D(filters=ndf, kernel_size=(4, 4), strides=2, padding='same', kernel_initializer=init)(inputs)
    x = LeakyReLU(0.2)(x)

    nf_mult = 1
    for n in range(n_layers_discriminator):
        nf_mult = min(2**n, 8)
        x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=2, padding='same', kernel_initializer=init)(x)
        x = BatchNormalization()(x)
        x = LeakyReLU(0.2)(x)

    nf_mult = min(2**n_layers_discriminator, 8)
    x = Conv2D(filters=ndf*nf_mult, kernel_size=(4, 4), strides=1, padding='same', kernel_initializer=init)(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(0.2)(x)

    x = Conv2D(filters=1, kernel_size=(4, 4), strides=1, padding='same', kernel_initializer=init)(x)

    x = Flatten()(x)
    x = Dense(1, activation='linear', kernel_initializer=init)(x)

    model = Model(inputs=inputs, outputs=x, name='Discriminator')
    return model

Model combining generator and discriminator

In [None]:
def generator_containing_discriminator(generator, discriminator):
    inputs = Input(shape=image_shape)
    generated_image = generator(inputs)
    outputs = discriminator(generated_image)
    model = Model(inputs=inputs, outputs=[generated_image, outputs])
    return model

## Losses

Total loss = Wasserstein_loss + perceptual_loss

In [None]:
def perceptual_loss(y_true, y_pred):
    vgg = VGG16(include_top=False, weights='imagenet', input_shape=image_shape)
    loss_model = Model(inputs=vgg.input, outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable = False
    return K.mean(K.square(loss_model(y_true) - loss_model(y_pred)))

def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true*y_pred)

def mae_loss(y_true, y_pred):
    return K.mean(K.abs(y_true - y_pred))

def content_loss(y_true, y_pred):
    return 100*perceptual_loss(y_true, y_pred) + 140*mae_loss(y_true, y_pred)

## Dataset loading

The dataset used is CIFAR10, 50000 32x32 RGB images, then a Gaussian Blurring with random kernel size between 1 and 9 is applied to obtain the dataset for the deblurring task

In [None]:
from keras.datasets import cifar10
from matplotlib import pyplot
import cv2
import random
# Loading dataset CIFAR-10
(trainX, trainY), (testX, testY) = cifar10.load_data()
# Check dataset shapes
print('Train: X=%s, y=%s' % (trainX.shape, trainY.shape))
print('Test: X=%s, y=%s' % (testX.shape, testY.shape))

# Blurring train images
trainX_blur = []
for i in range(trainX.shape[0]):
  size = random.randint(1,9)
  kernel = (size, size)
  trainX_blur.append(cv2.blur(trainX[i], kernel))

# Visual checks
trainX_blur = np.array(trainX_blur)
print(trainX_blur.shape)

for i in range(9):
	# define subplot
	pyplot.subplot(330 + 1 + i)
	# plot raw pixel data
	pyplot.imshow(trainX_blur[i])
# show the figure
pyplot.show()

# Blurring test images
testX_blur = []
for i in range(testX.shape[0]):
  size = random.randint(1,9)
  kernel = (size, size)
  testX_blur.append(cv2.blur(testX[i], kernel))

# Visual checks
testX_blur = np.array(testX_blur)
print(testX_blur.shape)

for i in range(9):
	# define subplot
	pyplot.subplot(330 + 1 + i)
	# plot raw pixel data
	pyplot.imshow(testX_blur[i])
# show the figure
pyplot.show()


Preprocessing applied to the images

In [None]:
def preprocess_image(img):
    return (img - 127.5) / 127.5

def deprocess_image(img):
    return img * 127.5 + 127.5

# def load_data(data_type, path):
#     with h5py.File(path, 'r') as f:
#         data_sharp = f['%s_data_sharp' % data_type][:].astype(np.float16)
#         data_sharp = preprocess_image(data_sharp)

#         data_blur = f['%s_data_blur' % data_type][:].astype(np.float16)
#         data_blur = preprocess_image(data_blur)

#         return data_sharp, data_blur

def load_data(train, train_blur):
  data_sharp = train.astype(np.float16)
  data_sharp = preprocess_image(data_sharp)

  data_blur = train_blur.astype(np.float16)
  data_blur = preprocess_image(data_blur)

  return data_sharp, data_blur


In [None]:
# !zip -r train.zip dataset/train
# !zip -r train_blur.zip dataset/train_blur
# from google.colab import files
# files.download("train.zip")
# files.download("train_blur.zip")

In [None]:
weights_save_dir =  '/content/gdrive/My Drive/Colab Notebooks/weights'
model_save_dir = '/content/gdrive/My Drive/Colab Notebooks/models'
model_load_dir = '/content/gdrive/My Drive/Colab Notebooks/models'


def save_all_weights(d, g, epoch_number, current_loss):
    now = datetime.datetime.now()
    save_dir = os.path.join(weights_save_dir, '{}{}'.format(now.month, now.day))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    g.save_weights(os.path.join(save_dir, 'generator_{}_{}.h5'.format(epoch_number, current_loss)), True)
    d.save_weights(os.path.join(save_dir, 'discriminator_{}.h5'.format(epoch_number)), True)

def save_models(d, d_on_g):
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    for layer in d.layers[:]:
        layer.trainable = True
    d.save(os.path.join(model_save_dir, 'discriminator_model.h5'))
    for layer in d.layers[:]:
        layer.trainable = False
    d_on_g.save(os.path.join(model_save_dir, 'combined_model.h5'))
    for layer in d.layers[:]:
        layer.trainable = True   


## Training

In [None]:
def train(batch_size, epoch_num, critic_updates=5, load=False):
    with tf.device('/device:GPU:0'):
      if load: 
          dis = load_model(os.path.join(model_load_dir, 'discriminator_model.h5'), custom_objects={'wasserstein_loss': wasserstein_loss})
          dis_and_gen = load_model(os.path.join(model_load_dir, 'combined_model.h5'), custom_objects={'ReflectionPadding2D': ReflectionPadding2D, 'wasserstein_loss': wasserstein_loss, 'content_loss' : content_loss})
          gen = dis_and_gen.get_layer('Generator')
          
      else:
          gen = None
          dis = None
          gen = generator_model()
          dis = discriminator_model()
          dis_and_gen = generator_containing_discriminator(gen, dis)

          dis_opt = Adam(lr=0.0001, beta_1=0.9)
          dis_and_gen_opt = Adam(lr=0.0001, beta_1=0.9)

          dis.trainable = True
          dis.compile(loss=wasserstein_loss, optimizer=dis_opt)
          dis.trainable = False
          dis_and_gen.compile(loss=[content_loss, wasserstein_loss], optimizer=dis_and_gen_opt)
          dis.trainable = True
        
      y_train, x_train = load_data(trainX, trainX_blur)
      
      for epoch in tqdm.tqdm(range(epoch_num)):
          
          d_losses = []
          d_losses_r = []
          d_losses_f = []
          dis_and_gen_losses = []
          num_batches = int(len_all / batch_size)
          curr_dataset = 0
          
          for batch in range(num_batches):        
              if batch%100 == 0:
                print(f'Epoch {epoch}, iteration: {batch}')
              permutated_indexes = np.random.permutation(x_train.shape[0])
              batch_indexes = permutated_indexes[0:batch_size]
              image_blur_batch = x_train[batch_indexes]
              image_sharp_batch = y_train[batch_indexes]
              
              positive_labels = np.ones((batch_size, 1))
              negative_labels = -np.ones((batch_size, 1))
              batch_and_labels = [image_sharp_batch, positive_labels]

              # Discrimintor training
              gen_imgs = gen.predict(image_blur_batch)
              for _ in range(critic_updates):
                  d_loss_real = dis.train_on_batch(image_sharp_batch, positive_labels)
                  d_loss_fake = dis.train_on_batch(gen_imgs, negative_labels)
                  d_losses_r.append(d_loss_real)
                  d_losses_f.append(d_loss_fake)
                  d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
                  d_losses.append(d_loss)
                  # Weights clipping to [-0.01,0.01] to enforce the Lipschitz constraint
                  for l in dis.layers:
                      weights = l.get_weights()
                      weights = [np.clip(w, -0.01, 0.01) for w in weights]
                      l.set_weights(weights)

              # Generator Training 
              dis.trainable = False
              dis_and_gen_loss = dis_and_gen.train_on_batch(image_blur_batch, batch_and_labels)
              dis_and_gen_losses.append(dis_and_gen_loss)
              dis.trainable = True
              

          print(f"{epoch} [D loss r: {np.mean(d_losses_r)} | D loss f: {np.mean(d_losses_f)} | D loss: {np.mean(d_losses)}] [G loss: {np.mean(dis_and_gen_losses)}]")
          with open('/content/gdrive/My Drive/Colab Notebooks/logs/log.txt', 'a+') as f:
              f.write(f"{epoch} [D loss r: {np.mean(d_losses_r)} | D loss f: {np.mean(d_losses_f)} | D loss: {np.mean(d_losses)}] [G loss: {np.mean(dis_and_gen_losses)}]\n")
        
          # save weights
          if (epoch+1) % 5 == 0:
              save_all_weights(dis, gen, epoch, int(np.mean(dis_and_gen_losses)))
          # save models to resume training
          if (epoch+1) % 5 == 0:
              save_models(dis, dis_and_gen)         

In [None]:
train(32, 50)#, load=True)

## Testing

##Metrics
Mean Squared Error

SSIM measures the quality of the images taking as reference the original

In [None]:
def MSE(img1, img2):
    return np.mean( (img1/255. - img2/255.) ** 2 )

In [None]:
def test(batch_size):
	y_test, x_test = load_data(testX, testX_blur)
	gen = generator_model()
	gen.load_weights('/content/gdrive/My Drive/Colab Notebooks/weights/14/generator_600_3.h5')
	generated_images = gen.predict(x=x_test, batch_size=batch_size)
	generated = np.array([deprocess_image(img) for img in generated_images])
	x_test = deprocess_image(x_test)
	y_test = deprocess_image(y_test)

	ssim = 0
	mse = 0

	for i in range(generated_images.shape[0]):
		y = y_test[i, :, :, :]
		x = x_test[i, :, :, :]
		img = generated[i, :, :, :]

        # randomly save 5% of tested images 
		if not os.path.exists("/content/gdrive/My Drive/Colab Notebooks/results"):
			os.makedirs("/content/gdrive/My Drive/Colab Notebooks/results")
		if (np.random.random()*100) < 5:
			output = np.concatenate((y, x, img), axis=1)
			im = Image.fromarray(output.astype(np.uint8))
			im.save('/content/gdrive/My Drive/Colab Notebooks/results/result{}.png'.format(i))

		# metrics
		ssim += structural_similarity(y, img, multichannel=True)
		mse += MSE(y, img)

	avg_ssim = ssim / generated_images.shape[0]
	avg_mse = mse / generated_images.shape[0]

	print('SSIM: {} \n'.format(avg_ssim))
	print('MSE: {} \n'.format(avg_mse))
	with open('/content/gdrive/My Drive/Colab Notebooks/log.txt', 'a+') as f:
		f.write('SSIM: {} \n'.format(avg_ssim))
		f.write('MSE: {} \n'.format(avg_mse))

Test local to test without connecting to Drive

In [None]:
def test_local(batch_size):
  	y_test, x_test = load_data(testX, testX_blur)
	gen = generator_model()
	gen.load_weights('generator_600_3.h5')
	generated_images = gen.predict(x=x_test, batch_size=batch_size)
	generated = np.array([deprocess_image(img) for img in generated_images])
	x_test = deprocess_image(x_test)
	y_test = deprocess_image(y_test)

	ssim = 0
	mse = 0

	for i in range(generated_images.shape[0]):
		y = y_test[i, :, :, :]
		x = x_test[i, :, :, :]
		img = generated[i, :, :, :]
    
		# metrics
		ssim += structural_similarity(y, img, multichannel=True)
		mse += MSE(y, img)

	avg_ssim = ssim / generated_images.shape[0]
	avg_mse = mse / generated_images.shape[0]

	print('SSIM: {} \n'.format(avg_ssim))
	print('MSE: {} \n'.format(avg_mse))

In [None]:
#test(16)
test_local(16)

#Results

In [None]:
with open('/content/gdrive/My Drive/Colab Notebooks/logs/log.txt') as f:
  lines = f.readlines()
  lines = [line.rstrip() for line in lines]
index_list = []
dis_loss_history = []
gen_loss_history = []
for line in lines:
  index = line[0:2].lstrip()
  dis_loss = line.split('[')[1].split('D loss: ')[1][:-2]
  gen_loss = line.split('[')[2].split('G loss: ')[1][:-2]
  index_list.append(int(index)+1)
  dis_loss_history.append(float(dis_loss))
  gen_loss_history.append(float(gen_loss))

In [None]:
import matplotlib.pyplot as plt
x = np.array(index_list)
y1 = np.array(dis_loss_history)
y2 = np.array(gen_loss_history)

plt.plot(x, y1)
plt.show()

In [None]:
plt.plot(x, y2)
plt.show()

Some results

In [None]:
import cv2
import matplotlib.pyplot as plt
images_list = ['result1', 'result2', 'result3', 'result4']
for img in images_list:
  img_file = img + '.png'
  image = cv2.imread(img_file)
  plt.figure()
  plt.imshow(image) 
  plt.show()