In [1]:
import tensorflow as tf
import numpy as np
import util
import matplotlib.pyplot as plt
import io
from PIL import Image
import os
import cytoolz as cz
import sys
from tqdm import tqdm_notebook as tqdm


from callbacks.GanSummary import GanSummary, GanSummary2
from callbacks import Callbacks

In [2]:
x_train, y_train, x_test, y_test = util.getKaggleMNIST()

In [3]:
def build_discriminator(input_shape):
    net = tf.keras.Sequential([
        tf.keras.layers.Conv2D(2, (5, 5), 
                               strides=(2, 2), padding='same',
                               activation=tf.nn.leaky_relu,
                               input_shape=input_shape),
        tf.keras.layers.Dropout(0.3),

        tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Dropout(0.3),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1024),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),
        tf.keras.layers.Dense(1, activation='sigmoid')
    ], name='discriminator')
    return net

In [4]:
discriminator = build_discriminator((28,28,1,))
discriminator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 14, 14, 2)         52        
_________________________________________________________________
dropout (Dropout)            (None, 14, 14, 2)         0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 7, 7, 64)          3264      
_________________________________________________________________
batch_normalization (BatchNo (None, 7, 7, 64)          256       
_________________________________________________________________
leaky_re_lu (LeakyReLU)      (None, 7, 7, 64)          0         
_________________________________________________________________
dropout_1 (Dropout)          (None, 7, 7, 64)          0         
_________________________________________________________________
flatten (Flatten)            (None, 3136)              0         
__________

In [None]:
def build_generator(input_shape, output_activation="sigmoid"):
    net = tf.keras.Sequential([
        tf.keras.layers.Dense(1024, activation=tf.nn.leaky_relu,
                               input_shape=input_shape),
        tf.keras.layers.Dense(7*7*128, activation=tf.nn.leaky_relu),
        tf.keras.layers.Reshape((7,7, 128)),

        tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(2, 2), padding='same', use_bias=True),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(),

        tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=True,
                                       activation=output_activation),
    ], name='generator')
    return net

In [None]:
epochs=150
batch_size=32
plot_data=10
latent_dim = 100
data_length = x_train.shape[0]
learning_rate = 0.0001

In [None]:
generator = build_generator((latent_dim,))
generator.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_2 (Dense)              (None, 1024)              103424    
_________________________________________________________________
dense_3 (Dense)              (None, 6272)              6428800   
_________________________________________________________________
reshape (Reshape)            (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose (Conv2DTran (None, 14, 14, 128)       409728    
_________________________________________________________________
batch_normalization_2 (Batch (None, 14, 14, 128)       512       
_________________________________________________________________
leaky_re_lu_2 (LeakyReLU)    (None, 14, 14, 128)       0         
_________________________________________________________________
conv2d_transpose_1 (Conv2DTr (None, 28, 28, 1)         3201      
Total para

In [None]:
discriminator = tf.keras.Sequential([
  tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', 
                                     input_shape=[28, 28, 1]),
  tf.keras.layers.LeakyReLU(),
  tf.keras.layers.Dropout(0.3),
      
  tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'),
  tf.keras.layers.LeakyReLU(),
  tf.keras.layers.Dropout(0.3),
       
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(1, activation='sigmoid')
], name='discriminator')

generator = tf.keras.Sequential([
  tf.keras.layers.Dense(7 * 7 * 256, use_bias=False, input_shape=(100,)),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),

  tf.keras.layers.Reshape((7, 7, 256)),

  tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),

  tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False),
  tf.keras.layers.BatchNormalization(),
  tf.keras.layers.LeakyReLU(),

  tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='sigmoid')
], name='generator')

In [None]:
def AdversarialLoss(discriminator):
    def model_loss(y_true, y_pred):
        """
            y_true: real image
            y_pred: image generated by autoencoder
        """
        disc_fake = discriminator(y_pred)
        adversarial_loss = -tf.log(disc_fake + 1e-5)
        return adversarial_loss

    return model_loss

In [None]:
# Compile models
discriminator.compile(loss='binary_crossentropy',
    optimizer=tf.train.AdamOptimizer(learning_rate),
    metrics=['accuracy'])

generator.compile(loss=AdversarialLoss(discriminator), optimizer=tf.train.AdamOptimizer(learning_rate))

In [None]:
data_to_plot = np.random.normal(0, 1, (plot_data, latent_dim))

In [None]:
summaries_dir = "summaries/gan_mnist"
tf.gfile.DeleteRecursively(summaries_dir) if tf.gfile.Exists(summaries_dir) else None
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=summaries_dir)
callbacks = [
    tensorboard_callback, 
    GanSummary2(tensorboard_callback,data_to_plot,discriminator, update_freq=1)
]

In [None]:
initial_epoch=0
steps_per_epoch = data_length//batch_size

In [None]:
# callbacks
callbacks = Callbacks(callbacks)

callbacks.set_model(generator)
callbacks.set_params(
    {
        "batch_size": batch_size,
        "epochs": epochs,
        "steps": steps_per_epoch,
        "samples": None,
        "verbose": 1,
        "do_validation": False,
        "metrics": generator.metrics_names,
    }
)
callbacks.on_train_begin()

# Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

In [None]:
initial_step = initial_epoch * steps_per_epoch
final_step = initial_step + epochs * steps_per_epoch

step_partitions = range(initial_step, final_step)
step_partitions = cz.partition_all(steps_per_epoch, step_partitions)

In [None]:
# Training loop
epoch_bar = tqdm(total=epochs, desc="Epoch 0/{}".format(epochs)) # Initialis

for epoch_index, epoch_steps in enumerate(step_partitions):
    
    batch_bar = tqdm(total=steps_per_epoch) # Initialis
    epoch = initial_epoch + epoch_index

    logs = {}
    callbacks.on_epoch_begin(epoch, logs=logs)

    for batch_index, step in enumerate(epoch_steps):
        # batch start
        callbacks.on_batch_begin(batch_index, logs=logs)
        
        train_batch = x_train[batch_index*batch_size:(batch_index+1)*batch_size]
        train_batch = train_batch.reshape(-1,28,28,1)
        
        noise = np.random.normal(0, 1, (batch_size, latent_dim)).astype(np.float32)
        ########### train generator #########
        metric_values = generator.train_on_batch(
            x=noise,
            y=valid,
        )
        metrics_names = generator.metrics_names
        metric_values = [metric_values] if not isinstance(metric_values,(list,tuple)) else metric_values
        logs.update(dict(zip(metrics_names, metric_values)))
        gen_loss = metric_values[0]
        #print("Generator: ", metrics_names, metric_values)
        #generator.reset_states()

        generated_batch = generator.predict(noise)
        #print(train_batch.shape, generated_batch.shape, fake.shape)
        
        images = np.concatenate([train_batch, generated_batch], axis=0)
        labels = np.concatenate([valid, fake], axis=0)
        #print(labels.shape, images.shape)
        
        ###### train discriminator #######
        metric_values = discriminator.train_on_batch(
            x=images, y=labels
        )
        metrics_names = [
            "discriminator_{}".format(name) for name in discriminator.metrics_names
        ]
        #print("Discriminator: ", metrics_names, metric_values)
        logs.update(dict(zip(metrics_names, metric_values)))
        disc_loss = metric_values[0]
        #discriminator.reset_states()
        
        # batch end
        callbacks.on_batch_end(batch_index, logs=logs)
  
        batch_bar.update(1)
        desc = "gen_loss {:.3f} disc_loss {:.3f}".format(gen_loss, disc_loss)
        batch_bar.set_description(desc=desc)
        
    callbacks.on_epoch_end(epoch, logs=logs)
    #generator.reset_metrics()
    #discriminator.reset_metrics()

    epoch_bar.update(1)
    desc = "Epoch {}/{}".format(epoch_index+1,epochs)
    epoch_bar.set_description(desc=desc)
    batch_bar.clear()

callbacks.on_train_end(logs=logs)

HBox(children=(IntProgress(value=0, description='Epoch 0/150', max=150, style=ProgressStyle(description_width=…

HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))



HBox(children=(IntProgress(value=0, max=1281), HTML(value='')))