# CycleGAN: Apple to Orange Translation

This notebook implements CycleGAN for translating between apples and oranges using the architecture described in the [original CycleGAN paper](https://arxiv.org/abs/1703.10593). Unlike the basic tutorial, this implementation:

1. Uses the ResNet generator architecture from the paper
2. Trains for 200 epochs for better results
3. Uses the apple-to-orange dataset

In [None]:
%pip install tensorflow

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [None]:
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def preprocess_image_train(image, label):
    image = normalize(image)
    return image

def random_crop(image):
    image = tf.image.random_crop(image, [256, 256, 3])
    return image

def random_jitter(image):
    # resizing to 286x286
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256x256
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def preprocess_image_test(image, label):
    image = normalize(image)
    image = tf.image.resize(image, [256, 256],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    return image

In [None]:
def resnet_block(n_filters, input_layer):
    """
    Implements a single ResNet block with two 3x3 convolutional layers
    """
    # first layer
    x = tf.keras.layers.Conv2D(n_filters, 3, padding='same', use_bias=False)(input_layer)
    x = tfa.layers.InstanceNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # second layer
    x = tf.keras.layers.Conv2D(n_filters, 3, padding='same', use_bias=False)(x)
    x = tfa.layers.InstanceNormalization()(x)
    
    # add skip connection
    x = tf.keras.layers.Add()([x, input_layer])
    return x

def resnet_generator(output_channels=3, dim=64, n_resnet=9, name=None):
    """
    Implements the ResNet generator architecture from the CycleGAN paper
    """
    inputs = tf.keras.layers.Input(shape=[256,256,3])
    
    # Initial convolution
    x = tf.keras.layers.Conv2D(dim, 7, padding='same', use_bias=False)(inputs)
    x = tfa.layers.InstanceNormalization()(x)
    x = tf.keras.layers.ReLU()(x)
    
    # Downsampling
    for _ in range(2):
        dim *= 2
        x = tf.keras.layers.Conv2D(dim, 3, strides=2, padding='same', use_bias=False)(x)
        x = tfa.layers.InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
    
    # ResNet blocks
    for _ in range(n_resnet):
        x = resnet_block(dim, x)
    
    # Upsampling
    for _ in range(2):
        dim //= 2
        x = tf.keras.layers.Conv2DTranspose(dim, 3, strides=2, padding='same', use_bias=False)(x)
        x = tfa.layers.InstanceNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
    
    # Final convolution
    x = tf.keras.layers.Conv2D(output_channels, 7, padding='same', activation='tanh')(x)
    
    return tf.keras.Model(inputs=inputs, outputs=x, name=name)

In [None]:
# Load the dataset
dataset, metadata = tfds.load('cycle_gan/apple2orange',
                            with_info=True, as_supervised=True)

train_apples, train_oranges = dataset['trainA'], dataset['trainB']
test_apples, test_oranges = dataset['testA'], dataset['testB']

# Define constants
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256
EPOCHS = 200
LAMBDA = 10
OUTPUT_CHANNELS = 3

# Apply preprocessing
train_apples = train_apples.map(
    preprocess_image_train, num_parallel_calls=tf.data.AUTOTUNE)
train_oranges = train_oranges.map(
    preprocess_image_train, num_parallel_calls=tf.data.AUTOTUNE)

test_apples = test_apples.map(
    preprocess_image_test, num_parallel_calls=tf.data.AUTOTUNE)
test_oranges = test_oranges.map(
    preprocess_image_test, num_parallel_calls=tf.data.AUTOTUNE)

# Cache, shuffle, and batch the datasets
train_apples = train_apples.cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
train_oranges = train_oranges.cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

test_apples = test_apples.cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)
test_oranges = test_oranges.cache().shuffle(
    BUFFER_SIZE).batch(BATCH_SIZE)

In [None]:
# Initialize generators and discriminators
generator_g = resnet_generator(OUTPUT_CHANNELS, name='apple2orange')
generator_f = resnet_generator(OUTPUT_CHANNELS, name='orange2apple')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

# Define optimizers
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
def discriminator_loss(real, generated):
    real_loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

    generated_loss = tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

    total_disc_loss = real_loss + generated_loss

    return total_disc_loss * 0.5

def generator_loss(generated):
    return tf.keras.losses.BinaryCrossentropy(
        from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

In [None]:
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X

        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)

    # Apply the gradients to the optimizers
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))
    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))

    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))
    
    return {
        "gen_g_loss": gen_g_loss,
        "gen_f_loss": gen_f_loss,
        "disc_x_loss": disc_x_loss,
        "disc_y_loss": disc_y_loss
    }

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)
    
    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

def plot_training_progress(epoch, g_losses, d_losses):
    """Plot training progress"""
    plt.figure(figsize=(15,5))
    
    plt.subplot(1,2,1)
    plt.plot(g_losses, label='Generator Loss')
    plt.title('Generator Loss History')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1,2,2)
    plt.plot(d_losses, label='Discriminator Loss')
    plt.title('Discriminator Loss History')
    plt.xlabel('Batch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

In [None]:
# Training
g_losses = []
d_losses = []

for epoch in range(EPOCHS):
    start = time.time()
    
    n = 0
    for image_x, image_y in tf.data.Dataset.zip((train_apples, train_oranges)):
        losses = train_step(image_x, image_y)
        g_losses.append((losses['gen_g_loss'] + losses['gen_f_loss'])/2)
        d_losses.append((losses['disc_x_loss'] + losses['disc_y_loss'])/2)
        
        if n % 10 == 0:
            print('.', end='')
        n += 1
        
    clear_output(wait=True)
    # Using test_apples as example input
    for inp in test_apples.take(1):
        generate_images(generator_g, inp)
        
    if (epoch + 1) % 10 == 0:
        plot_training_progress(epoch, g_losses, d_losses)
    
    print(f'Time taken for epoch {epoch + 1} is {time.time()-start} sec\n')