In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow import keras
from tensorflow.keras import layers

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# [0,255] -> [0,1] -> [-1,1]
x_train = (x_train/255.) * 2. - 1.

x_train = np.expand_dims(x_train,axis=3)
x_train = tf.cast(x_train,dtype=tf.float32)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(1000).batch(256)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

2022-05-19 16:23:12.227756: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [12]:
class Conditional_Generator(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.main_linear1 = layers.Dense(7 * 7 * 128)
        self.main_reshape = layers.Reshape((7, 7, 128))
        self.main_conv2d_tr1 = layers.Conv2DTranspose(128, kernel_size=4,
                                strides=2, padding="same")
        self.main_leaky1 = layers.LeakyReLU(alpha=0.2)
        self.main_conv2d_tr2 = layers.Conv2DTranspose(128, kernel_size=4,
                                strides=2, padding="same")
        self.main_leaky2 = layers.LeakyReLU(alpha=0.2)
        self.main_conv2d_tr3 = layers.Conv2DTranspose(1, (7, 7), padding="same", activation="tanh")

        self.label_emb = layers.Embedding(input_dim=10, output_dim=30)
        self.label_linear1 = layers.Dense(units = 7*7)
        self.label_reshape = layers.Reshape(target_shape=(7,7,-1))

        self.concat_layer = layers.Concatenate(axis=-1)

    def call(self, latent_code, label):
        
        # turn label into activation map
        label_map = self.label_reshape(self.label_linear1(self.label_emb(label)))

        # turn latent code into activation map
        img_activation_maps = self.main_reshape(self.main_linear1(latent_code))

        # concatenate all maps
        concat_activation_maps = self.concat_layer([img_activation_maps, label_map])

        # upsample to create image
        gen_img = self.main_conv2d_tr1(concat_activation_maps)
        gen_img = self.main_leaky1(gen_img)
        gen_img = self.main_conv2d_tr2(gen_img)
        gen_img = self.main_leaky2(gen_img)
        gen_img = self.main_conv2d_tr3(gen_img)

        return gen_img

In [25]:
print(f'epoch: {2}, disc_loss: '
    f'{3:.4f}, gen_loss: {6:.4f}')

epoch: 2, disc_loss: 3.0000, gen_loss: 6.0000


In [None]:
class Conditional_Discriminator(tf.keras.Model):
    def __init__(self):
        super().__init__()

        self.main_conv2d_1 = layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same")
        self.main_leaky_1 = layers.LeakyReLU(alpha=0.2)
        self.main_conv2d_2 = layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same")
        self.main_leaky_2 = layers.LeakyReLU(alpha=0.2)
        self.main_conv2d_3 = layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same")
        self.main_leaky_3 = layers.LeakyReLU(alpha=0.2)
        self.main_maxpool = layers.GlobalMaxPooling2D()
        self.main_linear = layers.Dense(1)

        self.label_emb = layers.Embedding(input_dim=10, output_dim=50)
        self.label_linear1 = layers.Dense(units = 28*28)
        self.label_reshape = layers.Reshape(target_shape=(28,28,-1))

        self.concat_layer = layers.Concatenate(axis=-1)

    def call(self, img, label):
        
        # turn label into activation map
        label_map = self.label_reshape(self.label_linear1(self.label_emb(label)))

        # concatenate all maps
        concat_img = self.concat_layer([img, label_map])

        # upsample to create image
        x = self.main_conv2d_1(concat_img)
        x = self.main_leaky_1(x)
        x = self.main_conv2d_2(x)
        x = self.main_leaky_2(x)
        x = self.main_conv2d_3(x)
        x = self.main_leaky_3(x)
        x = self.main_maxpool(x)
        pred = self.main_linear(x)

        return pred

array([0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 2, 3,
       4, 5, 6])

In [17]:
for b in train_ds:
    b
    break

gen = Conditional_Generator()
disc = Conditional_Discriminator()
latent_code = tf.random.normal(shape=(b[0].shape[0],128))
disc(b[0],b[1])

<tf.Tensor: shape=(256, 1), dtype=float32, numpy=
array([[-0.19739182],
       [-0.24189803],
       [-0.19816041],
       [-0.20420948],
       [-0.20160657],
       [-0.17229977],
       [-0.22877547],
       [-0.2351267 ],
       [-0.19057983],
       [-0.1843774 ],
       [-0.21497685],
       [-0.18236613],
       [-0.20547935],
       [-0.22799948],
       [-0.23009518],
       [-0.24682301],
       [-0.18350708],
       [-0.21895677],
       [-0.20812657],
       [-0.17855173],
       [-0.18525246],
       [-0.21598698],
       [-0.2046971 ],
       [-0.11593517],
       [-0.16629311],
       [-0.24558498],
       [-0.22924101],
       [-0.24169654],
       [-0.17694716],
       [-0.19483742],
       [-0.18677871],
       [-0.18714418],
       [-0.1756235 ],
       [-0.22283612],
       [-0.20114493],
       [-0.25978965],
       [-0.19561923],
       [-0.16660868],
       [-0.18597284],
       [-0.20292082],
       [-0.1996056 ],
       [-0.139725  ],
       [-0.24012002],
    

In [None]:
# loss function
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)

# optimizers
gen_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0004)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0003)

# metrics
disc_loss_tracker = tf.keras.metrics.Mean(name='disc_loss')
gen_loss_tracker = tf.keras.metrics.Mean(name='gen_loss')

# tensorboard
experiment_name = 'lbl01_glr0004_dlr0003_g13m_d390k'
log_dir = '../logs/'+experiment_name
img_save_dir = '../generated_imgs/'+experiment_name
summary_writer = tf.summary.create_file_writer(log_dir)

latent_code_size = 128
# fix latent code to track improvement
latent_code4visualization = tf.random.normal(shape=(25,latent_code_size))
epochs = 30

for epoch in range(epochs):

    for _,real_imgs in train_ds.enumerate():
        
        # PART 1: DISC TRAINING, fixed generator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as disc_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass real and fake images
            real_preds,fake_preds = Discriminator(real_imgs),Discriminator(generated_imgs)
            y_pred = tf.concat([real_preds,fake_preds],axis=0)
            y_true = tf.concat([tf.ones_like(real_preds),tf.zeros_like(fake_preds)],axis=0)
            
            # compute loss
            disc_loss = loss_fn(y_true=y_true,y_pred=y_pred)

        # compute disc gradients
        disc_gradients = disc_tape.gradient(disc_loss,Discriminator.trainable_variables)

        # update disc weights
        disc_optimizer.apply_gradients(zip(disc_gradients, Discriminator.trainable_variables))

        # update disc metrics
        disc_loss_tracker.update_state(disc_loss)


        # PART 2: GEN TRAINING, fixed discriminator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as gen_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass only images
            fake_preds = Discriminator(generated_imgs)

            # compute loss
            gen_loss = loss_fn(y_true=tf.ones_like(fake_preds),y_pred=fake_preds)

        # compute gen gradients
        gen_gradients = gen_tape.gradient(gen_loss,Generator.trainable_variables)

        # update gen weights
        gen_optimizer.apply_gradients(zip(gen_gradients, Generator.trainable_variables))

        # update gen metrics
        gen_loss_tracker.update_state(gen_loss)


    # generate and save sample images per epoch
    test_generated_imgs = Generator(latent_code4visualization)
    test_generated_imgs = (((test_generated_imgs+1.)/2.) * 255.).numpy()
    fig = plt.figure(figsize=(5, 5))
    for i in range(test_generated_imgs.shape[0]):
        plt.subplot(5, 5, i+1)
        plt.imshow(test_generated_imgs[i,:,:,0], cmap='gray')
        plt.axis('off')
    plt.savefig(img_save_dir)
    

    # display and record metrics at the end of each epoch.
    with summary_writer.as_default():
        tf.summary.scalar('disc_loss', disc_loss_tracker.result(), step=epoch)
        tf.summary.scalar('gen_loss', gen_loss_tracker.result(), step=epoch)
        tf.summary.image(name='test_samples',data=test_generated_imgs,max_outputs=test_generated_imgs.shape[0],step=epoch)

    disc_loss,gen_loss = disc_loss_tracker.result(),gen_loss_tracker.result()
    print(f'epoch: {epoch}, disc_loss: {disc_loss:.4f}, gen_loss: {gen_loss:.4f}')

    # reset metric states
    disc_loss_tracker.reset_state()
    gen_loss_tracker.reset_state()