<a href="https://colab.research.google.com/github/milindsoorya/colab-notebooks/blob/main/GAN_on_Fashion_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf

# Maintain consistent performance
tf.random.set_seed(1)

In [None]:
from tensorflow.keras.datasets import fashion_mnist

(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [None]:
import numpy as np

# As we are not bothered with classification, we can combine the train and test data.
dataset = np.concatenate([x_train, x_test], axis=0)
# Add extra dimension as the convolution layer expects 3 channels, 28x28 --> 28x28x1
# Also normalising the value to [0, 1]
dataset = np.expand_dims(dataset, -1).astype("float32") / 255

In [None]:
BATCH_SIZE = 64

# Convolution layers need 3 channels
dataset = np.reshape(dataset, (-1, 28, 28, 1))
# Create a tensorflow dataset object
dataset = tf.data.Dataset.from_tensor_slices(dataset)
# Set the batch size otherwise it reads one image at a time
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

In [None]:
from tensorflow import keras
from tensorflow.keras import layers

# The generators input is a noise vector
# A hyper parameter that can be fine tuned
NOISE_DIM = 150

generator = keras.models.Sequential([
  keras.layers.InputLayer(input_shape=(NOISE_DIM,)),
  layers.Dense(7*7*256),
  layers.Reshape(target_shape = (7, 7, 256)),
  layers.Conv2DTranspose(256, 3, activation="LeakyReLU", strides=2, padding="same"),
  layers.Conv2DTranspose(128, 3, activation="LeakyReLU", strides=2, padding="same"),
  layers.Conv2DTranspose(1, 3, activation="sigmoid", padding="same"),

])

generator.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 12544)             14438144  
                                                                 
 reshape (Reshape)           (None, 7, 7, 256)         0         
                                                                 
 conv2d_transpose (Conv2DTra  (None, 14, 14, 256)      590080    
 nspose)                                                         
                                                                 
 conv2d_transpose_1 (Conv2DT  (None, 28, 28, 128)      295040    
 ranspose)                                                       
                                                                 
 conv2d_transpose_2 (Conv2DT  (None, 28, 28, 1)        1153      
 ranspose)                                                       
                                                        

In [None]:
discriminator = keras.models.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28, 1)),
  layers.Conv2DTranspose(256, 3, activation="relu", strides=2, padding="same"),
  layers.Conv2DTranspose(128, 3, activation="relu", strides=2, padding="same"),
  layers.Flatten(),
  layers.Dense(64, activation="relu"),
  layers.Dropout(0.2),
  layers.Dense(1, activation="sigmoid")
])

discriminator.summary()

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_transpose_3 (Conv2DT  (None, 56, 56, 256)      2560      
 ranspose)                                                       
                                                                 
 conv2d_transpose_4 (Conv2DT  (None, 112, 112, 128)    295040    
 ranspose)                                                       
                                                                 
 flatten (Flatten)           (None, 1605632)           0         
                                                                 
 dense_1 (Dense)             (None, 64)                102760512 
                                                                 
 dropout (Dropout)           (None, 64)                0         
                                                                 
 dense_2 (Dense)             (None, 1)                

In [None]:
optimizerG = keras.optimizers.Adam(learning_rate=0.00001, beta_1=0.5)
optimizerD = keras.optimizers.Adam(learning_rate=0.00003, beta_1=0.5)

# Binary classifier
lossFn = keras.losses.BinaryCrossentropy(from_logits=True)

# Accuracy metric
gAccMetric = tf.keras.metrics.BinaryAccuracy()
dAccMetric = tf.keras.metrics.BinaryAccuracy()

In [None]:
@tf.function
def trainDStep(data):
  batchSize = tf.shape(data)[0]
  # Create a noise vector as generator input, sampled from Gaussian Random Normal
  # Try usinf uniform distribution
  noise = tf.random.normal(shape=(batchSize, NOISE_DIM))

  # Concatenate the real and fake labels
  y_true = tf.concat([
      # The original data is real, labeled with 1
      tf.ones(batchSize, 1),
      # The forged data is fake, labeled with 0
      tf.zeros(batchSize, 1) 
  ],
  axis=0
  )

  # Record the calculated gradients
  with tf.GradientTape() as tape:
    # Generate forged samples
    fake = generator(noise)
    # Concatenate real data and forged data
    x = tf.concat([data, fake], axis=0)
    # See if the discriminator detects them
    y_pred = discriminator(x)
    # Calculate the loss
    discriminatorLoss = lossFn(y_true, y_pred)

  # Apply the backpropagation and update the weights
  grads = tape.gradient(discriminatorLoss, discriminator.trainable_weights)
  optimizerD.apply_gradients(zip(grads, discriminator.trainable_weights))

  # report accuracy
  dAccMetric.update_state(y_true, y_pred)

  # return the loss for visualisation
  return {
      "discriminator_loss": discriminatorLoss,
      "discriminator_accuracy": dAccMetric.result()
  }

In [None]:
@tf.function
def trainGStep(data):
  batchSize = tf.shape(data)[0]
  noise = tf.random.normal(shape=(batchSize, NOISE_DIM))

  y_true = tf.ones(batchSize, 1)

  with tf.GradientTape() as tape:
    y_pred = discriminator(generator(noise))
    generatorLoss = lossFn(y_true, y_pred)

  grads = tape.gradient(generatorLoss, discriminator.trainable_weights)
  optimizerG.apply_gradients(zip(grads, discriminator.trainable_weights))

  # report accuracy
  gAccMetric.update_state(y_true, y_pred)

  # return the loss for visualisation
  return {
      "generator_loss": generatorLoss,
      "generator_accuracy": gAccMetric.result()
  }

In [None]:
from matplotlib import pyplot as plt

def plotImages(model):
  images = model(np.random.normal(size=(81, NOISE_DIM)))

  plt.figure(figsize=(9, 9))

  for i, image in enumerate(images):
    plt.subplot(9, 9, i+1)
    plt.imshow(np.squeeze(image, -1), cmap="Greys_r")
    plt.axis('off')

  plt.show()

In [None]:
for epoch in range(30):

  # Accumulate the loss to calculate the average at the end of the epoch
  dLossSum = 0
  gLossSum = 0
  dAccSum = 0
  gAccSum = 0
  cnt = 0

  # Loop the dataset pne batch at a time
  for batch in dataset:
    # Train the discriminator
    dLoss = trainDStep(batch)
    dLossSum += dLoss['discriminator_loss']
    dAccSum += dLoss['discriminator_accuracy']

    gLoss = trainGStep(batch)
    gLossSum += gLoss['generator_loss']
    gAccSum += gLoss['generator_accuracy']

    # Increment the counter
    cnt += 1

  # Log the performance
  print("E:{}, Loss G:{:0.4f}, Acc G:%{:0.2f}, Acc D:%{:0.2f}".format(
      epoch,
      gLossSum/cnt,
      dLossSum/cnt,
      100 * gAccSum/cnt,
      100 * dAccSum/cnt
  ))

  if epoch % 2 == 0:
    plotImages(generator)

KeyboardInterrupt: ignored

In [None]:
# Generate some images with the trained model
# Observe the generated samples seem to belong to the same or similar class, this 
# is the "model collapse problem of GANs
images = generator(np.random.normal(size=(81, NOISE_DIM)))

# Plot the generated samples
from matplotlib import pyplot as plotImages

plt.figure(figsize=(9, 9))

for i, image in enumerate(images):
  plt.subplot(9, 9, i+1)
  plt.imshow(np.squeeze(image, -1), cmap="Grey_r")
  plt.axis('off')

plt.show()