In [4]:
BATCH_SIZE = 32

In [1]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from random import choice, sample

In [2]:
import tensorflow
from tensorflow.keras.models import Model, Sequential
import tensorflow.keras.optimizers as Optimizer
from tensorflow.keras.layers import Flatten, Dense, Dropout, LeakyReLU, BatchNormalization, Reshape, Conv2DTranspose, Conv2D, Input
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img, load_img

In [3]:
def getRealMiniBatch(path, batch_size=BATCH_SIZE):
    train_set = os.listdir(path)
    while True:
        batch = sample(train_set, batch_size)
        images = []
        for img in batch:
            try:
                x = os.path.join(path,img)
                x = load_img(x, target_size = (512,512))
                x = img_to_array(x)
                images.append(x)
            except:
                continue
        yield (np.array(images)/127.5)-1

In [4]:
real_images_batch = getRealMiniBatch(path='mention_path_of_dataset_here', batch_size=BATCH_SIZE)

In [None]:
Real_batch = next(real_images_batch)
print(Real_batch.shape[0])

In [None]:
def Generator():
  
    model = Sequential()

    model.add(Dense(8*8*1024, use_bias=False, input_shape = (100,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    model.add(Reshape((8,8,1024)))
    assert model.output_shape == (None, 8, 8, 1024)

    model.add(Conv2DTranspose(512, (3,3), strides = (2, 2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    assert model.output_shape == (None, 16, 16, 512)

    model.add(Conv2DTranspose(256, (3,3), strides = (2, 2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    assert model.output_shape == (None, 32, 32, 256)

    model.add(Conv2DTranspose(128, (3,3), strides = (2, 2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    assert model.output_shape == (None, 64, 64, 128)

    model.add(Conv2DTranspose(64, (3,3), strides = (2, 2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    assert model.output_shape == (None, 128, 128, 64)

    model.add(Conv2DTranspose(32, (3,3), strides = (2, 2), padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    assert model.output_shape == (None, 256, 256, 32)

    model.add(Conv2DTranspose(3, (3,3), strides = (2, 2), padding='same', activation='tanh'))
    assert model.output_shape == (None, 512, 512, 3)

    return model

In [None]:
def Discriminator():

    model = Sequential()

    model.add(Conv2D(32, (3,3), strides=(2, 2), padding='same', input_shape=[512, 512, 3]))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(64, (3,3), strides=(2, 2), padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(128, (3,3), strides=(2, 2), padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(256, (3,3), strides=(2, 2), padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dropout(0.3))

    model.add(Conv2D(512, (3,3), strides=(2, 2), padding='same'))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha = 0.2))
    model.add(Dropout(0.3))

    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))

    return model

In [None]:
def GAN(generator, discriminator):
    discriminator.trainable=False
    noise = Input(shape=(100,))
    generator_output = generator(noise)
    discriminator_output = discriminator(generator_output)
    gan = Model(noise, discriminator_output)
    return gan

In [None]:
generator = Generator()

discriminator = Discriminator()
discriminator.compile(optimizer=Optimizer.Adam(lr = 0.0003), loss = 'binary_crossentropy')

gan = GAN(generator, discriminator)
gan.compile(optimizer=Optimizer.Adam(lr = 0.0003), loss = 'binary_crossentropy')

real = np.ones((BATCH_SIZE,1))
fake = np.zeros((BATCH_SIZE,1))

if os.path.exists('saved_weights'):
    print('Loading weights from saved_models...')
    generator.load_weights('saved_weights/generator_weights.hdf5')
    discriminator.load_weights('saved_weights/discriminator_weights.hdf5')
    gan.load_weights('saved_weights/gan_weights.hdf5')
    print('Complete')
else:
    os.mkdir('saved_weights')

In [None]:
epochs = 5000

for e in range(epochs):
  
    Real_batch = next(real_images_batch)
    while Real_batch.shape[0]!= BATCH_SIZE:
        Real_batch = next(real_images_batch)
    
    noise = np.random.normal(loc=0, scale=1, size=(BATCH_SIZE, 100))
    real_loss = discriminator.train_on_batch(x = Real_batch, y = real)
    Fake_batch = generator.predict_on_batch(noise)
    fake_loss = discriminator.train_on_batch(x = Fake_batch, y = fake)
    
    discriminator_loss = np.add(real_loss, fake_loss)/2
    gan_loss = gan.train_on_batch(noise, real)

    print("Epoch: {} Discriminator loss: {:.3f} GAN loss: {:.3f}".format(e + 1, discriminator_loss, gan_loss))
    
    if e%50==0:
        samples = 10
        x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))
        for k in range(samples):
            plt.subplot(2, 5, k+1)
            x_fake[k] = np.clip((x_fake[k]/2)+0.5,0,1)
            plt.imshow(x_fake[k], cmap='spring')
            plt.xticks([])
            plt.yticks([])

        plt.tight_layout()
        plt.show()
        discriminator.save_weights('/content/saved_weights/discriminator_weights.hdf5')
        generator.save_weights('/content/saved_weights/generator_weights.hdf5')
        gan.save_weights('/content/saved_weights/gan_weights.hdf5')
        print('Saved the model at epoch {}'.format(e + 1))