<a href="https://colab.research.google.com/github/jivemachine/DCGAN/blob/main/DCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports

In [1]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras

In [2]:
# pulling fashion mnist dataset from keras' dataset library
(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
X_train_full = X_train_full.astype(np.float32) / 255 # normalizing train data
X_test = X_test.astype(np.float32) / 255 # normalizing test data
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [5]:
coding_size = 100
# building the generator
generator = keras.models.Sequential([
  keras.layers.Dense(7 * 7 * 128, input_shape=[coding_size]),
  keras.layers.Reshape([7, 7, 128]),
  keras.layers.BatchNormalization(),
  keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding='same', activation='selu'),
  keras.layers.BatchNormalization(),
  keras.layers.Conv2DTranspose(1, kernel_size=5, strides=2, padding='same', activation='tanh')              
])


In [6]:
# building discriminator
discriminator = keras.models.Sequential([
  keras.layers.Conv2D(64, kernel_size=5, strides=2, padding='same', activation=keras.layers.LeakyReLU(0.2), input_shape=[28, 28, 1]),
  keras.layers.Dropout(0.4),
  keras.layers.Conv2D(128, kernel_size=5, strides=2, padding='same', activation=keras.layers.LeakyReLU(0.2)),
  keras.layers.Dropout(0.4),
  keras.layers.Flatten(),
  keras.layers.Dense(1, activation='sigmoid')                                                                                    
])

In [7]:
gan = keras.models.Sequential([generator, discriminator])

In [8]:
# compiling the discriminator & the gan
discriminator.compile(loss='binary_crossentropy', optimizer="rmsprop")
discriminator.trainable=False
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')

In [9]:
# we need to reshape the training data to the same range as the generator
X_train = X_train.reshape(-1, 28, 28, 1) * 2. - 1. # reshape & rescale

In [10]:
# since the training loop is unusual we need to write a custom training loop for the model
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

In [11]:
# code so we can see the generated images from the gan
def plot_multiple_images(images, n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols, n_rows))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")    

In [12]:
# training loop
def train_gan(gan, dataset, batch_size, size, n_epochs=50):
  generator, discriminator = gan.layers
  for epoch in range(n_epochs):
    print("Epoch {}/{}".format(epoch + 1, n_epochs))
    for X_batch in dataset:
      # phase 1 - training the discriminator
      noise = tf.random.normal(shape=[batch_size, size])
      generated_images = generator(noise)
      X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
      y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
      discriminator.trainable = False
      discriminator.train_on_batch(X_fake_and_real, y1)
      # phase 2 - training the generator
      noise = tf.random.normal(shape=[batch_size, size])
      y2 = tf.constant([[1.]] * batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y2)
    plot_multiple_images(generated_images, 8)
    plt.show()                     