In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# [0,255] -> [0,1] -> [-1,1]
x_train = (x_train/255.) * 2. - 1.

x_train = np.expand_dims(x_train,axis=3)
x_train = tf.cast(x_train,dtype=tf.float32)

train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(1000).batch(256)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

2022-04-30 23:32:33.714503: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [1]:
from tensorflow import keras
from tensorflow.keras import layers
# GENERATOR
Generator = keras.Sequential(
    [
        keras.Input(shape=(128,)),
        layers.Dense(7 * 7 * 128),
        layers.Reshape((7, 7, 128)),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(1, (7, 7), padding="same", activation="tanh"),
    ],
    name="generator",
)

# DISCRIMINATOR
Discriminator = keras.Sequential(
    [
        keras.Input(shape=(28, 28, 1)),
        layers.Conv2D(64, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (4, 4), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
)

2022-05-19 11:01:16.646889: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [4]:
Discriminator.summary()

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_1 (Conv2D)           (None, 14, 14, 64)        1088      
                                                                 
 leaky_re_lu_2 (LeakyReLU)   (None, 14, 14, 64)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 7, 7, 128)         131200    
                                                                 
 leaky_re_lu_3 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                                 
 conv2d_3 (Conv2D)           (None, 4, 4, 128)         262272    
                                                                 
 leaky_re_lu_4 (LeakyReLU)   (None, 4, 4, 128)         0         
                                                                 
 global_max_pooling2d (Globa  (None, 128)            

In [None]:
# loss function
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.1)

# optimizers
gen_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0004)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.0003)

# metrics
disc_loss_tracker = tf.keras.metrics.Mean(name='disc_loss')
gen_loss_tracker = tf.keras.metrics.Mean(name='gen_loss')

# tensorboard
experiment_name = 'lbl01_glr0004_dlr0003_g13m_d390k'
log_dir = '../logs/'+experiment_name
img_save_dir = '../generated_imgs/'+experiment_name
summary_writer = tf.summary.create_file_writer(log_dir)

latent_code_size = 128
# fix latent code to track improvement
latent_code4visualization = tf.random.normal(shape=(25,latent_code_size))
epochs = 30

for epoch in range(epochs):

    for _,real_imgs in train_ds.enumerate():
        
        # PART 1: DISC TRAINING, fixed generator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as disc_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass real and fake images
            real_preds,fake_preds = Discriminator(real_imgs),Discriminator(generated_imgs)
            y_pred = tf.concat([real_preds,fake_preds],axis=0)
            y_true = tf.concat([tf.ones_like(real_preds),tf.zeros_like(fake_preds)],axis=0)
            
            # compute loss
            disc_loss = loss_fn(y_true=y_true,y_pred=y_pred)

        # compute disc gradients
        disc_gradients = disc_tape.gradient(disc_loss,Discriminator.trainable_variables)

        # update disc weights
        disc_optimizer.apply_gradients(zip(disc_gradients, Discriminator.trainable_variables))

        # update disc metrics
        disc_loss_tracker.update_state(disc_loss)


        # PART 2: GEN TRAINING, fixed discriminator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as gen_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass only images
            fake_preds = Discriminator(generated_imgs)

            # compute loss
            gen_loss = loss_fn(y_true=tf.ones_like(fake_preds),y_pred=fake_preds)

        # compute gen gradients
        gen_gradients = gen_tape.gradient(gen_loss,Generator.trainable_variables)

        # update gen weights
        gen_optimizer.apply_gradients(zip(gen_gradients, Generator.trainable_variables))

        # update gen metrics
        gen_loss_tracker.update_state(gen_loss)


    # generate and save sample images per epoch
    test_generated_imgs = Generator(latent_code4visualization)
    test_generated_imgs = (((test_generated_imgs+1.)/2.) * 255.).numpy()
    fig = plt.figure(figsize=(5, 5))
    for i in range(test_generated_imgs.shape[0]):
        plt.subplot(5, 5, i+1)
        plt.imshow(test_generated_imgs[i,:,:,0], cmap='gray')
        plt.axis('off')
    plt.savefig(img_save_dir)
    

    # display and record metrics at the end of each epoch.
    with summary_writer.as_default():
        tf.summary.scalar('disc_loss', disc_loss_tracker.result(), step=epoch)
        tf.summary.scalar('gen_loss', gen_loss_tracker.result(), step=epoch)
        tf.summary.image(name='test_samples',data=test_generated_imgs,max_outputs=test_generated_imgs.shape[0],step=epoch)

    disc_loss,gen_loss = disc_loss_tracker.result(),gen_loss_tracker.result()
    print(f'epoch: {epoch}, disc_loss: {disc_loss:.4f}, gen_loss: {gen_loss:.4f}')

    # reset metric states
    disc_loss_tracker.reset_state()
    gen_loss_tracker.reset_state()