In [1]:
from numpy import expand_dims
from numpy import zeros
from numpy import ones
from numpy.random import randn
from numpy.random import randint
from keras.datasets import fashion_mnist
from keras.datasets import mnist
# from tensorflow.keras.datasets.mnist import load_data

# from keras.datasets import mnist, cifar10, fashion_mnist

from keras.optimizers import Adam
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose
from keras.layers import LeakyReLU, Dropout, Embedding, Concatenate, BatchNormalization
from IPython.display import clear_output

from pathlib import Path

import matplotlib.pyplot as plt


In [2]:
CLASS_NAMES = {
    'mnist': [f'{i}' for i in range(10)],
    'fashion_mnist': ['t-shirt/top', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'],
    'cifar10': ['airplane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'],
}

cwd = Path.cwd()

In [3]:
def define_discriminator(in_shape=(28,28,1), n_classes=10):
    in_label = Input(shape=(1,))
    li = Embedding(n_classes, 50)(in_label)
    n_nodes = in_shape[0] * in_shape[1]
    li = Dense(n_nodes)(li)
    li = Reshape((in_shape[0], in_shape[1], 1))(li)
    
    in_image = Input(shape=in_shape)
    
    merge = Concatenate()([in_image, li])
    
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(merge)
    fe = LeakyReLU(alpha=0.2)(fe)
    
    # downsample
    fe = Conv2D(128, (3,3), strides=(2,2), padding='same')(fe)
    fe = LeakyReLU(alpha=0.2)(fe)
    # flatten feature maps
    fe = Flatten()(fe)
    # dropout
    fe = Dropout(0.4)(fe)
    # output
    out_layer = Dense(1, activation='sigmoid')(fe)
    # define model
    model = Model([in_image, in_label], out_layer)
    
    # compile model
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt, metrics=['accuracy'])
    return model

def define_generator(latent_dim, n_classes=10):
    
    in_label = Input(shape=(1,))
    
    li = Embedding(n_classes, 50)(in_label)
    
    n_nodes = 7 * 7
    li = Dense(n_nodes)(li)
    
    li = Reshape((7, 7, 1))(li)
    
    in_lat = Input(shape=(latent_dim,))
    
    n_nodes = 128 * 7 * 7
    gen = Dense(n_nodes)(in_lat)
    gen = LeakyReLU(alpha=0.2)(gen)
    gen = Reshape((7, 7, 128))(gen)
    
    merge = Concatenate()([gen, li])
    
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(merge)
    gen = LeakyReLU(alpha=0.2)(gen)
    
    gen = Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')(gen)
    gen = LeakyReLU(alpha=0.2)(gen)
    
    out_layer = Conv2D(1, (7,7), activation='tanh', padding='same')(gen)
    
    model = Model([in_lat, in_label], out_layer)
    return model

def define_gan(g_model, d_model):
    
    d_model.trainable = False
    
    gen_noise, gen_label = g_model.input
    
    gen_output = g_model.output
    
    gan_output = d_model([gen_output, gen_label])
    
    model = Model([gen_noise, gen_label], gan_output)
    
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

In [4]:
# load fashion mnist images
def load_real_samples(dataset=None):
    (trainX, trainy), (_, _) = dataset.load_data()
    
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = (X - 127.5) / 127.5
    
    return [X, trainy]

def generate_real_samples(dataset, n_samples):
    images, labels = dataset
    ix = randint(0, images.shape[0], n_samples)
    X, labels = images[ix], labels[ix]
    y = ones((n_samples, 1))
    return [X, labels], y

def generate_latent_points(latent_dim, n_samples, n_classes=10):
    x_input = randn(latent_dim * n_samples)
    z_input = x_input.reshape(n_samples, latent_dim)
    labels = randint(0, n_classes, n_samples)
    return [z_input, labels]

def generate_fake_samples(generator, latent_dim, n_samples):
    # generate points in latent space
    z_input, labels_input = generate_latent_points(latent_dim, n_samples)
    # predict outputs
    images = generator.predict([z_input, labels_input])
    # create class labels
    y = zeros((n_samples, 1))
    return [images, labels_input], y

In [5]:
def plot_data(data=None,
              n=5,
              figsize=(10, 8), 
              save=False, 
              name=None, path='/plot_data/',
              axis='on', show=False):

    images, labels = data
                  
    plt.figure(figsize=(figsize))
    for i in range(n**2):
        plt.subplot(n, n, i+1)
        plt.axis(axis)
        plt.imshow(images[i].squeeze(), cmap='gray_r')
        if labels is not None:
            plt.title(CLASS_NAMES[DATASET_NAME][labels[i].squeeze()], fontsize=14)
    plt.tight_layout()
    if save:
        while name is None:
            name = input("Enter name for figure: ")

        file_path = cwd / Path('figures') / Path(path)
        file_path.mkdir(parents=True, exist_ok=True)
        plt.savefig(file_path.joinpath(str(name) + '.png'))
    if show:
        plt.show()
    plt.close()

def save_model(g_model, epoch, path=None):
    file_path = cwd / Path('models') / Path(path)
    file_path.mkdir(parents=True, exist_ok=True)
    
    g_model.save(file_path / f'gen_model_e-{epoch+1:03d}.h5')

def summarise_performance(epoch, g_model, latent_dim, n_samples=100):
    [X, labels], y = generate_fake_samples(g_model, latent_dim, n_samples)
    X = (X + 1) / 2.0

    plot_data(data=(X, labels), save=True, axis='off', path=f'{GAN_NAME}/images', 
             name=f'gen_image_e-{epoch+1:03d}')
    
    save_model(g_model, epoch, path=f'{GAN_NAME}')

In [6]:
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=100, n_batch=128):
    hash = {
        'd_loss1': [],
        'd_loss2': [],
        'g_loss': [],
    }
    bat_per_epo = int(dataset[0].shape[0] / n_batch)
    half_batch = int(n_batch / 2)    
    for i in range(n_epochs):        
        for j in range(bat_per_epo):
            
            [X_real, labels_real], y_real = generate_real_samples(dataset, half_batch)            
            d_loss1, _ = d_model.train_on_batch([X_real, labels_real], y_real)
            hash['d_loss1'].append(d_loss1)
            
            [X_fake, labels], y_fake = generate_fake_samples(g_model, latent_dim, half_batch)            
            d_loss2, _ = d_model.train_on_batch([X_fake, labels], y_fake)
            hash['d_loss2'].append(d_loss2)
            
            [z_input, labels_input] = generate_latent_points(latent_dim, n_batch)            
            y_gan = ones((n_batch, 1))
            g_loss = gan_model.train_on_batch([z_input, labels_input], y_gan)
            hash['g_loss'].append(g_loss)
            
            print('>%d, %d/%d, d1=%.3f, d2=%.3f g=%.3f' % (i+1, j+1, bat_per_epo, d_loss1, d_loss2, g_loss))
            if j % 35 == 0:
                clear_output(wait=True)
        summarise_performance(i, g_model, latent_dim)
    # save the generator model
    g_model.save('cgan_generator.h5')

In [10]:
latent_dim = 100

d_model = define_discriminator()
g_model = define_generator(latent_dim)
gan_model = define_gan(g_model, d_model)

dataset = load_real_samples(dataset=fashion_mnist)

DATASET_NAME = 'fashion_mnist'
GAN_NAME = f'CGAN/{DATASET_NAME}'

hash = train(g_model, d_model, gan_model, dataset, latent_dim)

>4, 37/468, d1=0.000, d2=0.000 g=0.019
>4, 38/468, d1=0.000, d2=0.000 g=0.019
>4, 39/468, d1=0.000, d2=0.000 g=0.018
>4, 40/468, d1=0.000, d2=0.000 g=0.019
>4, 41/468, d1=0.000, d2=0.000 g=0.019
>4, 42/468, d1=0.000, d2=0.000 g=0.019
>4, 43/468, d1=0.000, d2=0.000 g=0.020
>4, 44/468, d1=0.000, d2=0.000 g=0.019
>4, 45/468, d1=0.000, d2=0.000 g=0.021
>4, 46/468, d1=0.000, d2=0.000 g=0.021
>4, 47/468, d1=0.000, d2=0.000 g=0.020
>4, 48/468, d1=0.000, d2=0.000 g=0.019
>4, 49/468, d1=0.000, d2=0.000 g=0.019
>4, 50/468, d1=0.000, d2=0.000 g=0.020
>4, 51/468, d1=0.000, d2=0.000 g=0.019


Exception ignored in: <function WeakKeyDictionary.__init__.<locals>.remove at 0x12a6aa940>
Traceback (most recent call last):
  File "/Users/hamzz/anaconda3/envs/tensorflow_env/lib/python3.8/weakref.py", line 345, in remove
    def remove(k, selfref=ref(self)):
KeyboardInterrupt: 


KeyboardInterrupt: 