In [1]:
# https://keras.io/examples/generative/cyclegan/

In [None]:
#import tf_keras as keras
import keras
import os
#os.environ["TF_USE_LEGACY_KERAS"] = "1"

import tensorflow as tf
from tensorflow.keras import layers
import glob
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras import layers

AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
#tf.keras.layers.GroupNormalization(groups=-1)
#https://www.tensorflow.org/api_docs/python/tf/keras/layers/GroupNormalization
#Relation to Instance Normalization: If the number of groups is set to the input dimension (number of groups is equal to number of channels), then this operation becomes identical to Instance Normalization. You can achieve this via groups=-1.

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

In [None]:
dataset_path = '/kaggle/input/solesensei_bdd100k/bdd100k/bdd100k/images/100k'

trainA_path = os.path.join(dataset_path, 'train/trainA')
trainB_path = os.path.join(dataset_path, 'train/trainB')

testA_path = os.path.join(dataset_path, 'train/testA')
testB_path = os.path.join(dataset_path, 'train/testB')

BUFFER_SIZE = 256
batch_size = 1
img_size = (256, 256)
orig_img_size = (720, 720)
input_img_size = (256, 256, 3)

In [None]:
#import shutil
#if os.path.exists('/kaggle/working/Model'):
#    shutil.rmtree('/kaggle/working/Model')

In [None]:
output_img_dir = '/kaggle/working/Model/Model_Data/save'
output_ckpt_dir = '/kaggle/working/Model/Model_Data/ckpt'
backup_dir = '/kaggle/working/Model/Model_Data/backup'

for i in (output_img_dir, output_ckpt_dir, backup_dir):
    if not os.path.exists(i):
        os.makedirs(i)

In [None]:
# A -> Day
# B -> Night

train_day = glob.glob(os.path.join(dataset_path, 'train/trainA/*.jpg'))
train_night = glob.glob(os.path.join(dataset_path, 'train/trainB/*.jpg'))

print(len(train_day), len(train_night))

val_day = glob.glob(os.path.join(dataset_path, 'train/testA/*.jpg'))
val_night = glob.glob(os.path.join(dataset_path, 'train/testB/*.jpg'))

print(len(val_day), len(val_night))

train_day = train_day[::17]
train_night = train_night[::13]
print(len(train_day),len(train_night))

val_day = val_day[::10]
val_night = val_night[::6]
print(len(val_day),len(val_night))

# Data Preprocessing

In [None]:
trainB = tf.data.Dataset.list_files(os.path.join(dataset_path, 'test/trainA/*jpg'), shuffle=True).take(1200)
trainA = tf.data.Dataset.list_files(os.path.join(dataset_path, 'test/trainB/*jpg'), shuffle=True).take(1200)

testB = tf.data.Dataset.list_files(os.path.join(dataset_path, 'test/testA/*jpg'), shuffle=True).take(400)
testA = tf.data.Dataset.list_files(os.path.join(dataset_path, 'test/testB/*jpg'), shuffle=True).take(400)

In [None]:
kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

In [None]:
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image, channels=3)

    return image


def normalize_img(img):
    img = tf.cast(img, dtype=tf.float32)
    return (img / 127.5) - 1.0


def preprocess_train_image(img):
    img = tf.image.random_flip_left_right(img)
    img = tf.image.resize(img, [*orig_img_size])
    img = tf.image.random_crop(img, size=[*img_size, 3])
    img = normalize_img(img)
    return img


def preprocess_test_image(img):
    img = tf.image.resize(img, [img_size[0], img_size[1]])
    img = normalize_img(img)
    return img

def map_train_load(image):
    image = load(image)
    image = preprocess_train_image(image)
    return image

def map_test_load(image):
    image = load(image)
    image = preprocess_test_image(image)
    return image
    

In [None]:
trainA = (
    trainA.map(map_train_load, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(batch_size)
)
trainB = (
    trainB.map(map_train_load, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(batch_size)
)

testA = (
    testA.map(map_test_load, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(batch_size)
)
testB = (
    testB.map(map_test_load, num_parallel_calls=AUTOTUNE)
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(batch_size)
)

def read_jpg(path):
    img = tf.io.read_file(path)
    img = tf.image.decode_jpeg(img, channels=3)
    return img

def normalize(input_image):
    input_image = tf.cast(input_image, tf.float32)/127.5 - 1
    return input_image

def load_image(image_path):
    image = read_jpg(image_path)
    image = tf.image.resize(image, (256, 256))
    image = normalize(image)
    return image

In [None]:
_, ax = plt.subplots(2, 2, figsize=(15, 15))
for i, samples in enumerate(zip(trainA.take(2), trainB.take(2))):
    day = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    night = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    ax[i, 0].imshow(day)
    ax[i, 1].imshow(night)
plt.show()

# Model Architecture

In [None]:
class ReflectionPadding2D(layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super().__init__(**kwargs)

    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        padding_tensor = [
            [0, 0],
            [padding_height, padding_height],
            [padding_width, padding_width],
            [0, 0],
        ]
        return tf.pad(input_tensor, padding_tensor, mode="REFLECT")

In [None]:
def residual_block(
    x,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="valid",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    dim = x.shape[-1]
    input_tensor = x

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tf.keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tf.keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor, x])
    return x


def downsample(
    x,
    filters,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2D(
        filters,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tf.keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x


def upsample(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    kernel_initializer=kernel_init,
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2DTranspose(
        filters,
        kernel_size,
        strides=strides,
        padding=padding,
        kernel_initializer=kernel_initializer,
        use_bias=use_bias,
    )(x)
    x = tf.keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x

In [None]:
def get_resnet_generator(
    filters=64,
    num_downsampling_blocks=2,
    num_residual_blocks=9,
    num_upsample_blocks=2,
    gamma_initializer=gamma_init,
    name=None,
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = ReflectionPadding2D(padding=(3, 3))(img_input)
    x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
        x
    )
    x = tf.keras.layers.GroupNormalization(groups=-1, gamma_initializer=gamma_initializer)(x)
    x = layers.Activation("relu")(x)

    # Downsampling
    for _ in range(num_downsampling_blocks):
        filters *= 2
        x = downsample(x, filters=filters, activation=layers.Activation("relu"))

    # Residual blocks
    for _ in range(num_residual_blocks):
        x = residual_block(x, activation=layers.Activation("relu"))

    # Upsampling
    for _ in range(num_upsample_blocks):
        filters //= 2
        x = upsample(x, filters, activation=layers.Activation("relu"))

    # Final block
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (7, 7), padding="valid")(x)
    x = layers.Activation("tanh")(x)

    model = tf.keras.models.Model(img_input, x, name=name)
    return model

In [None]:
def get_discriminator(
    filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = layers.Conv2D(
        filters,
        (4, 4),
        strides=(2, 2),
        padding="same",
        kernel_initializer=kernel_initializer,
    )(img_input)
    x = layers.LeakyReLU(0.2)(x)

    num_filters = filters
    for num_downsample_block in range(3):
        num_filters *= 2
        if num_downsample_block < 2:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(2, 2),
            )
        else:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(1, 1),
            )

    x = layers.Conv2D(
        1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
    )(x)

    model = tf.keras.models.Model(inputs=img_input, outputs=x, name=name)
    return model

In [None]:
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")

In [None]:
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")

# Model

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(
        self,
        generator_G,
        generator_F,
        discriminator_X,
        discriminator_Y,
        lambda_cycle=10.0,
        lambda_identity=0.5,
    ):
        super().__init__()
        self.gen_G = generator_G
        self.gen_F = generator_F
        self.disc_X = discriminator_X
        self.disc_Y = discriminator_Y
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity

    def compile(
        self,
        gen_G_optimizer,
        gen_F_optimizer,
        disc_X_optimizer,
        disc_Y_optimizer,
        gen_loss_fn,
        disc_loss_fn,
    ):
        super().compile()
        self.gen_G_optimizer = gen_G_optimizer
        self.gen_F_optimizer = gen_F_optimizer
        self.disc_X_optimizer = disc_X_optimizer
        self.disc_Y_optimizer = disc_Y_optimizer
        self.generator_loss_fn = gen_loss_fn
        self.discriminator_loss_fn = disc_loss_fn
        self.cycle_loss_fn = keras.losses.MeanAbsoluteError()
        self.identity_loss_fn = keras.losses.MeanAbsoluteError()

    def train_step(self, batch_data):
        # x is Day and y is Night
        real_x, real_y = batch_data

        # For CycleGAN, we need to calculate different
        # kinds of losses for the generators and discriminators.
        # We will perform the following steps here:
        #
        # 1. Pass real images through the generators and get the generated images
        # 2. Pass the generated images back to the generators to check if we
        #    can predict the original image from the generated image.
        # 3. Do an identity mapping of the real images using the generators.
        # 4. Pass the generated images in 1) to the corresponding discriminators.
        # 5. Calculate the generators total loss (adversarial + cycle + identity)
        # 6. Calculate the discriminators loss
        # 7. Update the weights of the generators
        # 8. Update the weights of the discriminators
        # 9. Return the losses in a dictionary

        with tf.GradientTape(persistent=True) as tape:
            # Day to fake night
            fake_y = self.gen_G(real_x, training=True)
            # Night to fake day -> y2x
            fake_x = self.gen_F(real_y, training=True)

            # Cycle (Day to fake night to fake day): x -> y -> x
            cycled_x = self.gen_F(fake_y, training=True)
            # Cycle (Night to fake day to fake night) y -> x -> y
            cycled_y = self.gen_G(fake_x, training=True)

            # Identity mapping
            same_x = self.gen_F(real_x, training=True)
            same_y = self.gen_G(real_y, training=True)

            # Discriminator output
            disc_real_x = self.disc_X(real_x, training=True)
            disc_fake_x = self.disc_X(fake_x, training=True)

            disc_real_y = self.disc_Y(real_y, training=True)
            disc_fake_y = self.disc_Y(fake_y, training=True)

            # Generator adversarial loss
            gen_G_loss = self.generator_loss_fn(disc_fake_y)
            gen_F_loss = self.generator_loss_fn(disc_fake_x)

            # Generator cycle loss
            cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
            cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle

            # Generator identity loss
            id_loss_G = (
                self.identity_loss_fn(real_y, same_y)
                * self.lambda_cycle
                * self.lambda_identity
            )
            id_loss_F = (
                self.identity_loss_fn(real_x, same_x)
                * self.lambda_cycle
                * self.lambda_identity
            )

            # Total generator loss
            total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
            total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F

            # Discriminator loss
            disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
            disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)

        # Get the gradients for the generators
        grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
        grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)

        # Get the gradients for the discriminators
        disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
        disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)

        # Update the weights of the generators
        self.gen_G_optimizer.apply_gradients(
            zip(grads_G, self.gen_G.trainable_variables)
        )
        self.gen_F_optimizer.apply_gradients(
            zip(grads_F, self.gen_F.trainable_variables)
        )

        # Update the weights of the discriminators
        self.disc_X_optimizer.apply_gradients(
            zip(disc_X_grads, self.disc_X.trainable_variables)
        )
        self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_grads, self.disc_Y.trainable_variables)
        )

        return {
            "G_loss": total_loss_G,
            "F_loss": total_loss_F,
            "D_X_loss": disc_X_loss,
            "D_Y_loss": disc_Y_loss,
        }

# Callback

In [None]:
class GANMonitor(tf.keras.callbacks.Callback):
    def __init__(self, num_img=2):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(2, 2, figsize=(12, 12))
        for i, img in enumerate(testA.take(self.num_img)):
            prediction = self.model.gen_G(img, training=False)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction)
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = tf.keras.utils.array_to_img(prediction)
            prediction.save(
                f"{output_img_dir}/generated_img_{i}_{epoch+1}.png"
            )
        plt.show()
        plt.close()

In [None]:
plotter = GANMonitor()
checkpoint_filepath = output_ckpt_dir + "/cyclegan_checkpoints.{epoch:03d}.weights.h5"
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath, save_weights_only=True
)
backup_callback = keras.callbacks.BackupAndRestore(backup_dir=backup_dir)

# Training

In [None]:
adv_loss_fn = tf.keras.losses.MeanSquaredError()

def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss


def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5

In [None]:
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

cycle_gan_model.compile(
    gen_G_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_F_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_X_optimizer=keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    disc_Y_optimizer=tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)

In [None]:
cycle_gan_model.build((1, 256, 256, 3))

cycle_gan_model.gen_G.build((1, 256, 256, 3))
cycle_gan_model.gen_F.build((1, 256, 256, 3))
cycle_gan_model.disc_X.build((1, 256, 256, 3))
cycle_gan_model.disc_Y.build((1, 256, 256, 3))

In [None]:
history = cycle_gan_model.fit(
    tf.data.Dataset.zip((trainA, trainB)),
    epochs=100,
    callbacks=[plotter, model_checkpoint_callback, backup_callback]
)