In [1]:
!pip install opendatasets
!pip install tensorflow-addons

Collecting opendatasets
  Downloading opendatasets-0.1.22-py3-none-any.whl (15 kB)
Installing collected packages: opendatasets
Successfully installed opendatasets-0.1.22
Collecting tensorflow-addons
  Downloading tensorflow_addons-0.21.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (612 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m612.1/612.1 kB[0m [31m33.0 MB/s[0m eta [36m0:00:00[0m
Collecting typeguard<3.0.0,>=2.7 (from tensorflow-addons)
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, tensorflow-addons
Successfully installed tensorflow-addons-0.21.0 typeguard-2.13.3


## Import required package

In [None]:
from data_loader import DataLoader
from discriminator import Discriminator
from generator import Generator
from utils import plot_generator_loss, generate_and_visualize, plot_loss

import tensorflow as tf
import time
from tqdm import tqdm
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import MeanSquaredError, MeanAbsoluteError
import os
import matplotlib.pyplot as plt

from IPython import display
import imageio.v2 as imageio
import config

import opendatasets as od

## Prepare dataset

In [5]:
# Dataset link : https://www.kaggle.com/datasets/balraj98/monet2photo
od.download('https://www.kaggle.com/datasets/balraj98/monet2photo')

Please provide your Kaggle credentials to download this dataset. Learn more: http://bit.ly/kaggle-creds
Your Kaggle username: lamlenn
Your Kaggle Key: ··········
Downloading monet2photo.zip to ./monet2photo


100%|██████████| 291M/291M [00:17<00:00, 17.6MB/s]





In [6]:
train_dataset = DataLoader(config.TRAIN_X_PATH, config.TRAIN_Y_PATH) \
                        .get_dataset(config.BATCH_SIZE, config.BUFFER_SIZE, config.SHUFFLE)
val_dataset = DataLoader(config.VAL_X_PATH, config.VAL_Y_PATH) \
                        .get_dataset(config.BATCH_SIZE, config.BUFFER_SIZE, False)

## Train model

In [None]:
# Create optimizer
gen_optim = Adam(learning_rate=config.LEARNING_RATE, beta_1=config.BETA_1, beta_2=config.BETA_2)
disc_optim = Adam(learning_rate=config.LEARNING_RATE, beta_1=config.BETA_1, beta_2=config.BETA_2)

# Initialize generator and discriminator
gen_g = Generator()
gen_f = Generator()
disc_X = Discriminator()
disc_Y = Discriminator()

# Create checkpoint
checkpoint = tf.train.Checkpoint(gen_optim=gen_optim, disc_optim=disc_optim,
                                gen_g=gen_g, gen_f=gen_f, disc_x=disc_X, disc_y=disc_Y)

# Prepare val input
test_inputs = []
test_targets = []
for x, y in val_dataset.take(6):
    test_inputs.append(x)
    test_targets.append(y)
test_inputs = tf.concat(test_inputs, axis=0)
test_targets = tf.concat(test_targets, axis=0)

# Variable for tracing image
num_iter=15

def train(dataset, continue_training=False, checkpoint_dir=None, start_epoch=0, num_iter=0):
    if continue_training:
        checkpoint.restore(checkpoint_dir)

    G_losses = []
    D_losses = []
    # Initialize loss
    mse_loss = MeanSquaredError()
    mae_loss = MeanAbsoluteError()

    for epoch in range(start_epoch, config.EPOCHS):
        start = time.perf_counter()

        for batch_idx, (x_image, y_image) in enumerate(tqdm(dataset)):
            # Train discriminator X and discriminator Y
            with tf.GradientTape() as d_tape:
                fake_image_G = gen_g(x_image, training=True)
                true_disc_Y = disc_Y(y_image, training=True)
                disc_Y_true_loss = mse_loss(tf.ones_like(true_disc_Y), true_disc_Y)
                fake_disc_Y = disc_Y(fake_image_G, training=True)
                disc_Y_fake_loss = mse_loss(tf.zeros_like(fake_disc_Y), fake_disc_Y)

                # Total loss of discriminator Y
                disc_Y_loss = disc_Y_true_loss + disc_Y_fake_loss

                fake_image_F = gen_f(y_image, training=True)
                true_disc_X = disc_X(x_image, training=True)
                disc_X_true_loss = mse_loss(tf.ones_like(true_disc_X), true_disc_X)
                fake_disc_X = disc_X(fake_image_F, training=True)
                disc_X_fake_loss = mse_loss(tf.zeros_like(fake_disc_X), fake_disc_X)

                # Total loss of discriminator X
                disc_X_loss = disc_X_true_loss + disc_X_fake_loss

                # Total loss of discriminator in object function.
                # Slow down the discriminator by dividing the object by 2.
                disc_loss = (disc_Y_loss + disc_X_loss) / 2.0

            # Calculate gradients
            disc_grads = d_tape.gradient(disc_loss, list(disc_X.trainable_variables) + list(disc_Y.trainable_variables))
            # Update weights
            disc_optim.apply_gradients(zip(disc_grads, list(disc_X.trainable_variables) + list(disc_Y.trainable_variables)))

            # Monitor loss
            D_losses.append(disc_loss)

            # Train Generator X and Generator Y
            with tf.GradientTape() as g_tape:
                # Adversarial loss of gen X and gen Y
                fake_image_G = gen_g(x_image, training=True)
                fake_disc_Y = disc_Y(fake_image_G, training=True)
                disc_Y_fake_loss = mse_loss(tf.ones_like(fake_disc_Y), fake_disc_Y)

                fake_image_F = gen_f(y_image, training=True)
                fake_disc_X = disc_X(fake_image_F, training=True)
                disc_X_fake_loss = mse_loss(tf.ones_like(fake_disc_X), fake_disc_X)

                # Consistency Loss of gen X and gen Y
                consistence_loss_X = mae_loss(x_image, gen_f(fake_image_G))
                consistence_loss_Y = mae_loss(y_image, gen_g(fake_image_F))
                # Total loss of Generator X and Generator Y
                gen_loss = disc_Y_fake_loss + disc_X_fake_loss \
                            + config.CONSISTENCY_LOSS_LAMBDA * consistence_loss_X \
                            + config.CONSISTENCY_LOSS_LAMBDA * consistence_loss_Y

            # Calculate gradients
            gen_grads = g_tape.gradient(gen_loss, list(gen_g.trainable_variables) + list(gen_f.trainable_variables))
            # Update weights
            gen_optim.apply_gradients(zip(gen_grads, list(gen_g.trainable_variables) + list(gen_f.trainable_variables)))

            # Monitor loss
            G_losses.append(gen_loss)

            # Print result for monitoring
            if (batch_idx + 1) % 200 == 0 and batch_idx > 0:
                # Monitor training process
                display.clear_output(wait=True)
                print(f'Traning from batch {batch_idx - 200} to batch {batch_idx} \
                            takes {time.perf_counter() - start} seconds')
                print(
                    f"Epoch [{epoch}/{config.EPOCHS}] Batch {batch_idx}/{len(dataset)} \
                      Loss D: {disc_loss:.4f}, loss G: {gen_loss:.4f}"
                )
                generate_and_visualize(gen_g, test_inputs, num_iter)
                plot_loss(G_losses, D_losses)
                num_iter += 1

            if (batch_idx + 1) % 3000 == 0:
                checkpoint.save(config.checkpoint_prefix)

    print('Done!')
    # Display the last result
    display.clear_output(wait=True)
    print(f'Traning epoch {epoch} takes {time.perf_counter() - start} seconds')
    generate_and_visualize(gen_g, test_inputs, config.EPOCHS)
    plot_loss(G_losses, D_losses)

train(train_dataset, num_iter=num_iter, continue_training=True, checkpoint_dir=config.checkpoint_prefix + '-1')

Output hidden; open in https://colab.research.google.com to view.

In [5]:
import imageio.v2  as imageio
with open('cycleGAN_training_process.gif', 'w+'):
    pass

frames = []
for epoch in range(67):
    image = imageio.imread('training-images/image_at_epoch_{:04d}.png'.format(epoch))
    frames.append(image)

imageio.mimsave('cycleGAN_training_process.gif', frames)

# this is a hack to display the gif inside the notebook
#os.system('cp cycleGAN_training_process.gif cycleGAN_training_process.gif.png')

#display.Image(filename="pix2pix_training_process.gif.png")