<a href="https://colab.research.google.com/github/jmerceron/AI-Activity/blob/main/Julien_GAN_Handwritten_Digit_Generation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

INITIALIZATION

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

import keras
from keras.models import Model, Sequential
from keras.layers import Dense, BatchNormalization, Reshape, Dropout, LeakyReLU, Input, Flatten
from keras.optimizers import Adam
from keras.datasets import mnist
from keras.utils import plot_model

In [None]:
# Parameters
epochs = 20000
mnist_shape = (28,28,1)
batch_size = 128
noise_shape = (100,)
save_every = 1000

GENERATOR

In [None]:
def build_generator(noise_shape, mnist_shape):
    noise = Input(shape=noise_shape)

    x = Dense(256, input_shape=(noise_shape))(noise)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(np.prod(mnist_shape), activation='tanh')(x)
    x = Reshape(mnist_shape)(x)

    model = Model(noise, x)

    # print model summary
    model.summary()

    img = model(noise)
    return Model(noise, img)

In [None]:
G = build_generator(noise_shape, mnist_shape)

DISCRIMINATOR

In [None]:
def build_discriminator(mnist_shape):
    input = Input(shape=mnist_shape)

    x = Flatten()(input)
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(256)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(1, activation='sigmoid')(x)

    model = Model(input, x)
    model.summary()
    img = model(input)
    return Model(input, img)

In [None]:
D = build_discriminator(mnist_shape)

COMPILE

In [None]:
G.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')
D.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy', metrics=['accuracy'])

BUILD G,A.N

In [None]:
input = Input(shape=noise_shape)
image = G(input)
D.trainable = False
image = D(image)
# noise -> G -> D
D_G_model = Model(input, image)
D_G_model.compile(optimizer=Adam(0.0002, 0.5), loss='binary_crossentropy')

D_G_model.summary()

In [None]:
def save_image(epoch):
    # we plot 5 x 5 grid of images and save it to disk along with epoch number
    rows = 10
    cols = 10

    noise = np.random.uniform(0, 1, (rows*cols, noise_shape[0]))
    images = G.predict(noise)

    # rescale
    images = 0.5 * images + 0.5

    fig, ax = plt.subplots(rows, cols)
    ctr = 0
    for i in range(rows):
        for j in range(cols):
            ax[i,j].imshow(images[ctr, :,:, 0], cmap='gray')
            ax[i,j].axis('off')
            ctr += 1
    if not os.path.exists('images'):
      os.makedirs('images')

    fname = 'images/mnist_{}.png'.format(int(epoch/save_every))
    fig.savefig(fname)
    print('saved: {}'.format(fname))
    plt.close()

TRAIN

In [None]:
# Load dataset
(X_train, _), (_,_) = mnist.load_data()
X_train.shape

# center data
X_train = (X_train.astype('float32') - 127.5)/127.5

X_train = np.expand_dims(X_train, axis=3)
X_train.shape

print(np.mean(X_train), np.std(X_train))

# Train

# we want to train discriminator by passing half batch of real and half batch of fake images
half_batch = int(batch_size/2)
print('half batch size : {}'.format(half_batch))
for epoch in range(epochs):
    # train discriminator

    # real
    indices = np.random.randint(0, X_train.shape[0], half_batch)
    images = X_train[indices]
    d_real_loss = D.train_on_batch(images, np.ones((half_batch, 1)))

    # fake
    noise = np.random.uniform(0, 1, (half_batch, noise_shape[0]))
    noise_images = G.predict(noise)
    d_fake_loss = D.train_on_batch(noise_images, np.zeros((half_batch, 1)))

    d_loss = np.add(d_real_loss, d_fake_loss) / 2

    # train generator
    noise = np.random.uniform(0, 1, (batch_size, noise_shape[0]))
    g_loss = D_G_model.train_on_batch(noise, np.ones((batch_size, 1)))

    if epoch % save_every == 0:
        save_image(epoch)
        print('Epoch: {}, D_Loss:{}, D_Acc:{}, G_Loss:{}'.format(epoch, d_loss[0], d_loss[1], g_loss))

TEST #1

In [None]:
noise = np.random.uniform(0, 1, (1, noise_shape[0]))
image = G.predict(noise)

# Visualise
plt.imshow(image[0,:,:, 0], cmap='gray')

TEST #2

In [None]:
noise = np.random.uniform(0, 1, (1, noise_shape[0]))
image = G.predict(noise)

# Visualise
plt.imshow(image[0,:,:, 0], cmap='gray')

TEST #3

In [None]:
noise = np.random.uniform(0, 1, (1, noise_shape[0]))
image = G.predict(noise)

# Visualise
plt.imshow(image[0,:,:, 0], cmap='gray')