In [1]:
import numpy as np
import tensorflow as tf

In [None]:
# roadmap:
# 1-implement dcgan in tf and observe generated examples
# 2-investigate tf dcgan implementation

# 3-implement dcgan in torch and observed generated examples
# 4-investigate torch dcgan implementation

In [2]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# [0,255] -> [0,1] -> [-1,1]
x_train = (x_train/255.) * 2. - 1.

x_train = np.expand_dims(x_train,axis=3)
x_train = tf.cast(x_train,dtype=tf.float32)

train_ds = tf.data.Dataset.from_tensor_slices(x_train)
train_ds = train_ds.shuffle(1000).batch(32)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)

2022-04-30 01:15:24.829208: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
# GENERATOR
# use batchnorm and dropout on G !
inputs = tf.keras.Input(shape=(100,))
x = tf.keras.layers.Dense(units=(4*4*256),activation='relu')(inputs)
x = tf.reshape(x,[-1,4,4,256])
x = tf.keras.layers.Conv2DTranspose(filters=128,kernel_size=5,strides=2,activation='relu',padding='same')(x)
x = tf.keras.layers.Conv2DTranspose(filters=64,kernel_size=7,strides=1,activation='relu',padding='valid')(x)
outputs = tf.keras.layers.Conv2DTranspose(filters=1,kernel_size=5,strides=2,activation='tanh',padding='same')(x)
Generator = tf.keras.Model(inputs=inputs, outputs=[outputs])

# DISCRIMINATOR
# use batchnorm on D !
inputs = tf.keras.Input(shape=(28,28,1))
x = tf.keras.layers.Conv2D(filters=64,kernel_size=5,strides=2)(inputs)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Conv2D(filters=128,kernel_size=5,strides=2)(x)
x = tf.keras.layers.LeakyReLU(alpha=0.2)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(units=1)(x)
Discriminator = tf.keras.Model(inputs=inputs, outputs=[outputs])


In [5]:
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True,label_smoothing=0.0)

gen_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)
disc_optimizer = tf.keras.optimizers.Adam(learning_rate = 0.001)

disc_loss_tracker = tf.keras.metrics.Mean(name='disc_loss')
gen_loss_tracker = tf.keras.metrics.Mean(name='gen_loss')

# tensorboard show generated images per epoch
# add other callbacks

epochs = 5
latent_code_size = 100
for epoch in range(epochs):

    for _,real_imgs in train_ds.enumerate():
        
        # PART 1: DISC TRAINING, fixed generator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as disc_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass real and fake images
            real_preds,fake_preds = Discriminator(real_imgs),Discriminator(generated_imgs)
            y_pred = tf.concat([real_preds,fake_preds],axis=0)
            y_true = tf.concat([tf.ones_like(real_preds),tf.zeros_like(fake_preds)],axis=0)
            
            # compute loss
            disc_loss = loss_fn(y_true=y_true,y_pred=y_pred)

        # compute disc gradients
        disc_gradients = disc_tape.gradient(disc_loss,Discriminator.variables)

        # update disc weights
        disc_optimizer.apply_gradients(zip(disc_gradients, Discriminator.variables))

        # update disc metrics
        disc_loss_tracker.update_state(disc_loss)


        # PART 2: GEN TRAINING, fixed discriminator
        latent_code = tf.random.normal(shape=(real_imgs.shape[0],latent_code_size))

        with tf.GradientTape() as gen_tape:
            # generate fake images
            generated_imgs = Generator(latent_code)

            # forward pass only images
            fake_preds = Discriminator(generated_imgs)

            # compute loss
            gen_loss = loss_fn(y_true=tf.ones_like(fake_preds),y_pred=fake_preds)

        # compute gen gradients
        gen_gradients = gen_tape.gradient(gen_loss,Generator.variables)

        # update gen weights
        gen_optimizer.apply_gradients(zip(gen_gradients, Generator.variables))

        # update gen metrics
        gen_loss_tracker.update_state(gen_loss)


    # display metrics at the end of each epoch.
    disc_loss,gen_loss = disc_loss_tracker.result(),gen_loss_tracker.result()
    print(f'epoch: {epoch}, disc_loss: {disc_loss:.4f}, gen_loss: {gen_loss:.4f}')

    # reset metric states
    disc_loss_tracker.reset_state()
    gen_loss_tracker.reset_state()

epoch: 0, disc_loss: 0.0029, gen_loss: 11.3791


KeyboardInterrupt: 