# GAN for Generating MNIST Handwritten Digits

#### Importing Libraries

In [None]:
import tensorflow as tf

from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Concatenate
from tensorflow.keras.layers import BatchNormalization, Activation, Embedding, multiply
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.utils import to_categorical

import matplotlib.pyplot as plt
import numpy as np

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline

## Loading data

In [None]:
# load the mnist dataset 
(x_train, y_train), (x_test, y_test)  = tf.keras.datasets.mnist.load_data()
print('Train', x_train.shape, y_train.shape)
print('Test', x_test.shape, y_test.shape)

In [None]:
# plot raw pixel data
plt.imshow(x_train[10], cmap='gray')

In [None]:
img_shape = (28, 28, 1)
noise_dim = 100
num_classes = 10

### Preprocessing Dataset

In [None]:
# normalizing
# scale from [0,255] to [-1,1]
X_train = (x_train - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)

## Modeling

### Generator model

In [None]:
def def_generator(noise_dim):
    
    model = Sequential()
    
    model.add(Dense(7*7*256, input_shape=(noise_dim, )))
    model.add(Reshape((7, 7, 256)))
    
    # 7*7*256 => 14*14*128
    model.add(Conv2DTranspose(128, kernel_size=3, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))
    
    # 14*14*128 => 14*14*64
    model.add(Conv2DTranspose(64, kernel_size=3, strides=1, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))
    
    # 14*14*64 => 28*28*1
    model.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same'))
    model.add(Activation('tanh'))
    
    z = Input(shape=(noise_dim, ))
    
    # Conditioning label
    label = Input(shape=(1,), dtype='int32')
    
    # embedding layer:
    # turns labels into dense vectors of size noise_dim
    # produces 3D tensor with shape: (batch_size, 1, noise_dim)
    label_embedding = Embedding(num_classes, noise_dim, input_length=1)(label)
    
    # Flatten the embedding 3D tensor into 2D  tensor with shape: (batch_size, noise_dim)
    label_embedding = Flatten()(label_embedding)
    
    # Element-wise product of the vectors z and the label embeddings
    joined_representation = multiply([z, label_embedding])
    
    img = model(joined_representation)
    
    return Model([z, label], img)


In [None]:
# creating the generator model
gen = def_generator(noise_dim)
gen.summary()
# the generator takes noise and the target label as input
# and generates the corresponding digit for that label

## The Discriminator

In [None]:
# define the standalone discriminator model
def def_discriminator(img_shape):
    
    model = Sequential()
    
    # 28*28*2 => 14*14*64
    model.add(Conv2D(64, kernel_size=3, strides=2, padding='same', input_shape=(28, 28, 2)))
    model.add(LeakyReLU(alpha=0.01))
    
    # 14*14*64 => 7*7*64
    model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))
    
    # 7*7*128 => 3*3*128
    model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.01))
    
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    
    img = Input(shape=img_shape)
    
    label = Input(shape=(1,), dtype='int32')
    
    # embedding layer:
    # turns labels into dense vectors of size 28*28*1
    # produces 3D tensor with shape: (batch_size, 1, 28*28*1)
    label_embedding = Embedding(input_dim=num_classes, output_dim=np.prod(img_shape), input_length=1)(label)
    # Flatten the embedding 3D tensor into 2D  tensor with shape: (batch_size, 28*28*1)
    label_embedding = Flatten()(label_embedding)
    # Reshape label embeddings to have same dimensions as input images
    label_embedding = Reshape(img_shape)(label_embedding)
    
    # concatenate images with corresponding label embeddings
    concatenated = Concatenate(axis=-1)([img, label_embedding])
    
    prediction = model(concatenated)
    
    return Model([img, label], prediction)


In [None]:
# building and compiling the Discriminator
disc = def_discriminator(img_shape)
disc.compile(loss='binary_crossentropy', metrics=['accuracy'], optimizer=Adam())
disc.summary()

## GAN model

In [None]:

z = Input(shape=(noise_dim,))
label = Input(shape=(1,))

img = gen([z, label])

# keep the discriminator's params constant for generator training
disc.trainable = False

prediction = disc([img, label])

# Conditional (Conditional) GAN model with fixed discriminator to train the generator
cgan = Model([z, label], prediction)
cgan.compile(loss='binary_crossentropy', optimizer=Adam())
cgan.summary()

In [None]:
# function to display images
def display_images(epoch,image_grid_rows=2, image_grid_columns=5):
    z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, noise_dim))
    labels = np.arange(0, 10).reshape(-1, 1)
    gen_imgs = gen.predict([z, labels])
    gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(10,4), sharey=True, sharex=True)
    fig.suptitle('ittr:'+str(epoch), fontsize=10)
    cnt = 0
    for i in range(image_grid_rows):
        for j in range(image_grid_columns):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            axs[i,j].set_title("Digit: %d" % labels[cnt])
            cnt += 1
    plt.savefig('itter'+str(epoch)+'.png')

### Training

In [None]:
accuracies = []
losses = []

def train(iterations, batch_size, sample_interval = 500):

    
    real = np.ones(shape=(batch_size, 1))
    fake = np.zeros(shape=(batch_size, 1))
    
    for iteration in range(iterations):
        
        idx = np.random.randint(0, X_train.shape[0], batch_size)
        imgs, labels = X_train[idx], y_train[idx]
        
        z = np.random.normal(0, 1, size=(batch_size, noise_dim))
        gen_imgs = gen.predict([z, labels])
        
        d_loss_real = disc.train_on_batch([imgs, labels], real)
        d_loss_fake = disc.train_on_batch([gen_imgs, labels], fake)
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
        
        z = np.random.normal(0, 1, size=(batch_size, noise_dim))
        labels = np.random.randint(0, num_classes, batch_size).reshape(-1, 1)
        
        g_loss = cgan.train_on_batch([z, labels], real)
        
        if iteration % sample_interval == 0:
            print('{} >> G loss: {} D loss: {}, accuracy: {:.2f}'.format(iteration,g_loss, d_loss[0], 100 * d_loss[1] ))
        
            losses.append((d_loss[0], g_loss))
            accuracies.append(d_loss[1])
            
            display_images(iteration)
    

In [None]:
epochs = 20000
batch_size = 128

train(epochs, batch_size)


In [None]:
d_loss,g_loss = np.hsplit(losses[0:15,:],2)
d_loss.shape

In [None]:
losses = np.array(losses)
# d_loss = losses[0]

# Plot training losses for Discriminator and Generator
plt.figure(figsize=(15, 5))
plt.plot(range(0,15*500,500),d_loss.reshape(-1), label="Discriminator loss")
plt.plot(range(0,15*500,500),g_loss.reshape(-1), label="Generator loss")

plt.xticks(range(0,15*500,500), rotation=90)

plt.title("Training Loss")
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.legend()

## Generating mnist Images

In [None]:
image_grid_rows=2
image_grid_columns=5
z = np.random.normal(0, 1, (image_grid_rows * image_grid_columns, noise_dim))
labels = np.arange(0, 10).reshape(-1, 1)
gen_imgs = gen.predict([z, labels])
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(image_grid_rows, image_grid_columns, figsize=(10,4), sharey=True, sharex=True)
cnt = 0
for i in range(image_grid_rows):
    for j in range(image_grid_columns):
        axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
        axs[i,j].axis('off')
        axs[i,j].set_title("Digit: %d" % labels[cnt])
        cnt += 1

In [None]:
digit = 9

z = np.random.normal(0, 1, (1, noise_dim))
gen_imgs = gen.predict([z, np.array(digit).reshape(-1,1)])
gen_imgs = 0.5 * gen_imgs + 0.5
plt.imshow(gen_imgs[0,:,:,0],cmap='gray')