In [10]:
import tensorflow as tf

In [11]:
from tensorflow.keras.layers import Input, Dense, LeakyReLU, Dropout, \
BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD, Adam

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import sys, os

In [12]:
# load the data
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# map the inputs between (-1,1) for better training to center around zero
x_train, x_test = x_train / 255.0 * 2 - 1, x_test / 255.0 * 2 - 1
print("x_train shape: ", x_train.shape)

x_train shape:  (60000, 28, 28)


In [13]:
# flatten the data to be tabular
N, H, W = x_train.shape
D = H * W
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [21]:
# dimentionality of the latent space
latent_dim = 100

In [15]:
# get the generator model
def build_generator(latent_dim):
  i = Input(shape=(latent_dim,))
  x = Dense(256, activation = LeakyReLU(alpha=0.2))(i)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(512, activation = LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(1024, activation = LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(1, activation = 'tanh')(x) #since the image pixels are between -1,1

  model = Model(i,x)
  return model


In [16]:
def build_discriminator(img_size):
  i = Input(shape=(img_size,))
  x = Dense(512, activation = LeakyReLU(alpha=0.2))(i)
  x = Dense(256, activation = LeakyReLU(alpha=0.2))(x)
  x = BatchNormalization(momentum=0.8)(x)
  x = Dense(1, activation = 'sigmoid')(x) #since this is binary classification
  model = Model(i,x)
  return model

In [22]:
# compile both models to prep for training

#build and compile the discriminator
discriminator = build_discriminator(D)
discriminator.compile(
    loss = 'binary_crossentropy',
    optimizer = Adam(0.0002, 0.5),
    metrics = ['accuracy']
)

#build and compile the generator
generator = build_generator(latent_dim)

#create input for noise from latent space
z = Input(shape=(latent_dim,))

#pass noise through generator to get an image
img = generator(z)

#make sure only the generator gets trained
discriminator.trainable = False

#the true output is fake, but we label them as real
fake_pred = discriminator(img)

#create the cobined model object
combined_model = Model(z, fake_pred)

#compile the combined model
combined_model.compile(loss='binary_crossentropy',optimizer=Adam(0.0002, 0.5))



ValueError: ignored

In [18]:
# train the GAN

#config
batch_size = 32
epochs = 30000
sample_period = 200

#create batch labels to use on calling tain_on_batch
ones=np.ones(batch_size)
zeros=np.zeros(batch_size)

#store the losses
g_losses =  []
d_losses = []

# create a folder to store the images
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')

In [19]:
# function to generate a grid of random samples of 25 images from the generator
def sample_images(epoch):
  rows, cols = 5, 5
  noise = np.random.randn(rows*cols, latent_dim)
  imgs = generator.predict(noise)

  #rescale images between 0,1
  imgs = imgs * 0.5 + 0.5

  fig, axs = plt.subplots(rows, cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H,W), cmap='gray')
      axs[i,j].axis('off')
      idx += 1
  fig.savefig("gan_images/%d.png" %epoch)
  plt.close()

In [20]:
#main training loop
for epoch in range(epochs):
  #train discriminator
  #select a random batch of images
  idx = np.random.randint(0, x_train.shape[0], batch_size)
  real_imgs = x_train[idx]

  #generate fake images
  noise = np.random.randn(batch_size, latent_dim)
  fake_imgs = generator.predict(noise)

  #train the discriminator
  d_loss_real, d_acc_real = discriminator.train_on_batch(real_imgs, ones)
  d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_imgs, zeros)
  d_loss = 0.5 * (d_loss_real + d_loss_fake)
  d_acc = 0.5 * (d_acc_real + d_acc_fake)

  #train generator
  noise = np.random.randn(batch_size, latent_dim)
  g_loss = combined_model.train_on_batch(noise, ones)

  #save the losses
  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch % 100 == 0:
    print(f"epoch: {epoch+1}/{epochs}, d_loss: {d_loss:.2f}, \
    d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")

  if epoch % sample_period == 0:
    sample_images(epoch)




InvalidArgumentError: ignored

In [None]:
plt.plot(g_losses, label='g_losses')
plt.plot(d_losses, label='d_losses')
plt.legend()

In [None]:
!ls gan_images

In [None]:
from skimage.io import imread
a = imread('gan_images/0.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/1000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/5000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/10000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/20000.png')
plt.imshow(a)

In [None]:
a = imread('gan_images/29800.png')
plt.imshow(a)