# GAN - MNIST + Keras

## 0. Necessities

### 0.0 Imports

In [None]:
# Maths
import numpy                                            as np

# Matplotlib
import matplotlib                                       as mp
import matplotlib.pyplot                                as pt

# Machine / Deep Learning
import tensorflow                                       as tf
import keras                                            as ks
from keras              import models                   as mls
from keras              import layers                   as lys
from keras.datasets     import mnist                    as mn
from keras.utils        import to_categorical           as tc

# Versions
print( f"Numpy .... : {np.__version__}" )
print( f"Matplotlib : {mp.__version__}" )
print( f"Tensorflow : {tf.__version__}" )
print( f"Keras .... : {ks.__version__}" )

### 0.1 Class Declaration : MonoGAN

In [None]:
class MonoGAN( ks.Model ):
    def __init__( self, discriminator, generator, dimension ):
        super( MonoGAN, self ).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.dimension = dimension

    def compile( self, dr_opt, gr_opt, loss_function ):
        super( MonoGAN, self ).compile()
        self.dr_opt = dr_opt
        self.gr_opt = gr_opt
        self.loss_function = loss_function

    def train_step( self, real_images ):

        if isinstance( real_images, tuple ):
            real_images = real_images[0]

        # Sample random points in the latent space
        batch_size = tf.shape( real_images )[0]
        random_latent_vectors = tf.random.normal( shape = ( batch_size, self.dimension ) )

        # Decode them to fake images
        generated_images = self.generator( random_latent_vectors )

        # Combine them with real images
        combined_images = tf.concat(
            [generated_images, real_images],
            axis = 0
        )

        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones( ( batch_size, 1 ) ), tf.zeros( ( batch_size, 1 ) )],
            axis = 0
        )
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform( tf.shape( labels ) )

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator( combined_images )
            dr_loss = self.loss_function( labels, predictions )
        
        grads = tape.gradient( dr_loss, self.discriminator.trainable_weights )
        self.dr_opt.apply_gradients(
            zip( grads, self.discriminator.trainable_weights )
        )

        # Sample random points in the latent space
        random_latent_vectors = tf.random.normal( shape = ( batch_size, self.dimension ) )

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros( ( batch_size, 1 ) )

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            gr_loss = self.loss_function(misleading_labels, predictions)

        grads = tape.gradient( gr_loss, self.generator.trainable_weights )
        self.gr_opt.apply_gradients( zip( grads, self.generator.trainable_weights ) )

        return {
            "dr_loss": dr_loss,
            "gr_loss": gr_loss
        }

### 0.2 Class Declaration : MonoGAN Results Viewer

In [None]:
class MonoGAN_Viewer( ks.callbacks.Callback ):
    def __init__( self, image_number = 100, dimension = 128 ):
        self.image_number = image_number
        self.dimension  = dimension

    def on_epoch_end( self, epoch, logs = None ):
        random_latent_vectors = tf.random.normal( shape = ( self.image_number, self.dimension ) )
        generated_images = self.model.generator( random_latent_vectors )
        generated_images *= 255
        generated_images.numpy()
        for i in range( self.image_number ):
            img = ks.preprocessing.image.array_to_img( generated_images[i] )
            img.save( "generated_img_{i}_{epoch}.png".format( i = i, epoch = epoch ) )

## 1. Set Up Model

### 1.0 Set Up Discriminator

In [None]:
# Set up discriminator

dr = mls.Sequential( name = "Discriminator" )

dr.add( lys.Input( shape = ( 28, 28, 1 ) ) )
dr.add( lys.Conv2D( 64, ( 3, 3 ), strides = ( 2, 2 ), padding = "same" ) )
dr.add( lys.LeakyReLU( alpha = 0.2 ) )
dr.add( lys.Conv2D( 128, ( 3, 3 ), strides = ( 2, 2 ), padding = "same" ) )
dr.add( lys.LeakyReLU( alpha = 0.2 ) )
dr.add( lys.GlobalMaxPooling2D() )
dr.add( lys.Dense( 1 ) )

dr.summary()

### 1.1 Set Up Generator

In [None]:
dimension = 128

gr = mls.Sequential( name = "Generator" )

gr.add( lys.Input( shape = ( dimension, ) ) )
gr.add( lys.Dense( 7 * 7 * 128 ) )
gr.add( lys.LeakyReLU( alpha = 0.2 ) )
gr.add( lys.Reshape( ( 7, 7, 128 ) ) )
gr.add( lys.Conv2DTranspose( 128, ( 4, 4 ), strides = ( 2, 2 ), padding = "same" ) )
gr.add( lys.LeakyReLU( alpha = 0.2 ) )
gr.add( lys.Conv2DTranspose( 128, ( 4, 4 ), strides = ( 2, 2 ), padding = "same" ) )
gr.add( lys.LeakyReLU( alpha = 0.2 ) )
gr.add( lys.Conv2DTranspose( 1, ( 7, 7 ), padding = "same", activation = "sigmoid" ) )

gr.summary()

## 2. Setup MNIST Dataset 

In [None]:
( tr_images, tr_labels ), ( ts_images, ts_labels ) = mn.load_data()

batch_size = 64
# partial_size = 100

all_digits = np.concatenate( [tr_images, ts_images] )                             # --> Full
# all_digits = tr_images[:partial_size]                                             # --> Partial
all_digits = all_digits.astype( 'float32' ) / 255
all_digits = np.reshape( all_digits, ( -1, 28, 28, 1 ) )

dataset = tf.data.Dataset.from_tensor_slices( all_digits )
dataset = dataset.shuffle( buffer_size=1024 ).batch( batch_size ).prefetch( 32 )

## 3. Create GAN Model

In [None]:
epochs = 30

gan = MonoGAN(
    discriminator   = dr,
    generator       = gr,
    dimension       = dimension
)

## 4.  Compile GAN Model

In [None]:
gan.compile(
    dr_opt          = ks.optimizers.Adam( learning_rate = 0.0003 ),
    gr_opt          = ks.optimizers.Adam( learning_rate = 0.0003 ),
    loss_function   = ks.losses.BinaryCrossentropy( from_logits = True )
)


## 4. Train Model

In [None]:
gan.fit(
    dataset,
    epochs = epochs,
    callbacks = [ MonoGAN_Viewer( image_number = 3, dimension = dimension ) ]
)

## 5. Evaluate Model

In [None]:
for i in range( 3 ):
    pt.subplot( 3, 3, i + 1 )
    pt.figure( figsize = ( 30, 30 ) )
    im = pt.imread( f"generated_img_{i}_29.png" )
    pt.imshow( im )