<a href="https://colab.research.google.com/github/jivemachine/GAN/blob/main/Simple_GAN_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# imports

In [1]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import tensorflow as tf
from tensorflow import keras

In [2]:
# importing the MNIST fashion dataset from keras
(X_train_full, y_train_full), (X_test, y_test) = keras.datasets.fashion_mnist.load_data()
X_train_full = X_train_full.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
# building the generator
size = 30 # size of images we are generating

generator = keras.models.Sequential([
      keras.layers.Dense(100, activation='selu', input_shape=[size]),
      keras.layers.Dense(150, activation='selu'),
      keras.layers.Dense(28 * 28, activation='sigmoid'),
      keras.layers.Reshape([28, 28])
])

In [4]:
# building the discriminator
discriminator = keras.models.Sequential([
    keras.layers.Flatten(input_shape=[28,28]),
    keras.layers.Dense(150, activation='selu'),
    keras.layers.Dense(100, activation='selu'),
    keras.layers.Dense(1, activation='sigmoid')
])

In [5]:
# putting our simple gan together
gan = keras.models.Sequential([generator, discriminator])

In [7]:
# compiling the model
discriminator.compile(loss='binary_crossentropy', optimizer='rmsprop')
discriminator.trainable = False # the discriminator should not be trained during the second phase, before our gan
gan.compile(loss='binary_crossentropy', optimizer='rmsprop')

In [8]:
# since the training loop is unusual we need to write a custom training loop for the model
batch_size = 32
dataset = tf.data.Dataset.from_tensor_slices(X_train).shuffle(1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)


In [9]:
# training loop
def train_gan(gan, dataset, batch_size, size, n_epochs=50):
  generator, discriminator = gan.layers
  for epoch in range(n_epochs):
    for X_batch in dataset:
      # phase 1 - training the discriminator
      noise = tf.random.normal(shape=[batch_size, size])
      generated_images = generator(noise)
      X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
      y1 = tf.constant([[0.]] * batch_size + [[1.]] * batch_size)
      discriminator.trainable = False
      discriminator.train_on_batch(X_fake_and_real, y1)
      # phase 2 - training the generator
      noise = tf.random.normal(shape=[batch_size, size])
      y2 = tf.constant([[1.]] * batch_size)
      discriminator.trainable = False
      gan.train_on_batch(noise, y2)