In [115]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [116]:
from tensorflow.keras.datasets import mnist

# Loading Data

In [117]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [118]:
plt.imshow(X_train[1])
plt.show()

<Figure size 432x288 with 1 Axes>

In [119]:
y_train[1]

0

In [120]:
only_zeroes = X_train[y_train==0]
only_zeroes.shape

(5923, 28, 28)

In [121]:
X_train.shape

(60000, 28, 28)

In [122]:
print(plt.imshow(only_zeroes[14]))

AxesImage(size=(28, 28))


<Figure size 432x288 with 1 Axes>

# Creating the Model

In [123]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential

In [124]:
# Creating discriminator, it tells what is real and what is fake (binary classifier)

discriminator = Sequential()
discriminator.add(Flatten(input_shape=[28,28]))
discriminator.add(Dense(150,activation='relu'))
discriminator.add(Dense(100,activation='relu'))
# Final output layer
discriminator.add(Dense(1,activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer='adam')

In [125]:
# Creating Generator (similar to decoder)
codings_size = 100 # Should be lot less than original dataset, in this case it is 28x28=784

generator = Sequential()
generator.add(Dense(100, activation='relu', input_shape=[codings_size]))
generator.add(Dense(150, activation='relu'))
generator.add(Dense(784, activation='relu'))
generator.add(Reshape([28,28]))

# generator is not compiled, because it is trained through full GAN model

In [126]:
GAN = Sequential([generator,discriminator])

In [127]:
discriminator.trainable=False # In the second phase it must not be trained

In [128]:
GAN.compile(loss='binary_crossentropy', optimizer='adam')

# Training the model

In [129]:
batch_size = 32 # smaller number means longer training time

In [130]:
my_data = only_zeroes
my_data.shape

(5923, 28, 28)

In [131]:
#This is the number of batches that we have
5923/32

185.09375

In [132]:
dataset = tf.data.Dataset.from_tensor_slices(my_data).shuffle(buffer_size=1000)

In [133]:
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
# drop remainder drops the batches that were not divisible by batch_size

In [134]:
epochs=1

In [135]:
GAN.layers[0].summary() # [0] is a generator model

Model: "sequential_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_12 (Dense)            (None, 100)               10100     
                                                                 
 dense_13 (Dense)            (None, 150)               15150     
                                                                 
 dense_14 (Dense)            (None, 784)               118384    
                                                                 
 reshape_1 (Reshape)         (None, 28, 28)            0         
                                                                 
Total params: 143,634
Trainable params: 143,634
Non-trainable params: 0
_________________________________________________________________


In [136]:
GAN.layers[1].summary() # [1] is a discriminator model

Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten_2 (Flatten)         (None, 784)               0         
                                                                 
 dense_9 (Dense)             (None, 150)               117750    
                                                                 
 dense_10 (Dense)            (None, 100)               15100     
                                                                 
 dense_11 (Dense)            (None, 1)                 101       
                                                                 
Total params: 132,951
Trainable params: 0
Non-trainable params: 132,951
_________________________________________________________________


In [137]:
generator, discriminator = GAN.layers

In [138]:
for epoch in range(epochs): # i is the number of batches
    print(f"Currenty on Epoch{epoch+1}")
    i=0

    for X_batch in dataset:
        i = i+1

        if i%100 == 0:
            print(f"\t Currently on batch number {i} of {len(my_data)//batch_size}")

        # DISCRIMINATOR TRAINING PHASE
        noise = tf.random.normal(shape=[batch_size, codings_size])
        gen_images = generator(noise)
        X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch,tf.float32)], axis=0)
        y1 = tf.constant([[0.0]]*batch_size + [[1.0]]*batch_size)
        discriminator.trainable = True
        discriminator.train_on_batch(X_fake_vs_real, y1)

        # GENERATOR TRAINING PHASE
        noise = tf.random.normal(shape=[batch_size, codings_size])
        y2 = tf.constant([[1.0]]*batch_size) # we want the generator to believe that all images are real
        discriminator.trainable=False
        GAN.train_on_batch(noise, y2)

Currenty on Epoch1
	 Currently on batch number 100 of 185


In [141]:
noise = tf.random.normal(shape=[10, codings_size])

In [142]:
noise.shape

TensorShape([10, 100])

In [143]:
plt.imshow(noise)

<matplotlib.image.AxesImage at 0x7f3954bdcbb0>

<Figure size 432x288 with 1 Axes>

In [144]:
images = generator(noise)

In [145]:
images.shape

TensorShape([10, 28, 28])

In [147]:
plt.imshow(images[0])

<matplotlib.image.AxesImage at 0x7f39ad53d550>

<Figure size 432x288 with 1 Axes>