In [19]:
#   Ignore warning
import warnings
warnings.filterwarnings("ignore")

In [20]:
base_path = './'
import os
os.path.isdir(base_path)

In [21]:
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Model
from keras.layers import Dense, Flatten, Conv2D, Conv2DTranspose, Input, Reshape, Embedding, Dropout
from tensorflow.keras.optimizers import Adam
from keras.layers import Concatenate, LeakyReLU
from keras.datasets import fashion_mnist
from tqdm import tqdm_notebook

In [22]:
#   Define Generator
def generator(noise_dims, n_class):
    input_label = Input(shape = (1,))
    embed_label = Embedding(n_class, 50)(input_label)
    label_dense = Dense(49)(embed_label)
    label_img = Reshape((7, 7, 1))(label_dense)

    input_noise = Input(shape = (noise_dims, ))
    n_nodes = 7*7*128
    gen = Dense(n_nodes)(input_noise)
    gen = LeakyReLU(alpha = .2)(gen)
    gen = Reshape((7, 7, 128))(gen)

    #   Concatenate gen with label_embed
    merge = Concatenate()([gen, label_img])
    # Upsample to 14x14x64
    gen = Conv2DTranspose(64, (4, 4), strides = 2, padding = 'same')(merge)
    gen = LeakyReLU(alpha = .2)(gen)
    # Upsample to 28x28x32
    gen = Conv2DTranspose(32, (4, 4), strides = 2, padding = 'same')(gen)
    gen = LeakyReLU(alpha = .2)(gen)

    #   Conv to 28x28x3
    fake_imgs = Conv2D(1, (7, 7), padding = 'same', activation = 'tanh')(gen)

    #   Define model
    gen_model = Model(inputs = [input_noise, input_label], outputs = fake_imgs)
    return gen_model

In [23]:
def discriminator(input_shape, n_class):
    input_label = Input(shape = (1,))
    embed_label = Embedding(n_class, 50)(input_label)
    label_dense = Dense(input_shape[0] * input_shape[1])(embed_label)
    label_img = Reshape(input_shape)(label_dense)

    input_img = Input(shape = input_shape)
    #   Concate input_img and label_img
    merge = Concatenate()([input_img, label_img])
    #   Downsample
    #   14x14x64
    discr = Conv2D(64, (3, 3), strides = 2, padding = 'same')(merge)
    discr = LeakyReLU(alpha = .2)(discr)
    #   7x7x128
    discr = Conv2D(128, (3, 3), strides = 2, padding = 'same')(merge)
    discr = LeakyReLU(alpha = .2)(discr)
    #   Flatten feature maps
    discr = Flatten()(discr)
    discr = Dropout(0.4)(discr)
    out_layer = Dense(1, activation = 'sigmoid')(discr)
    #   Define Discriminator model
    discr_model = Model(inputs = [input_img, input_label], outputs = out_layer)
    discr_model.compile(optimizer = Adam(learning_rate = 2e-4, beta_1 = 0.5), loss = 'binary_crossentropy', metrics = ['acc'])
    return discr_model

In [24]:
def preprocess_data(X_data, labels):
    X = np.expand_dims(X_data, -1)
    X_trai = X.astype('float32')
    X_processed = (X/255.0) * 2 - 1
    return [X_processed, labels]

In [25]:
def define_gan(gen_model, discr_model):
    #   Freeze Discriminator weights when training generator
    discr_model.trainable = False
    #   Get noises and labels input from generator model
    gen_noise, gen_label = gen_model.input
    #   Get image output from generator model
    gen_out = gen_model.output
    #   Connect image output and label input from generator model as inputs to discriminator model
    discr_out = discr_model([gen_out, gen_label]) 
    #   Define GAN model as taking noise and label and outputinga classification
    C_gan = Model(inputs = [gen_noise, gen_label], outputs = discr_out)
    opt = Adam(learning_rate = 2e-4, beta_1 = 0.5)
    C_gan.compile(loss = 'binary_crossentropy', optimizer = opt, metrics = ['acc'])
    return C_gan

In [26]:
def generate_real_samples(dataset, n_sample):
    #   Split into images and labels
    images, labels = dataset
    #   Choose random instances
    index = np.random.randint(0, images.shape[0], n_sample)
    #   Select images and labels from random above instances
    X_real, labels = images[index], labels[index]
    #   Create class labels
    Y1 = np.ones((n_sample, 1))
    return X_real, labels, Y1

In [27]:
def generate_latent_point(latent_dims, n_sample, n_class):
    noise = np.random.randn(n_sample, latent_dims)
    labels = np.random.randint(0, n_class, (n_sample, 1))
    return noise, labels

In [28]:
def generate_fake_samples(gen_model, latent_dims, n_sample, n_class):
    noise, labels = generate_latent_point(latent_dims, n_sample, n_class)
    Y0 = np.zeros((n_sample, 1))
    X_fake = gen_model.predict_on_batch([noise, labels])
    return X_fake, labels, Y0

In [29]:
def plot_result(images):
    h, w, c = images.shape[1:]
    grid_size = int(np.sqrt(images.shape[0]))
    images = ((images + 1)/2.)*255.
    images = images.astype('uint8')
    images = images.reshape(grid_size, grid_size, h, w, c).transpose(0, 2, 1, 3, 4).reshape(grid_size*h, grid_size*w, c)
    plt.imshow(images, cmap = 'gray_r')
    plt.axis('off')
    plt.show()

In [30]:
def train(dataset, C_gan, gen_model, discr_model, latent_dims, n_class, epochs, batch_size, plot_freq, n_sample = 36):
    iters = dataset[0].shape[0] // batch_size
    half_batch = batch_size // 2
    #   Define loss
    losses = {'D': [], 'G': []}

    #   Create fixed noise latent to evaluate generator model
    noise_fixed = np.random.normal(-1, 1, (n_sample, latent_dims))
    desire_labels = np.random.randint(0, n_class, (n_sample, 1))
    for epoch in tqdm_notebook(range(1, epochs + 1)):
        if epoch == 1 or epoch % plot_freq == 0:
            print('-'*30, f'Epoch {epoch}', '-'*30)
        for i in range(iters):
            # ======================== Train Discriminator ========================
            #   Get randomly selected 'real' samples
            X_real, real_labels, Y1 = generate_real_samples(dataset, half_batch)
            #   Update discriminator weights by training it with REAL samples
            D_loss_real = discr_model.train_on_batch([X_real, real_labels], Y1)
            #   Generate FAKE samples
            X_fake, fake_labels, Y0 = generate_fake_samples(gen_model, latent_dims, half_batch, n_class)
            #   Update discriminator weights by training it with FAKE samples
            D_loss_fake = discr_model.train_on_batch([X_fake, fake_labels], Y0)
            #   Total discriminator loss
            losses['D'].append(D_loss_real + D_loss_fake)

            # ======================== Train Generator ========================
            #   Prepare points in latent space as input of generator
            noise, labels = generate_latent_point(latent_dims, batch_size, n_class)
            #   Create real labels for training generator
            Y_gen_1 = np.ones((batch_size, 1))
            #   Update generator weights via the discriminator error
            G_loss = C_gan.train_on_batch([noise, labels], Y_gen_1)
            losses['G'].append(G_loss)
        if epoch == 1 or epoch % plot_freq == 0:
            fake_imgs = gen_model.predict_on_batch([noise_fixed, desire_labels])
            print(fake_imgs.shape)
            plot_result(fake_imgs)

In [31]:
(X_train, Y_train), (_, _) = fashion_mnist.load_data()
dataset = preprocess_data(X_train, Y_train)
noise_dims = 100
n_class = 10
gen_model = generator(noise_dims, n_class)
discr_model = discriminator((28, 28, 1), n_class)
C_gan = define_gan(gen_model, discr_model)
epochs = 100
batch_size = 64
plot_freq = 10
train(dataset, C_gan, gen_model, discr_model, noise_dims, n_class, epochs, batch_size, plot_freq)

In [32]:
import os
gen_model.save(os.path.join('./', 'Conditional_GAN.h5'))

In [33]:
from keras.models import load_model
C_gan = load_model('Conditional_GAN.h5')

In [34]:
n_sample = 16
noise_dims = 100
noise_input = np.random.randn(n_sample, noise_dims)
desired_label = np.random.randint(0, 10, (16, 1))
print(noise_input.shape)
print(desired_label.shape)

In [35]:
gen_imgs = C_gan.predict([noise_input, desired_label])

In [40]:
fig = plt.figure(figsize = (3, 3))
for index, image in enumerate(gen_imgs):
    ax = fig.add_subplot(4, 4, index + 1)
    plt.imshow(image, cmap = 'gray_r')
    plt.axis('off')
plt.show()