<a href="https://colab.research.google.com/github/consequencesunintended/BEGAN/blob/main/BEGAN_custom_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
from keras.engine.topology import Layer
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import os
import time
from IPython import display
print (tf.__version__)

In [None]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Activation, Reshape
from tensorflow.keras.layers import Convolution2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Convolution2DTranspose
from tensorflow.keras.layers import concatenate

In [None]:
if tf.test.gpu_device_name():
    print('Default GPU Device: {}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")

In [None]:
DIMENSION = 64
noise_dim = 128
BATCH_SIZE = 8
EPOCHS = 1000

In [None]:
data_dir = "data/real/images/*.png"
list_ds = tf.data.Dataset.list_files(data_dir)

In [None]:
list_ds

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
for f in list_ds.take(5):
    print(f.numpy())

In [None]:
def decode_img(img):
    img = tf.image.decode_png(img, channels=4)
    alphas = img[:,:,3:4]
    img = img[:,:,:3]
    img = tf.image.convert_image_dtype(img, tf.float32)
    alphas = tf.image.convert_image_dtype(alphas, tf.float32)
    img = alphas * img
    img = tf.image.resize(img, [DIMENSION, DIMENSION])
    
    return img

In [None]:
def process_path(file_path):
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img

In [None]:
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

In [None]:
SHUFFLE_BUFFER_SIZE = 100
train_dataset_final = labeled_ds.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
train_dataset_final = train_dataset_final.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
NUM_IMAGES = len(labeled_ds)

In [None]:
for image in train_dataset_final.take(1):
    print("Image shape: ", image.numpy().shape)

In [None]:
sample_training_images = next(iter(train_dataset_final))

In [None]:
plt.imshow(sample_training_images[0])

In [None]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(DIMENSION,DIMENSION))
    axes = axes.flatten()
    for img, ax in zip( images_arr, axes):
        ax.imshow(img.numpy())          
        ax.axis('off')    
    plt.tight_layout()
    plt.show()

In [None]:
plotImages(sample_training_images[:5])

In [None]:
class Resize_nn(Layer):
    def __init__(self, image_size=(512, 512), **kwargs):
        self.image_size = image_size[0], image_size[1]
        super(Resize_nn, self).__init__(**kwargs)

    def call(self, inputs, **kwargs):
        return tf.image.resize(inputs, self.image_size, method='nearest')

    def compute_output_shape(self, input_shape):
        return input_shape[0], self.image_size[0], self.image_size[1], input_shape[-1]

In [None]:
def decoder():
    kernel= 3
    filters = 16

    inputs = Input(shape=(noise_dim,))
    
    dense_1 = Dense(8*8*filters)(inputs)
    dens_1_reshaped = Reshape([8,8,filters])(dense_1)

    conv_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(dens_1_reshaped)
    conv_1 = tf.nn.elu(conv_1)
    
    conv_2 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(conv_1)
    conv_2 = tf.nn.elu(conv_2)
    conv_2 = Resize_nn([DIMENSION//8, DIMENSION//8])(conv_2)
    
    print("decorder-1 DONE!")
    
    conv_3 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_2)
    conv_3 = tf.nn.elu(conv_3)
    
    conv_4 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(conv_3)
    conv_4 = tf.nn.elu(conv_4)
    conv_4 = Resize_nn([DIMENSION//4, DIMENSION//4])(conv_4)
    
    print("decorder-2 DONE!")
    
    conv_4 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_4)
    conv_4 = tf.nn.elu(conv_4)
    
    conv_5 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(conv_4)
    conv_5 = tf.nn.elu(conv_5)
    conv_5 = Resize_nn([DIMENSION//2, DIMENSION//2])(conv_5)
    
    print("decorder-3 DONE!")
    
    conv_5 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_5)
    conv_5 = tf.nn.elu(conv_5)
    
    conv_6 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(conv_5)
    conv_6 = tf.nn.elu(conv_6)
    conv_6 = Resize_nn([DIMENSION, DIMENSION])(conv_6)
    
    print("decorder-4 DONE!")
    
    conv_7 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_6)
    conv_7 = tf.nn.elu(conv_7)
    
    conv_8 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_7)
    conv_8 = tf.nn.elu(conv_8)
    
    conv_9 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(conv_8)
    conv_9 = tf.nn.elu(conv_9)    
    
    conv_30 = Convolution2D(3, (kernel, kernel), padding="same")(conv_9)
    conv_30 = Activation("sigmoid")(conv_30)

    outputs = (conv_30)   
    
    print("Decoder Model DONE!")

    return Model(inputs=inputs, outputs=outputs, name="Decoder")

In [None]:
def encoder():
    input_shape = [DIMENSION,DIMENSION,3]
    kernel= 3
    filters = 16

    inputs = Input(shape=input_shape)
    
    conv_1 = Convolution2D(filters, (kernel, kernel), strides=(1,1), padding="same")(inputs)
    conv_1 = tf.nn.elu(conv_1)
    
    conv_2 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_1)
    conv_2 = tf.nn.elu(conv_2)
    conv_2 = Resize_nn([DIMENSION, DIMENSION])(conv_2)
    
    print("encoder-1 DONE!")
    
    conv_3 = Convolution2D(2*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_2)
    conv_3 = tf.nn.elu(conv_3)
    
    conv_4 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_3)
    conv_4 = tf.nn.elu(conv_4)
    conv_4 = Resize_nn([DIMENSION//2, DIMENSION//2])(conv_4)
    
    print("encoder-2 DONE!")
    
    conv_5 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_4)
    conv_5 = tf.nn.elu(conv_5)
    
    conv_6 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_5)
    conv_6 = tf.nn.elu(conv_6)
    conv_6 = Resize_nn([DIMENSION//4, DIMENSION//4])(conv_6)
    
    print("encoder-3 DONE!")
    
    conv_7 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_6)
    conv_7 = tf.nn.elu(conv_7)
    
    conv_8 = Convolution2D(3*filters, (kernel, kernel), strides=(1,1), padding="same")(conv_7)
    conv_8 = tf.nn.elu(conv_8)
    conv_8 = Resize_nn([DIMENSION//8, DIMENSION//8])(conv_8)
    
    conv_8_flatten  = Flatten()(conv_8)
    dense_1 = Dense(8*8*3*filters)(conv_8_flatten)
    dense_2 = Dense(noise_dim)(dense_1)

    outputs = (dense_2)   
    
    print("Encoder Model DONE!")

    return Model(inputs=inputs, outputs=outputs, name="encoder")

In [None]:
np.unique(sample_training_images[0])

In [None]:
sample_training_images[0].shape

In [None]:
def make_generator_model():
    
    d_model = decoder()
    
    model = tf.keras.Sequential()

    model.add(d_model)
    
    return model

In [None]:
generator = make_generator_model()

In [None]:
plt.imshow(sample_training_images[0])

In [None]:
input_images = tf.random.normal([1, noise_dim])
generated_image = generator(input_images, training=False)

plt.imshow(generated_image[0, :, :])

In [None]:
def make_discriminator_model():
    
    e_model = encoder()
    d_model = decoder()
    
    model = tf.keras.Sequential()

    model.add(e_model)
    model.add(d_model)
    
    return model

In [None]:
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision.shape)

In [None]:
def began_autoencoder_loss(out, inp):
    
    diff = tf.abs(out - inp)
    
    return tf.reduce_mean(diff)

In [None]:
def get_loss_values(k_t, gamma, D_real_in, D_real_out, D_gen_in, D_gen_out):
    mu_real = began_autoencoder_loss(D_real_out, D_real_in)
    mu_gen = began_autoencoder_loss(D_gen_out, D_gen_in)
    
    D_loss = mu_real - k_t * mu_gen
    G_loss = mu_gen
    
    lambda_v = 0.001
    k_tp = k_t + lambda_v * (gamma * mu_real - mu_gen)
    
    convergence_measure = mu_real + np.abs(gamma * mu_real - mu_gen)
    
    
    return G_loss, D_loss, k_tp, convergence_measure

In [None]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4,)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [None]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [None]:
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])

In [None]:
def train_step(k_t, images):

    D_gen_in = tf.random.normal([BATCH_SIZE, noise_dim])
    D_real_in = images

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator(D_gen_in, training=True)

        d_generated_images = discriminator(generated_images, training=True)
        discrimanted_images = discriminator(D_real_in, training=True)
        
        gamma = 0.75
        gen_loss, disc_loss, k_t, convergence_measure = get_loss_values(k_t, gamma, D_real_in, discrimanted_images, generated_images, d_generated_images)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)    

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
    

    return gen_loss, disc_loss, k_t, convergence_measure

In [None]:
def generate_and_save_images(model, epoch, test_input):
    predictions = model(test_input, training=False)

    fig = plt.figure(figsize=(8,8))

    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow(predictions[i, :, :] )
            
        plt.axis('off')

    plt.show()

In [None]:
STEP_SIZE = NUM_IMAGES // BATCH_SIZE
def train(epochs):
    k_t = 0.0
    step = 0
    convergence_measure = 0.0
    
    for epoch in range(epochs):
        start = time.time()
        
        g_loss = 0
        d_loss = 0

        for _ in range(STEP_SIZE):
            g_loss, d_loss, k_t, convergence_measure = train_step(min(max(k_t, 0.0), 1.0), next(iter(train_dataset_final)))
            step += 1
            
            if ( step % 100 == 0 ):
                display.clear_output(wait=True)
                print( 'Generator loss:{} Discrimantor loss:{} Convergence:{} K_t:{} step: {}'.format(g_loss, d_loss, convergence_measure, k_t, step))

                generate_and_save_images(generator,
                         epoch + 1,
                         seed)
                
        display.clear_output(wait=True)

        print( 'Generator loss:{} Discrimantor loss:{} Convergence:{} K_t:{} step: {}'.format(g_loss, d_loss, convergence_measure, k_t, step))
        generate_and_save_images(generator,
                                 epoch + 1,
                                 seed)

        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)

        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

    display.clear_output(wait=True)
    generate_and_save_images(generator,
                           epochs,
                           seed)

In [None]:
generate_and_save_images(generator, 0, seed)

In [None]:
%%time
train(EPOCHS)