In [1]:
!pip install tensorflow==2.0.0 -q

[K     |████████████████████████████████| 86.3MB 68kB/s 
[K     |████████████████████████████████| 450kB 43.9MB/s 
[K     |████████████████████████████████| 3.8MB 34.8MB/s 
[?25h  Building wheel for gast (setup.py) ... [?25l[?25hdone


In [0]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, LeakyReLU, Dropout, BatchNormalization
from tensorflow.keras.optimizers import SGD, Adam

import os

In [0]:
from tensorflow.keras.datasets import mnist

In [4]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [11]:
x_train.shape

(60000, 784)

In [0]:
x_train = x_train/255.0 *2 - 1
x_test = x_test/255.0 *2 - 1

In [0]:
#Flatten the data
N, H, W = x_train.shape
D = H*W
x_train = x_train.reshape(-1, D)
x_test = x_test.reshape(-1, D)

In [0]:
latent_dim = 100

In [0]:
#Generator Model
def build_generator(latent_dim):
  i = Input((latent_dim,))
  x = Dense(256, activation=LeakyReLU(alpha = 0.2))(i)
  x = BatchNormalization(momentum=0.7)(x)
  x = Dense(512, activation=LeakyReLU(alpha = 0.2))(x)
  x = BatchNormalization(momentum=0.7)(x)
  x = Dense(1024, activation=LeakyReLU(alpha = 0.2))(x)
  x = BatchNormalization(momentum=0.7)(x)
  x = Dense(D, activation='tanh')(x)
  model = Model(i, x)
  return model

In [0]:
#Discriminator Model
def build_discriminator(img_size):
  i = Input((img_size,))
  x = Dense(512, activation=LeakyReLU(alpha = 0.2))(i)
  x = Dense(256, activation=LeakyReLU(alpha= 0.2))(x)
  x = Dense(1, activation='sigmoid')(x)
  model = Model(i,x)
  return model

In [0]:
#Compiling the model
discriminator = build_discriminator(D)
discriminator.compile(
    loss = 'binary_crossentropy',
    optimizer = Adam(learning_rate=0.0002, beta_1=0.5),
    metrics = ['accuracy']
)

generator = build_generator(latent_dim)
z = Input(shape=(latent_dim,))
img = generator(z)
discriminator.trainable = False

fake_pred = discriminator(img)

combined_model = Model(z, fake_pred)

In [0]:
combined_model.compile(
    loss = 'binary_crossentropy',
    optimizer = Adam(0.0002, 0.5),
    metrics = ['accuracy']

)

In [0]:
#Training the GAN
batch_size = 32
epochs = 30000
sample_period = 200

In [0]:
#creating batch labels to use when calling train_on_batch

ones = np.ones(batch_size)
zeros = np.zeros(batch_size)
d_losses = []
g_losses = []

In [0]:
#create a folder to store the images of GAN
if not os.path.exists('gan_images'):
  os.makedirs('gan_images')


In [0]:
#A function to generate a grid of random samples from the sequence

def sample_images(epoch):
  rows, cols = 5,5
  noise = np.random.randn(rows*cols, latent_dim)
  imgs = generator.predict(noise)

  #rescale the images 0-1
  imgs = 0.5*imgs + 0.5
  fig, axs = plt.subplot(rows, cols)
  idx = 0
  for i in range(rows):
    for j in range(cols):
      axs[i,j].imshow(imgs[idx].reshape(H,W), cmap='gray')
      axs[i,j].axis('off')
      idx += 1
  fig.savefig('gan_images/%d.png' % epoch)
  plt.close()

In [44]:
#Main training loop

#Training discriminator

for epoch in range(epochs):
  #selecting random batch of images
  idx = np.random.randint(0, x_train.shape[0], batch_size)
  real_img = x_train[idx]

  #generating fake images
  noise = np.random.randn(batch_size, latent_dim)
  fake_img = generator.predict(noise)

  #Training the discriminator returning both loss and accuracy
  d_loss_real, d_acc_real = discriminator.train_on_batch(real_img, ones)
  d_loss_fake, d_acc_fake = discriminator.train_on_batch(fake_img, zeros)

  d_loss = 0.5*(d_loss_real + d_loss_fake)
  d_acc = 0.5*(d_acc_fake + d_acc_real)

  #Training the generator
  noise  = np.random.randn(batch_size, latent_dim)
  g_loss = combined_model.train_on_batch(noise, ones)

  #saving the losses
  d_losses.append(d_loss)
  g_losses.append(g_loss)

  if epoch%100 == 0:
    print("epoch: " + str((epoch+1)/epochs) + f", d_loss: {d_loss:.2f}, \
      d_acc: {d_acc:.2f}, g_loss: {g_loss:.2f}")
  
  if epoch%sample_period == 0:
    sample_images(epoch)



TypeError: ignored