In [1]:
# %% --------------------------------------- Load Packages -------------------------------------------------------------
import os
import random
import cv2
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow.keras.backend as K
from tensorflow.keras import Model, Sequential
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Input, Reshape, Dense, Dropout, \
    Activation, LeakyReLU, Conv2D, Conv2DTranspose, Embedding, \
    Concatenate, multiply, Flatten, BatchNormalization
from tensorflow.keras.initializers import glorot_normal
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split

In [2]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [4]:
from tensorflow.keras import layers

In [3]:
# %% ---------------------------------- Data Preparation ---------------------------------------------------------------
def change_image_shape(images):
    shape_tuple = images.shape
    if len(shape_tuple) == 3:
        images = images.reshape(-1, shape_tuple[-1], shape_tuple[-1], 1)
    elif shape_tuple == 4 and shape_tuple[-1] > 3:
        images = images.reshape(-1, shape_tuple[-1], shape_tuple[-1], shape_tuple[1])
    return images

In [None]:
######################## MNIST / CIFAR ##########################
# # Load MNIST Fashion
# from tensorflow.keras.datasets.fashion_mnist import load_data
# # Load CIFAR-10
from tensorflow.keras.datasets.cifar10 import load_data

# # Load training set
(x_train_raw, y_train_raw), (x_test_raw,y_test_raw) = load_data()
x_train_raw = change_image_shape(x_train_raw)
x_test_raw = change_image_shape(x_test_raw)

y_train = y_train_raw.reshape(-1)
y_test = y_test_raw.reshape(-1)

######################## Preprocessing ##########################
# Set channel
channel = x_train_raw.shape[-1]

# to 64 x 64 x channel
x_train = np.ndarray(shape=(x_train_raw.shape[0], 64, 64, channel))
x_test = np.ndarray(shape=(x_test_raw.shape[0], 64, 64, channel))

for i in range(x_train.shape[0]):
    x_train[i] = cv2.resize(x_train_raw[i], (64, 64)).reshape((64, 64, channel))

for i in range(x_test.shape[0]):
    x_test[i] = cv2.resize(x_test_raw[i], (64, 64)).reshape((64, 64, channel))

# Create imbalanced version
for c in range(1, 10):
    x_train = np.vstack([x_train[y_train!=c], x_train[y_train==c][:100*c]])
    y_train = np.append(y_train[y_train!=c], np.ones(100*c) * c)

# Train test split, for autoencoder (actually, this step is redundant if we already have test set)
# x_train, x_test, y_train, y_test = train_test_split(real, labels, test_size=0.3, shuffle=True, random_state=42)

# It is suggested to use [-1, 1] input for GAN training
x_train = (x_train.astype('float32') - 127.5) / 127.5
x_test = (x_test.astype('float32') - 127.5) / 127.5

# Get image size
img_size = x_train[0].shape
# Get number of classes
n_classes = len(np.unique(y_train))

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [None]:
# %% --------------------------------------- Fix Seeds -----------------------------------------------------------------
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)
weight_init = glorot_normal(seed=SEED)
latent_dim=128

In [None]:
# %% ---------------------------------- Hyperparameters ----------------------------------------------------------------

optimizer = Adam(learning_rate=0.0002, beta_1=0.5, beta_2=0.9)
# trainRatio === times(Train D) / times(Train G)
trainRatio = 5

In [None]:
# %% ---------------------------------- Models Setup -------------------------------------------------------------------
# Build Generator/Decoder
def decoder():
    # weight initialization
    init = RandomNormal(stddev=0.02)
    noise_le = Input((latent_dim,))
    decoder = layers.Dense(4*4*256)(noise_le)
    decoder = layers.LeakyReLU(alpha=0.2)(decoder)

    ## Size: 4 x 4 x 256
    decoder = Reshape((4, 4, 256))(decoder)

    ## Size: 8 x 8 x 128
    decoder = Conv2DTranspose(filters=128,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        padding='same',
                        kernel_initializer=init)(decoder)
    decoder = BatchNormalization()(decoder)
    decoder = LeakyReLU(0.2)(decoder)

    ## Size: 16 x 16 x 128
    decoder = Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(decoder)
    decoder = BatchNormalization()(decoder)
    decoder = LeakyReLU(0.2)(decoder)

    ## Size: 32 x 32 x 64
    decoder = Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(decoder)
    decoder = BatchNormalization()(decoder)
    decoder = LeakyReLU(0.2)(decoder)

    ## Size: 64 x 64 x 3
    generated = Conv2DTranspose(channel, (4, 4), strides=(2, 2), padding='same', activation='tanh', kernel_initializer=init)(decoder)


    generator = Model(inputs=noise_le, outputs=generated)
    return generator

# Build Encoder
def encoder():
    # weight initialization
    init = RandomNormal(stddev=0.02)

    encoder_inputs = Input(img_size)

    encoder = Conv2D(64, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(encoder_inputs)
    # x = LayerNormalization()(x) # It is not suggested to use BN in Discriminator of WGAN
    encoder = LeakyReLU(0.2)(encoder)
    # x = Dropout(0.3)(x)

    encoder = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(encoder)
    # x = LayerNormalization()(x)
    encoder = LeakyReLU(0.2)(encoder)
    # x = Dropout(0.3)(x)

    encoder = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(encoder)
    # x = LayerNormalization()(x)
    encoder = LeakyReLU(0.2)(encoder)
    # x = Dropout(0.3)(x)

    encoder = Conv2D(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(encoder)
    # x = LayerNormalization()(x)
    encoder = LeakyReLU(0.2)(encoder)
    # x = Dropout(0.3)(x)

    # 4 x 4 x 256
    feature = Flatten()(encoder)

    feature = Dense(latent_dim)(feature)
    encoder_outputs = LeakyReLU(0.2)(feature)

    # Lấy mean và log variance từ encoder_outputs
    z_mean = layers.Dense(latent_dim)(encoder_outputs)
    z_log_var = layers.Dense(latent_dim)(encoder_outputs)

    # Lấy mẫu từ phân phối Gaussian
    def sampling(args):
        z_mean, z_log_var = args
        epsilon = tf.keras.backend.random_normal(shape=(tf.keras.backend.shape(z_mean)[0], latent_dim), mean=0.0, stddev=1.0)
        return z_mean + tf.keras.backend.exp(0.5 * z_log_var) * epsilon

    z = layers.Lambda(sampling)([z_mean, z_log_var])

    model = Model(inputs=encoder_inputs, outputs=z)
    return model

In [None]:
def embedding_labeled_latent():
  """
  Returns a model that embeds a label into a latent space.

  Args:
    latent_dim: The dimension of the latent space.
    n_classes: The number of classes in the dataset.

  Returns:
    A TensorFlow model.
  """

  # Weight initialization.
  init = tf.random_normal_initializer(stddev=0.02)

  # Input layers.
  label = tf.keras.Input((1,), dtype='int32')
  noise = tf.keras.Input((latent_dim,), dtype='float32')

  # Embedding layer.
  le = tf.keras.layers.Embedding(n_classes, latent_dim, embeddings_initializer=init)(label)
  le = tf.keras.layers.Flatten()(le)

  # Multiply layer.
  noise_le = tf.keras.layers.Multiply()([noise, le])

  # Model.
  model = tf.keras.Model([noise, label], noise_le)

  return model

In [None]:
# Train Autoencoder
en = encoder()
de = decoder()
em = embedding_labeled_latent()

label = Input((1,), dtype='int32')
img = Input(img_size)
latent_z = en(img)
labeled_latent = em([latent_z, label])

rec_img = de(labeled_latent)
vae = Model([img, label], rec_img)

In [None]:
vae.compile(optimizer=optimizer, loss='mae')

vae.fit([x_train, y_train], x_train,
       epochs=100,
       batch_size=128,
       shuffle=True,
       verbose=True,
       validation_data=([x_test, y_test], x_test))

In [None]:
#Show results of reconstructed images
decoded_imgs = vae.predict([x_test, y_test])
n = n_classes
plt.figure(figsize=(2*n, 4))
decoded_imgs = decoded_imgs*0.5 + 0.5
x_real = x_test*0.5 + 0.5
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i+1)
    if channel == 3:
        plt.imshow(x_real[y_test==i][0].reshape(64, 64, channel))
    else:
        plt.imshow(x_real[y_test==i][0].reshape(64, 64))
        plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    # display reconstruction
    ax = plt.subplot(2, n, i + n + 1)
    if channel == 3:
        plt.imshow(decoded_imgs[y_test==i][0].reshape(64, 64, channel))
    else:
        plt.imshow(decoded_imgs[y_test==i][0].reshape(64, 64))
        plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()

In [None]:
# Build Discriminator without inheriting the pre-trained Encoder
# Similar to cWGAN
def discriminator_cwgan():
    # weight initialization
    init = RandomNormal(stddev=0.02)

    img = Input(img_size)
    label = Input((1,), dtype='int32')


    x = Conv2D(64, kernel_size=(4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(img)
    # x = LayerNormalization()(x) # It is not suggested to use BN in Discriminator of WGAN
    x = LeakyReLU(0.2)(x)
    # x = Dropout(0.3)(x)

    x = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU(0.2)(x)
    # x = Dropout(0.3)(x)

    x = Conv2D(128, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU(0.2)(x)
    # x = Dropout(0.3)(x)

    x = Conv2D(256, (4, 4), strides=(2, 2), padding='same', kernel_initializer=init)(x)
    # x = LayerNormalization()(x)
    x = LeakyReLU(0.2)(x)
    # x = Dropout(0.3)(x)

    x = Flatten()(x)

    le = Flatten()(Embedding(n_classes, 512)(label))
    le = Dense(4 * 4 * 256)(le)
    le = LeakyReLU(0.2)(le)
    x_y = multiply([x, le])
    x_y = Dense(512)(x_y)

    out = Dense(1)(x_y)

    model = Model(inputs=[img, label], outputs=out)

    return model

In [None]:
# %% ----------------------------------- BAGAN-GP Part -----------------------------------------------------------------
# Refer to the WGAN-GP Architecture. https://github.com/keras-team/keras-io/blob/master/examples/generative/wgan_gp.py
# Build our BAGAN-GP
class BAGAN_GP(Model):
    def __init__(
        self,
        discriminator,
        generator,
        latent_dim,
        discriminator_extra_steps=3,
        gp_weight=10.0,
    ):
        super(BAGAN_GP, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.train_ratio = trainRatio
        self.gp_weight = gp_weight

    def compile(self, d_optimizer, g_optimizer, d_loss_fn, g_loss_fn):
        super(BAGAN_GP, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.d_loss_fn = d_loss_fn
        self.g_loss_fn = g_loss_fn

    def gradient_penalty(self, batch_size, real_images, fake_images, labels):
        """ Calculates the gradient penalty.

        This loss is calculated on an interpolated image
        and added to the discriminator loss.
        """
        # get the interplated image
        alpha = tf.random.normal([batch_size, 1, 1, 1], 0.0, 1.0)
        diff = fake_images - real_images
        interpolated = real_images + alpha * diff

        with tf.GradientTape() as gp_tape:
            gp_tape.watch(interpolated)
            # 1. Get the discriminator output for this interpolated image.
            pred = self.discriminator([interpolated, labels], training=True)

        # 2. Calculate the gradients w.r.t to this interpolated image.
        grads = gp_tape.gradient(pred, [interpolated])[0]
        # 3. Calcuate the norm of the gradients
        norm = tf.sqrt(tf.reduce_sum(tf.square(grads), axis=[1, 2, 3]))
        gp = tf.reduce_mean((norm - 1.0) ** 2)
        return gp

    def train_step(self, data):
        if isinstance(data, tuple):
            real_images = data[0]
            labels = data[1]

        # Get the batch size
        batch_size = tf.shape(real_images)[0]

        ########################### Train the Discriminator ###########################
        # For each batch, we are going to perform cwgan-like process
        for i in range(self.train_ratio):
            # Get the latent vector
            random_latent_vectors = tf.random.normal(
                shape=(batch_size, self.latent_dim)
            )
            fake_labels = tf.random.uniform((batch_size,), 0, n_classes)
            wrong_labels = tf.random.uniform((batch_size,), 0, n_classes)
            with tf.GradientTape() as tape:
                # Generate fake images from the latent vector
                fake_images = self.generator([random_latent_vectors, fake_labels], training=True)
                # Get the logits for the fake images
                fake_logits = self.discriminator([fake_images, fake_labels], training=True)
                # Get the logits for real images
                real_logits = self.discriminator([real_images, labels], training=True)
                # Get the logits for wrong label classification
                wrong_label_logits = self.discriminator([real_images, wrong_labels], training=True)

                # Calculate discriminator loss using fake and real logits
                d_cost = self.d_loss_fn(real_logits=real_logits, fake_logits=fake_logits,
                                        wrong_label_logits=wrong_label_logits
                                        )

                # Calculate the gradient penalty
                gp = self.gradient_penalty(batch_size, real_images, fake_images, labels)
                # Add the gradient penalty to the original discriminator loss
                d_loss = d_cost + gp * self.gp_weight

            # Get the gradients w.r.t the discriminator loss
            d_gradient = tape.gradient(d_loss, self.discriminator.trainable_variables)
            # Update the weights of the discriminator using the discriminator optimizer
            self.d_optimizer.apply_gradients(
                zip(d_gradient, self.discriminator.trainable_variables)
            )

        ########################### Train the Generator ###########################
        # Get the latent vector
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        fake_labels = tf.random.uniform((batch_size,), 0, n_classes)
        with tf.GradientTape() as tape:
            # Generate fake images using the generator
            generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
            # Get the discriminator logits for fake images
            gen_img_logits = self.discriminator([generated_images, fake_labels], training=True)
            # Calculate the generator loss
            g_loss = self.g_loss_fn(gen_img_logits)

        # Get the gradients w.r.t the generator loss
        gen_gradient = tape.gradient(g_loss, self.generator.trainable_variables)
        # Update the weights of the generator using the generator optimizer
        self.g_optimizer.apply_gradients(
            zip(gen_gradient, self.generator.trainable_variables)
        )
        return {"d_loss": d_loss, "g_loss": g_loss}

In [None]:
# Optimizer for both the networks
# learning_rate=0.0002, beta_1=0.5, beta_2=0.9 are recommended
generator_optimizer = Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)
discriminator_optimizer = Adam(
    learning_rate=0.0002, beta_1=0.5, beta_2=0.9
)

In [None]:
def lecam_reg(dis_real, dis_fake):
    reg = tf.reduce_mean(tf.square(tf.nn.relu(dis_real - dis_fake))) + \
          tf.reduce_mean(tf.square(tf.nn.relu(dis_real - dis_fake)))
    return reg

In [None]:
# We refer to the DRAGAN loss function. https://github.com/kodalinaveen3/DRAGAN
# Define the loss functions to be used for discrimiator
# We will add the gradient penalty later to this loss function
def discriminator_loss(real_logits, fake_logits, wrong_label_logits):
    real_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits, labels=tf.ones_like(real_logits)))
    fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.zeros_like(fake_logits)))
    wrong_label_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=wrong_label_logits, labels=tf.zeros_like(fake_logits)))
    return wrong_label_loss + fake_loss + real_loss

In [None]:
# Define the loss functions to be used for generator
def generator_loss(fake_logits):
    fake_loss = tf.reduce_mean(
        tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits, labels=tf.ones_like(fake_logits)))
    return fake_loss

In [None]:
# build generator with pretrained decoder and embedding
def generator_label(embedding, decoder):
    # # Embedding model needs to be trained along with GAN training
    # embedding.trainable = False

    label = Input((1,), dtype='int32')
    latent = Input((latent_dim,))

    labeled_latent = embedding([latent, label])
    gen_img = decoder(labeled_latent)
    model = Model([latent, label], gen_img)

    return model

In [None]:
# Build discriminator with pre-trained Encoder
def build_discriminator(encoder):

    label = Input((1,), dtype='int32')
    img = Input(img_size)

    inter_output_model = Model(inputs=encoder.input, outputs=encoder.layers[-3].output)
    x = inter_output_model(img)

    le = Flatten()(Embedding(n_classes, 512)(label))
    le = Dense(4 * 4 * 256)(le)
    le = LeakyReLU(0.2)(le)
    x_y = multiply([x, le])
    x_y = Dense(512)(x_y)

    out = Dense(1)(x_y)

    model = Model(inputs=[img, label], outputs=out)

    return model

In [None]:
# %% ----------------------------------- Compile Models ----------------------------------------------------------------
# d_model = build_discriminator(en)  # initialized with Encoder
d_model = discriminator_cwgan()  # without initialization
g_model = generator_label(em, de)  # initialized with Decoder and Embedding

bagan_gp = BAGAN_GP(
    discriminator=d_model,
    generator=g_model,
    latent_dim=latent_dim,
    discriminator_extra_steps=3,
)

# Compile the model
bagan_gp.compile(
    d_optimizer=discriminator_optimizer,
    g_optimizer=generator_optimizer,
    g_loss_fn=generator_loss,
    d_loss_fn=discriminator_loss,
)

In [None]:
# Record the loss
d_loss_history = []
g_loss_history = []

############################# Start training #############################

LEARNING_STEPS = 10
for learning_step in range(LEARNING_STEPS):
    print('LEARNING STEP # ', learning_step + 1, '-' * 50)
    bagan_gp.fit(x_train, y_train, batch_size=128, epochs=10, verbose=True)
    d_loss_history += bagan_gp.history.history['d_loss']
    g_loss_history += bagan_gp.history.history['g_loss']
    if (learning_step+1)%1 == 0:
       plt_img(bagan_gp.generator, learning_step)
    #     bagan_gp.discriminator.save_weights('model_full_data_ciffar10/discriminator_weight_epoch' + str(learning_step) + '.h5')
    #     bagan_gp.generator.save_weights('model_full_data_ciffar10/generator_weight_epoch' + str(learning_step) + '.h5')


In [None]:
bagan_gp.discriminator.save('dis_imba_bagan_gp_ep10_no_lc.h5')
bagan_gp.generator.save('gen_imba_bagan_gp_ep10_no_lc.h5')

In [None]:
# d_model = discriminator_cwgan()  # without initialization
# g_model = generator_label(em, de)  # initialized with Decoder and Embedding

# d_model.load_weights('/content/discriminator_weight_step9.h5')
# g_model.load_weights('/content/generator_weight_step9.h5')

# retrain_bagan_gp = BAGAN_GP(
#     discriminator=d_model,
#     generator=g_model,
#     latent_dim=latent_dim,
#     discriminator_extra_steps=3,
# )

# # Compile the model
# retrain_bagan_gp.compile(
#     d_optimizer=discriminator_optimizer,
#     g_optimizer=generator_optimizer,
#     g_loss_fn=generator_loss,
#     d_loss_fn=discriminator_loss,
# )

# # # Record the loss
# # d_loss_history = []
# # g_loss_history = []



# # ############################# Start training #############################
# # LEARNING_STEPS = 15
# # for learning_step in range(LEARNING_STEPS):
# #     print('LEARNING STEP # ', learning_step + 1, '-' * 50)
# #     retrain_bagan_gp.fit(x_train, y_train, batch_size=128, epochs=10, verbose=True)
# #     d_loss_history += retrain_bagan_gp.history.history['d_loss']
# #     g_loss_history += retrain_bagan_gp.history.history['g_loss']
# #     if (learning_step+1)%1 == 0:
# #         plt_img(retrain_bagan_gp.generator, learning_step+15)
# #         retrain_bagan_gp.discriminator.save_weights('model_full_data_ciffar10/discriminator_weight_epoch' + str(learning_step+15) + '.h5')
# #         retrain_bagan_gp.generator.save_weights('model_full_data_ciffar10/generator_weight_epoch' + str(learning_step + 15) + '.h5')


In [None]:
list_fid_scores = get_fid_score(x_test, y_test,  bagan_gp, n_classes, latent_dim)

In [None]:
# get_fid_score(x_test, y_test,  bagan_gp, n_classes, latent_dim)

In [None]:
import pandas as pd
pd.DataFrame(list_fid_scores, columns=['Class', 'FID']).to_csv('imba_cifar10_ep10_fid_no_lc.csv',
                                                          index=False)