In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.models import Sequential
from tensorflow.keras import Model 
from tensorflow.keras.layers import Dense, Dropout, Input
from tensorflow.keras.datasets import mnist

In [2]:
class Generator(keras.Model):
    def __init__(self, input_dim=100, output_dim=784, name="Generator", **kwargs):
        super(Generator, self).__init__(name=name, **kwargs)

        self.input_dim = input_dim
        self.output_dim = output_dim

        self.hidden = list()
        self.hidden.append(Dense(units=256,input_dim=input_dim, activation=tf.nn.leaky_relu,  name='generator_input'))
        self.hidden.append(Dense(units=512, activation=tf.nn.leaky_relu))
        self.hidden.append(Dense(units=1024, activation=tf.nn.leaky_relu))
        self.output_layer = Dense(units=output_dim, activation='tanh', name='generator_output')

    def call(self, inputs):
        x = self.hidden[0](inputs)
        for layer in self.hidden[1:]:
            x = layer(x)
        return self.output_layer(x)
    
    def generate_noise(self,batch_size, random_noise_size):
        return np.random.uniform(-1,1, size = (batch_size, random_noise_size))
  

In [3]:
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits = True)

def generator_objective(dx_of_gx):
    # Labels are true here because generator thinks he produces real images. 
    return cross_entropy(tf.ones_like(dx_of_gx), dx_of_gx) 


In [4]:
class Discriminator(keras.layers.Layer):
    def __init__(self, input_dim=784, name="Discriminator", **kwargs):
        super(Discriminator, self).__init__(name=name, **kwargs)

        self.input_dim = input_dim

        self.hidden = list()
        self.hidden.append(Dense(units=1024,input_dim=input_dim, activation=tf.nn.leaky_relu,  name='discriminator_input'))
        self.hidden.append(Dropout(0.2))
        self.hidden.append(Dense(units=512, activation=tf.nn.leaky_relu))
        self.hidden.append(Dropout(0.2))
        self.hidden.append(Dense(units=256, activation=tf.nn.leaky_relu))
        self.output_layer = Dense(units=1, activation='sigmoid', name='discriminator_output')

    def call(self, inputs):
        x = self.hidden[0](inputs)
        for layer in self.hidden[1:]:
          x = layer(x)
        return self.output_layer(x)

In [5]:
def discriminator_objective(d_x, g_z, smoothing_factor = 0.9):
    """
    d_x = real output
    g_z = fake output
    """
    real_loss = cross_entropy(tf.ones_like(d_x) * smoothing_factor, d_x) # If we feed the discriminator with real images, we assume they all are the right pictures --> Because of that label == 1
    fake_loss = cross_entropy(tf.zeros_like(g_z), g_z) # Each noise we feed in are fakes image --> Because of that labels are 0. 
    total_loss = real_loss + fake_loss
    
    return total_loss

In [6]:
class GAN(keras.Model):
    def __init__(self, input_dim=100, latent_dim=784, name="GAN", **kwargs):
        super(GAN, self).__init__(name=name, **kwargs)
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.G = Generator(latent_dim, input_dim)
        self.D = Discriminator(input_dim)

    def call(self, input):
        x = self.G(input)
        x = tf.concat([x, input[1]])
        x = self.D(x)
        return x


In [7]:
class GANback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs={}):
        self.model.D.trainable = !self.model.D.trainable
callback = GANback()

In [11]:
(X, _), (X_test, _) = mnist.load_data()
X = (X.astype(np.float32) - 127.5)/127.5
X = X.reshape(60000, 784)

In [10]:
BATCH_SIZE = 256
BUFFER_SIZE = 60000
EPOCHES = 300
OUTPUT_DIR = "img"
train_dataset = tf.data.Dataset.from_tensor_slices(X.reshape(X.shape[0],784)).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
generator_optimizer = keras.optimizers.RMSprop()
discriminator_optimizer = keras.optimizers.RMSprop()

In [12]:
@tf.function()
def training_step(generator: Discriminator, discriminator: Discriminator, images:np.ndarray , k:int =1, batch_size = 32):
    for _ in range(k):
         with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
            noise = generator.generate_noise(batch_size, 100)
            g_z = generator(noise)
            d_x_true = discriminator(images) # Trainable?
            d_x_fake = discriminator(g_z) # dx_of_gx

            discriminator_loss = discriminator_objective(d_x_true, d_x_fake)
            # Adjusting Gradient of Discriminator
            gradients_of_discriminator = disc_tape.gradient(discriminator_loss, discriminator.trainable_variables)
            discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables)) # Takes a list of gradient and variables pairs
            
              
            generator_loss = generator_objective(d_x_fake)
            # Adjusting Gradient of Generator
            gradients_of_generator = gen_tape.gradient(generator_loss, generator.trainable_variables)
            generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables)) 

In [None]:
def training(dataset, epoches):
    for epoch in range(epoches):
        for batch in dataset: 
            training_step(generator, discriminator, batch ,batch_size = BATCH_SIZE, k = 1)
            
        ## After ith epoch plot image 
        if (epoch % 50) == 0: 
            fake_image = tf.reshape(generator(seed), shape = (28,28))
            print(f"{epoch}/{epoches} epoches")
            #plt.imshow(fake_image, cmap = "gray")
            plt.imsave("{OUTPUT_DIR}/{epoch}.png".format(OUTPUT_DIR,epoch),fake_image, cmap = "gray")