In [None]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

In [None]:
import time

import numpy as np

from PIL import Image

from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.models import Model
from tensorflow.keras.activations import sigmoid, tanh
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, Dropout, Embedding, Concatenate

In [None]:
(X_train, y_train), (X_test, y_test) = fashion_mnist.load_data()
X, y = np.vstack((X_train, X_test)), np.hstack((y_train, y_test))
X = (X  - 127.5) / 127.5

In [None]:
def create_discriminator(input_shape=(28,28,1), classes_number=10):
    label_input = Input(shape=(1,))
    
    label_layer = Embedding(input_dim=classes_number, output_dim=50)(label_input)
    label_layer = Dense(units=input_shape[0] * input_shape[1])(label_layer)
    label_layer = Reshape(target_shape=(input_shape[0], input_shape[1], 1))(label_layer)
    
    image_input = Input(shape=input_shape)
    
    concatenated_input = Concatenate()([image_input, label_layer])
    
    layer = Conv2D(128, kernel_size=(3,3), strides=(2,2), padding='same')(concatenated_input)
    layer = LeakyReLU(alpha=0.2)(layer)
    
    layer = Conv2D(filters=128, kernel_size=(3,3), strides=(2,2), padding='same')(layer)
    layer = LeakyReLU(alpha=0.2)(layer)
    
    layer = Flatten()(layer)
    
    layer = Dropout(rate=0.4)(layer)
    output_layer = Dense(units=1, activation=sigmoid)(layer)
    
    
    discriminator = Model([image_input, label_input], output_layer)
    discriminator.compile(loss=BinaryCrossentropy(), optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    
    return discriminator
 

def create_generator(hidden_dim, classes_number=10):
    label_input = Input(shape=(1,))
    
    label_layer = Embedding(input_dim=classes_number, output_dim=50)(label_input)
    label_layer = Dense(units=7 * 7)(label_layer)
    label_layer = Reshape(target_shape=(7, 7, 1))(label_layer)
    
    noise_input = Input(shape=(hidden_dim,))
                                       
    noise_layer = Dense(units=7 * 7 * 128)(noise_input)
    noise_layer = LeakyReLU(alpha=0.2)(noise_layer)
    noise_layer = Reshape((7, 7, 128))(noise_layer)
                                       
    concatenated_input = Concatenate()([noise_layer, label_layer])
    
    layer = Conv2DTranspose(filters=128, kernel_size=(4,4), strides=(2,2), padding='same')(concatenated_input)
    layer = LeakyReLU(alpha=0.2)(layer)
    
    layer = Conv2DTranspose(filters=128, kernel_size=(4,4), strides=(2,2), padding='same')(layer)
    layer = LeakyReLU(alpha=0.2)(layer)
    
    output_layer = Conv2D(filters=1, kernel_size=(7,7), activation=tanh, padding='same')(layer)
    
    return Model([noise_input, label_input], output_layer)


def connect_models(generator, discriminator):
    discriminator.trainable = False
    gen_noise, gen_label = generator.input
    
    connected_model_output = discriminator([generator.output, gen_label])
    
    connected_model = Model([gen_noise, gen_label], connected_model_output)
    connected_model.compile(loss=BinaryCrossentropy(), optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    
    return connected_model
 


def random_arange(start, stop):
    array = np.arange(start, stop)
    np.random.shuffle(array)
    return array



def train_discriminator_on_batch(discriminator, generator, hidden_dim, X_real, labels_real, classes_number):
    batch_size_half = X_real.shape[0]
    
    d_loss_real = discriminator.train_on_batch([X_real, labels_real], np.ones(batch_size_half))

    labels_fake = np.random.randint(0, classes_number, size=batch_size_half)
    X_fake = generator.predict([np.random.normal(size=(batch_size_half, hidden_dim)), labels_fake])
    d_loss_fake = discriminator.train_on_batch([X_fake, labels_fake], np.zeros(batch_size_half))
    
    return d_loss_real, d_loss_fake



def train_generator_on_batch(connected_model, hidden_dim, classes_number, batch_size):
    labels_fake = np.random.randint(0, classes_number, size=batch_size)
    return connected_model.train_on_batch([np.random.normal(size=(batch_size, hidden_dim)), labels_fake], np.ones(batch_size))



def train(generator, discriminator, hidden_dim, X, y, classes_number=10, epochs=100, batch_size=140, epochs_per_save=5, first_epoch_number=0, save_path='.'):
    connected_model = connect_models(generator, discriminator)
    
    batches_number = X.shape[0] // batch_size
    batch_size_half = batch_size // 2
    
    epochs += first_epoch_number
    for i in range(first_epoch_number, epochs):
        random_indices = random_arange(0, X.shape[0] // 2) if i % 2 == 0 else random_arange(X.shape[0] // 2, X.shape[0])
        
        d_loss_real, d_loss_fake, g_loss = 0., 0., 0.
        average_time = 0
        
        for j in range(batches_number):
            print(f'Epoch {i + 1}/{epochs}, batch {j + 1}/{batches_number}, average_time {round(average_time, 3) or "?"}', end='\r')
            start_time = time.perf_counter()
            
            
            samples_index = random_indices[j * batch_size_half:(j + 1) * batch_size_half]
            
            d_loss_real_batch, d_loss_fake_batch = train_discriminator_on_batch(
                discriminator, generator, hidden_dim, 
                X[samples_index], y[samples_index], classes_number
            )
            d_loss_real += d_loss_real_batch
            d_loss_fake += d_loss_fake_batch
            
            g_loss += train_generator_on_batch(connected_model, hidden_dim, classes_number, batch_size)
            
            average_time = (average_time * j + time.perf_counter() - start_time) / (j + 1)
                
        print(f'Epoch {i + 1}/{epochs}, time {average_time * batches_number:.3f}, d_real={d_loss_real:.3f}, d_fake={d_loss_fake:.3f}, g={g_loss:.3f}')
        if (i + 1) % epochs_per_save == 0:
            generator.save(os.path.join(save_path, f'generator_{i + 1}.h5'))
            discriminator.save(os.path.join(save_path, f'discriminator_{i + 1}.h5'))
            print('Saved generator and discriminator')

In [None]:
hidden_dim = 100

discriminator = create_discriminator()
generator = create_generator(hidden_dim)

In [None]:
train(generator, discriminator, hidden_dim, X, y, epochs=200, save_path='models')