In [1]:
import numpy as np
import tensorflow as tf
import PIL
import cv2
import matplotlib.pyplot as plt
%matplotlib inline
import tqdm.notebook as tq
import os
from PIL import Image

import keras
from keras.models import Model
from keras.models import Sequential
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import Dense
from keras.layers import BatchNormalization
from keras.layers import Dropout
from keras.preprocessing.image import ImageDataGenerator
from keras import initializers
from keras.layers import LeakyReLU
from keras.activations import relu
from keras.layers import Activation
from keras.layers import Flatten
from keras.layers import UpSampling2D
from keras.layers import GlobalAveragePooling2D

import gc
import time

In [2]:
!pip install -q kaggle

In [3]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/

### Importing the dataset from Kaggle: The CelebA Dataset consisting of about 200,000 images of people.


In [4]:
!kaggle datasets list -s celeba

ref                                                   title                                        size  lastUpdated          downloadCount  
----------------------------------------------------  ------------------------------------------  -----  -------------------  -------------  
jessicali9530/celeba-dataset                          CelebFaces Attributes (CelebA) Dataset        1GB  2018-06-01 20:08:48          43119  
zuozhaorui/celeba                                     celeba                                        3GB  2018-11-03 05:29:21            342  
ashishjangra27/gender-recognition-200k-images-celeba  Gender Classification 200K Images | CelebA    1GB  2020-05-22 20:15:23            319  
ruchi798/periocular-detection                         Periocular Recognition                       13MB  2020-08-09 00:45:00            134  
ahmedshawaf/celeba                                    celeba                                        1GB  2020-04-07 21:56:35            133  
megh24

In [5]:
!kaggle datasets download -d jessicali9530/celeba-dataset

Downloading celeba-dataset.zip to /content
100% 1.32G/1.33G [00:23<00:00, 52.7MB/s]
100% 1.33G/1.33G [00:23<00:00, 60.3MB/s]


In [7]:
!unzip '/content/celeba-dataset.zip'

Archive:  /content/celeba-dataset.zip
replace img_align_celeba/img_align_celeba/000001.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: 

### Loading the data and preprocessing it.

In [8]:
os.mkdir('/content/reduced_dataset')
os.mkdir('/content/reduced_dataset/reduced_dataset/')

In [9]:
directory = '/content/img_align_celeba/img_align_celeba/'
original_ht = 208
original_width = 178
diff = (original_ht - original_width)//2
for image in tq.tqdm(os.listdir(directory)):
    img = Image.open(directory + image)
    img = img.crop((0, diff, original_width, original_ht-diff))
    img.thumbnail((128, 128), Image.ANTIALIAS)
    img.save("reduced_dataset/reduced_dataset/" + image)

HBox(children=(FloatProgress(value=0.0, max=202599.0), HTML(value='')))




### Using ImageDataGenerator of Keras to load the large dataset into batches.

In [10]:
def preprocessing_function(x):
    return x/128. - 1.

datagen = ImageDataGenerator(preprocessing_function=preprocessing_function, validation_split=0.1)

train_ds = datagen.flow_from_directory('reduced_dataset/',
                                             target_size=(128, 128), batch_size=128,
                                             class_mode=None, subset='training')
valid_ds = datagen.flow_from_directory('reduced_dataset/',
                                             target_size=(128, 128), batch_size=128,
                                             class_mode=None, subset='validation')

Found 182340 images belonging to 1 classes.
Found 20259 images belonging to 1 classes.


### Building the Architecture

In [11]:
class ResidualUnit(keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides=1, activation=LeakyReLU(alpha=0.2), **kwargs):
        super().__init__(**kwargs)
        self.activation = keras.activations.get(activation)
        self.main_layers = [
                            Conv2D(filters, kernel_size=3, strides=strides, padding="SAME"),
                            BatchNormalization(),
                            self.activation,
                            Conv2D(filters, kernel_size=3, strides=strides, padding="SAME"),
                            BatchNormalization(),
        ]
        self.skip_layers = [
                            Conv2D(filters, kernel_size=1, strides=strides, padding="SAME"),
                            BatchNormalization(),
        ]
    def call(self, inputs):
        Z = inputs
        for layer in self.main_layers:
            Z = layer(Z)
        skip_Z = inputs
        for layer in self.skip_layers:
            skip_Z = layer(skip_Z)
        return self.activation(Z + skip_Z)

In [12]:
np.random.seed(42)
tf.random.set_seed(42)

generator = Sequential()                                                   
generator.add(Conv2D(64, kernel_size=7, strides=1, padding="SAME",
           activation=LeakyReLU(alpha=0.2), input_shape=[32, 32, 3]))

for filters in [256, 128, 64]:
    generator.add(ResidualUnit(filters, kernel_size=3, strides=1))
    generator.add(ResidualUnit(filters, kernel_size=3, strides=1))

generator.add(UpSampling2D(size=2))
generator.add(Conv2D(64, kernel_size=3, strides=1, padding="SAME"))
generator.add(BatchNormalization())
Activation(LeakyReLU(alpha=0.2))

generator.add(UpSampling2D(size=2))
generator.add(Conv2D(3, kernel_size=9, strides=1, padding="SAME",
           activation='tanh'))


discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=3, strides=1, padding="SAME",
            activation=LeakyReLU(alpha=0.2), input_shape=[128, 128, 3]))

for filters in [64, 128, 256, 512]:
    discriminator.add(Conv2D(filters, kernel_size=3, strides=2, padding="SAME"))
    discriminator.add(BatchNormalization())
    Activation(LeakyReLU(alpha=0.2))

discriminator.add(Conv2D(256, kernel_size=3, strides=1, padding="SAME"))
discriminator.add(BatchNormalization())
Activation(LeakyReLU(alpha=0.2))

discriminator.add(Dense(1024))
discriminator.add(BatchNormalization())
discriminator.add(GlobalAveragePooling2D())
discriminator.add(Dense(1, activation='sigmoid'))


In [13]:
generator.summary()
discriminator.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 32, 32, 64)        9472      
_________________________________________________________________
residual_unit (ResidualUnit) (None, 32, 32, 256)       757504    
_________________________________________________________________
residual_unit_1 (ResidualUni (None, 32, 32, 256)       1249024   
_________________________________________________________________
residual_unit_2 (ResidualUni (None, 32, 32, 128)       477056    
_________________________________________________________________
residual_unit_3 (ResidualUni (None, 32, 32, 128)       313216    
_________________________________________________________________
residual_unit_4 (ResidualUni (None, 32, 32, 64)        119744    
_________________________________________________________________
residual_unit_5 (ResidualUni (None, 32, 32, 64)        7

### Some utility functions

In [14]:
def plot_gan(SR_images, LR_images):
    fig = plt.figure(figsize=(12,3))
    for i in range(8):
        plt.subplot(2, 8, i+1)
        plt.imshow((SR_images[i] + 1 )*0.5)
        plt.axis('off')
        
        plt.subplot(2, 8, i+8+1)
        plt.imshow((LR_images[i] + 1 )*0.5)
        plt.axis('off')

In [15]:
def save_GIF_images(model, GIF_seed, epoch):
    pred = model(GIF_seed, training = False)
    fig = plt.figure(figsize=(12, 3))

    for i in range(8):
        plt.subplot(2, 8, i+1)
        plt.imshow((pred[i] + 1 )*0.5)
        plt.axis('off')
    
        plt.subplot(2, 8, i+8+1)
        plt.imshow((GIF_seed[i] + 1 )*0.5)
        plt.axis('off')
    
    plt.savefig('/content/drive/MyDrive/Colab/Super Resolution using DCGAN/GIF/image_at_epoch_{:04d}.png'.format(epoch + 1))

In [16]:
def mse_loss(HR_batch, SR_batch):
    return tf.reduce_mean(tf.square(HR_batch - SR_batch))

In [17]:
from keras.layers import Input

discriminator_optimizer = keras.optimizers.RMSprop(lr=.0001, clipvalue=1.0, decay=1e-8)
discriminator.compile(loss="binary_crossentropy", optimizer=discriminator_optimizer)
discriminator.trainable = False

input = Input(shape=[32, 32, 3])
SR_image = generator(input)
gan_output = discriminator(SR_image)
gan = Model(inputs=input, outputs=[SR_image, gan_output])
generator_optimizer = keras.optimizers.Adam(lr=0.0002, beta_1=0.5)
gan.compile(loss=[mse_loss, "binary_crossentropy"], loss_weights=[0.8, 0.2],
            optimizer=generator_optimizer)

In [18]:
checkpoint_dir = '/content/drive/MyDrive/Colab/Super Resolution using DCGAN/checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [19]:
def train_gan(generator, discriminator, dataset, batch_size, epochs):
    GIF_seed = tf.image.resize(valid_ds[0], size=[32, 32])

    for epoch in range(epochs):
        G_loss_epoch = 0
        D_loss_epoch = 0
        print("Epoch no. " + str(epoch + 1))
        start_time = time.time()
        
        for batch_number in tq.tqdm(range(dataset.n//batch_size)):
            HR_batch = dataset[batch_number]
            LR_batch = tf.image.resize(HR_batch, size=[32, 32])
            
            #PHASE 1: train the Discriminator
            SR_batch = generator(LR_batch)
            random_wrong_labels = np.random.binomial(1, 0.05, size=[batch_size, 1])
            
            Y1 = tf.constant([[0.]]*batch_size)
            Y1 += .1 * np.random.random_sample(Y1.shape)
            discriminator.trainable = True
            D_loss_epoch += discriminator.train_on_batch(SR_batch, Y1)
            
            Y1 = tf.constant([[1.]]*batch_size)
            Y1 -= .1 * np.random.random_sample(Y1.shape)
            discriminator.trainable = True
            D_loss_epoch += discriminator.train_on_batch(HR_batch, Y1)
                
            
            #PHASE 2: train the Generator
            random_wrong_labels = np.random.binomial(1, 0.05, size=[batch_size, 1])
            Y2 = tf.constant([[1.]]*batch_size)
            Y2 -= .1 * np.random.random_sample(Y2.shape)
            discriminator.trainable = False
            G_losses = gan.train_on_batch(LR_batch, [HR_batch, Y2])
            G_loss_epoch += G_losses[0]

            if (batch_number + 1)%500 == 0:
                plot_gan(SR_batch, LR_batch)
                j = np.random.randint(0, 50)
                plot_gan(generator(tf.image.resize(valid_ds[j], size=[32, 32])), tf.image.resize(valid_ds[j], size=[32, 32]))
                checkpoint.save(file_prefix = checkpoint_prefix)

        D_loss_epoch /= (dataset.n//batch_size)
        G_loss_epoch /= (dataset.n//batch_size)
        print("D_loss = "+str(D_loss_epoch)+"   G_loss = "+str(G_loss_epoch)+"   @epoch "+str(epoch+1)+"   time = "+str(time.time() - start_time))
        plot_gan(SR_batch, LR_batch)
        plt.show()
        
        print("Saving weights and generating GIF images")
        save_GIF_images(generator, GIF_seed, epoch)
        checkpoint.save(file_prefix = checkpoint_prefix)
        
    save_GIF_images(generator, GIF_seed, epoch)
    checkpoint.save(file_prefix = checkpoint_prefix)

### Training

In [None]:
train_gan(generator, discriminator, train_ds, 128, epochs=15)

### Results are in the accompanying Github README.md