<a href="https://colab.research.google.com/github/dude123studios/AdvancedGenerativeLearning/blob/main/BicycleGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U tensorflow-addons

Collecting tensorflow-addons
  Downloading tensorflow_addons-0.13.0-cp37-cp37m-manylinux2010_x86_64.whl (679 kB)
[?25l[K     |▌                               | 10 kB 29.1 MB/s eta 0:00:01[K     |█                               | 20 kB 21.3 MB/s eta 0:00:01[K     |█▌                              | 30 kB 10.9 MB/s eta 0:00:01[K     |██                              | 40 kB 8.7 MB/s eta 0:00:01[K     |██▍                             | 51 kB 6.4 MB/s eta 0:00:01[K     |███                             | 61 kB 7.2 MB/s eta 0:00:01[K     |███▍                            | 71 kB 6.6 MB/s eta 0:00:01[K     |███▉                            | 81 kB 7.5 MB/s eta 0:00:01[K     |████▍                           | 92 kB 6.8 MB/s eta 0:00:01[K     |████▉                           | 102 kB 6.9 MB/s eta 0:00:01[K     |█████▎                          | 112 kB 6.9 MB/s eta 0:00:01[K     |█████▉                          | 122 kB 6.9 MB/s eta 0:00:01[K     |██████▎                  

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu, tanh
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy
import tensorflow_datasets as tfds
from tensorflow_addons.layers import InstanceNormalization

import numpy as np
import matplotlib.pyplot as plt
import os

import warnings
warnings.filterwarnings('ignore')

In [None]:
dataset_name = 'edges2shoes'
_URL = f'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/{dataset_name}.tar.gz'

path_to_zip = tf.keras.utils.get_file(f'{dataset_name}.tar.gz',
                                      origin=_URL,
                                      extract=True)

PATH = os.path.join(os.path.dirname(path_to_zip), f'{dataset_name}/')

dataset_name = "facades"

Downloading data from http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/edges2shoes.tar.gz


In [None]:
image_shape = (256, 256, 3)
IMG_HEIGHT = image_shape[0]
IMG_WIDTH = image_shape[1]

In [None]:
BATCH_SIZE = 8
BUFFER_SIZE = 400

def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.image.decode_jpeg(image)

    w = tf.shape(image)[1]

    w = w // 2
    real_image = image[:, w:, :]
    input_image = image[:, :w, :]

    input_image = tf.cast(input_image, tf.float32)
    real_image = tf.cast(real_image, tf.float32)

    return input_image, real_image

def resize(input_image, real_image, height, width):
    input_image = tf.image.resize(input_image, [height, width],
                                method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    real_image = tf.image.resize(real_image, [height, width],
                               method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    return input_image, real_image

def random_crop(input_image, real_image):
    stacked_image = tf.stack([input_image, real_image], axis=0)
    cropped_image = tf.image.random_crop(
      stacked_image, size=[2, IMG_HEIGHT, IMG_WIDTH, 3])

    return cropped_image[0], cropped_image[1]

def normalize(input_image, real_image):
    input_image = (input_image / 127.5) - 1
    real_image = (real_image / 127.5) - 1

    return input_image, real_image

#@tf.function()
def random_jitter(input_image, real_image):
    # resizing to 286 x 286 x 3
    input_image, real_image = resize(input_image, real_image, 286, 286)

    # randomly cropping to 256 x 256 x 3
    input_image, real_image = random_crop(input_image, real_image)

    if tf.random.uniform(()) > 0.5:
        # random mirroring
        input_image = tf.image.flip_left_right(input_image)
        real_image = tf.image.flip_left_right(real_image)

    return input_image, real_image

def load_image_train(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = random_jitter(input_image, real_image)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

def load_image_test(image_file):
    input_image, real_image = load(image_file)
    input_image, real_image = resize(input_image, real_image,
                                   IMG_HEIGHT, IMG_WIDTH)
    input_image, real_image = normalize(input_image, real_image)

    return input_image, real_image

train_dataset = tf.data.Dataset.list_files(PATH+'train/*.jpg')
train_dataset = train_dataset.map(load_image_train,
                                  num_parallel_calls=tf.data.experimental.AUTOTUNE)
train_dataset = train_dataset.shuffle(BUFFER_SIZE)
train_dataset = train_dataset.batch(BATCH_SIZE)

test_dataset = tf.data.Dataset.list_files(PATH+'val/*.jpg')
test_dataset = test_dataset.map(load_image_test)
test_dataset = test_dataset.batch(1).repeat()

In [None]:
def downsample(channels, kernels, strides=2, norm=True, activation=True, dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    block = tf.keras.Sequential()
    block.add(layers.Conv2D(channels, kernels, strides=strides, padding='same', 
                                use_bias=False, kernel_initializer=initializer))
    
    if norm:
        block.add(InstanceNormalization())              
    if activation:
        block.add(layers.LeakyReLU(0.2)) 
    if dropout:
        block.add(layers.Dropout(0.5))

    return block

def upsample(channels, kernels, strides=1, norm=True, activation=True, dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    block = tf.keras.Sequential()
    block.add(layers.UpSampling2D((2,2)))
    block.add(layers.Conv2D(channels, kernels, strides=strides, padding='same', 
                            use_bias=False, kernel_initializer=initializer))

    if norm:
        block.add(InstanceNormalization())              
    if activation:
        block.add(layers.LeakyReLU(0.2)) 
    if dropout:
        block.add(layers.Dropout(0.5))

    return block

In [None]:
def build_generator(z_dim, image_shape):

    DIM = 64
    input_image = layers.Input(shape=image_shape)
    input_z = layers.Input(shape=(z_dim,), name='z')
        
    z = layers.Reshape((1,1, z_dim))(input_z)
    z_tiles = tf.tile(z, [1, image_shape[0], image_shape[1], 8])
    x = layers.Concatenate()([input_image, z_tiles])

    down1 = downsample(DIM, 4, norm=False)(x)
    down2 = downsample(2 * DIM, 4, norm=False)(down1)
    down3 = downsample(4 * DIM, 4, norm=False)(down2)
    down4 = downsample(4 * DIM, 4, norm=False)(down3)
    down5 = downsample(4 * DIM, 4, norm=False)(down4)
    down6 = downsample(4 * DIM, 4, norm=False)(down5)
    down7 = downsample(4 * DIM, 4, norm=False)(down6)

    up6 = upsample(4*DIM, 4, dropout=True)(down7)
    concat6 = layers.Concatenate()([up6, down6]) 

    up5 = upsample(4*DIM, 4, dropout=True)(concat6)
    concat5 = layers.Concatenate()([up5, down5]) 
    
    up4 = upsample(4*DIM, 4)(concat5)
    concat4 = layers.Concatenate()([up4, down4]) 

    up3 = upsample(4*DIM, 4)(concat4)
    concat3 = layers.Concatenate()([up3, down3]) 

    up2 = upsample(2*DIM, 4)(concat3)
    concat2 = layers.Concatenate()([up2, down2]) 

    up1 = upsample(DIM, 4)(concat2)
    concat1 = layers.Concatenate()([up1, down1]) 

    output_image = tanh(upsample(3, 4, norm=False, activation=False)(concat1))

    return Model([input_image, input_z], output_image, name='generator') 

In [None]:
 def build_discriminator():
    DIM = 64
    input_image_B = layers.Input(shape=image_shape)
        
    x = downsample(DIM, 4, norm=False)(input_image_B) # 128
    x = downsample(2*DIM, 4)(x) # 64
    x = downsample(4*DIM, 4)(x) # 32
    x =downsample(8*DIM, 4, strides=1)(x) 
    output = layers.Conv2D(1, 4)(x)

    return Model(input_image_B, output, name='discriminator')     

In [None]:
class GaussianSampling(layers.Layer): 
    def __init__(self, z_dim, name):
        super(GaussianSampling, self).__init__(name=name)
        self.z_dim = z_dim

    def call(self, inputs):
        mean, logvar = inputs
        epsilon = tf.random.normal((1, self.z_dim), mean=0., stddev=1.)
        return mean + tf.exp(0.5 * logvar) * epsilon

In [None]:
def build_encoder(z_dim):
    DIM = 64
    input_image = layers.Input(shape=image_shape)
    x = downsample(DIM, 4, norm=False)(input_image) 
    x = downsample(2*DIM, 4)(x) 
    x = downsample(4*DIM, 4)(x) 
    x = downsample(8*DIM, 4)(x) 
    x = downsample(8*DIM, 4)(x) 
    x = downsample(8*DIM, 4)(x) 
    x = layers.Flatten()(x)
    mean = layers.Dense(z_dim, name='mean')(x)
    logvar = layers.Dense(z_dim, name='logvar')(x)
    z = GaussianSampling(name='z', z_dim=z_dim)([mean, logvar])
    return Model(input_image, [z, mean, logvar], name='encoder')

In [None]:
class BicycleGAN(Model):

    def __init__(self, image_shape, z_dim):
        super(BicycleGAN, self).__init__()

        self.image_shape = image_shape
        self.z_dim = z_dim

        self.discriminator_1 = build_discriminator()
        self.discriminator_2 = build_discriminator()
        self.encoder = build_encoder(z_dim)
        self.generator = build_generator(z_dim, image_shape)

        self.LAMBDA = 100

        discriminator_output = self.discriminator_1([self.generator.output])
        self.patch_size = discriminator_output.shape[1]

    def compile(self):
        super(BicycleGAN, self).compile()

        self.LAMBDA_IMAGE = 10
        self.LAMBDA_LATENT = 0.5
        self.LAMBDA_KL = 0.01

        self.d1_optimizer = Adam(2e-4, 0.5)
        self.d2_optimizer = Adam(2e-4, 0.5)
        self.g_optimizer = Adam(2e-4, 0.5)
        self.e_optimizer = Adam(2e-4, 0.5)
        
        self.mae = tf.keras.losses.MeanAbsoluteError()
        self.mse = tf.keras.losses.MeanSquaredError()

    def train_step(self, inputs):
        images_A, images_B = inputs
        batch_size = tf.shape(images_A)[0]

        images_A_1, images_A_2 = images_A[:batch_size // 2], images_A[batch_size // 2:]
        images_B_1, images_B_2 = images_B[:batch_size //2], images_B[batch_size // 2:]

        real_labels = tf.ones((batch_size, self.patch_size, self.patch_size, 1))
        fake_labels = tf.zeros((batch_size, self.patch_size, self.patch_size, 1))

        z = tf.random.normal((batch_size, self.z_dim))

        with tf.GradientTape() as d1_tape, \
            tf.GradientTape() as d2_tape, \
            tf.GradientTape() as g_tape, \
            tf.GradientTape() as e_tape:

            # cVAE-GAN
            z_encode, mean_encode, logvar_encode = self.encoder(images_B_1)
            kl_loss =  - 0.5 * tf.reduce_sum(1 + logvar_encode - \
                                          tf.square(mean_encode) - tf.exp(logvar_encode))
            
            fake_B_encode = self.generator([images_A_1, z_encode])

            encode_fake = self.discriminator_1(fake_B_encode)
            encode_real = self.discriminator_1(images_B_1)

            # cLR-GAN
            fake_B_random = self.generator([images_A_2, z])
            _, mean_random, _ = self.encoder(fake_B_random)

            random_fake = self.discriminator_2(fake_B_random)
            random_real = self.discriminator_2(images_B_2)

            d1_loss = self.mse(encode_fake, fake_labels) + self.mse(encode_real, real_labels)
            d2_loss = self.mse(random_fake, fake_labels) + self.mse(random_real, real_labels)

            g_1_loss = self.mse(encode_fake, real_labels)
            g_2_loss = self.mse(random_fake, real_labels)

            image_loss = self.LAMBDA_IMAGE * self.mae(images_B_1, fake_B_encode)
            kl_loss = self.LAMBDA_KL * (- 0.5 * tf.reduce_sum(1 + logvar_encode - \
                                          tf.square(mean_encode) - tf.exp(logvar_encode)))
            
            latent_loss = self.LAMBDA_LATENT * self.mae(z, mean_random)
            
            e_loss = g_1_loss + g_2_loss + image_loss + kl_loss
            g_loss = e_loss + latent_loss
        

        d1_grads = d1_tape.gradient(d1_loss, self.discriminator_1.trainable_variables)
        d2_grads = d2_tape.gradient(d2_loss, self.discriminator_2.trainable_variables)

        g_grads = g_tape.gradient(g_loss, self.generator.trainable_variables)
        e_grads = e_tape.gradient(e_loss, self.encoder.trainable_variables)

        self.d1_optimizer.apply_gradients(zip(d1_grads, self.discriminator_1.trainable_variables))
        self.d2_optimizer.apply_gradients(zip(d2_grads, self.discriminator_2.trainable_variables))

        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_variables))
        self.e_optimizer.apply_gradients(zip(e_grads, self.encoder.trainable_variables))

        return {'g_loss': g_loss, 'd_loss': (d1_loss + d2_loss)/2}
    
    def call(self, input_imgs):
        num_imgs = tf.shape(input_imgs)[0]
        z = tf.random.normal((num_imgs, self.z_dim))
        return self.generator([input_imgs, z])

In [None]:
class GenerativeCallback(tf.keras.callbacks.Callback):

    def __init__(self, test_dataset, num_imgs=5, interval=5):
        super(GenerativeCallback, self).__init__()
        
        self.num_imgs = num_imgs
        self.interval = interval
        self.ds = iter(test_dataset)

    def on_epoch_end(self, epoch, logs=None):
        if (epoch % self.interval) is not 0:
            return 
        grid_row = 1
        grid_col = self.num_imgs
        batch, _ = next(self.ds)
        images = self.model(batch)
        f, axarr = plt.subplots(grid_row, grid_col, figsize=(grid_col*1.5, grid_row*1.5))
        for col in range(grid_col):
            axarr[col].imshow((images[col,:,:,:]+1)/2)
            axarr[col].axis('off') 
        plt.show()

In [None]:
bicyclegan = BicycleGAN(image_shape, z_dim=8)

bicyclegan.compile()
callbacks = [GenerativeCallback(test_dataset)]

bicyclegan.fit(train_dataset, epochs=30, callbacks=callbacks)