In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np

from utils.preprocessing import load_dataset

from utils.gan import (
    generator, discriminator,
)

from utils.cyclegan import (
    CycleGan,
    generator_loss, discriminator_loss,
    calc_cycle_loss, identity_loss
)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

# Import the data

In [None]:
PAINTER = "monet"
# PAINTER = "vangogh"

In [None]:
# get the number of files in the monet and photo directories
data_dir = ('data/')

if PAINTER == "monet":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'monet_tfrec/*.tfrec'))
if PAINTER == "vangogh":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'vangogh_tfrec/*.tfrecord'))

PHOTO_FILES = tf.io.gfile.glob(str(data_dir + 'photo_tfrec/*.tfrec'))

In [None]:
# print the number of files in each directory
print('Monet TFRecord Files:', len(PAINTER_FILES))
print('Photo TFRecord Files:', len(PHOTO_FILES))

In [None]:
paintings = load_dataset(filenames=PAINTER_FILES, batch_size=1)
photos = load_dataset(filenames=PHOTO_FILES, batch_size=1)

In [None]:
# plot 5 random images from the monet dataset
plt.figure(figsize=(10, 10))
for i, image in enumerate(paintings.take(5)):
    plt.subplot(5, 5, i + 1)
    plt.imshow((image[0] * 0.5 + 0.5).numpy())
    plt.axis('off')
plt.title('Sample of The Monet Dataset')
plt.show()

# plot 5 random images from the photo dataset
plt.figure(figsize=(10, 10))
for i, image in enumerate(photos.take(5)):
    plt.subplot(5, 5, i + 1)
    plt.imshow((image[0] * 0.5 + 0.5).numpy())
    plt.axis('off')
plt.title('Sample of The Photo Dataset')
plt.show()


# Creating and training the model

In [None]:
with strategy.scope():
    G_paint = generator()
    G_photo = generator()
    
    D_paint = discriminator()
    D_photo = discriminator()

In [None]:
with strategy.scope():
    G_optimizer_paint = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    G_optimizer_photo = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

    D_optimizer_paint = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)
    D_optimizer_photo = tf.keras.optimizers.Adam(learning_rate=0.0002, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        paint_generator=G_paint,
        photo_generator=G_photo,
        paint_discriminator=D_paint,
        photo_discriminator=D_photo,
    )

    cycle_gan_model.compile(
        paint_gen_optimizer=G_optimizer_paint,
        photo_gen_optimizer=G_optimizer_photo,
        paint_disc_optimizer=D_optimizer_paint,
        photo_disc_optimizer=D_optimizer_photo,
        gen_loss_fn=generator_loss,
        disc_loss_fn=discriminator_loss,
        cycle_loss_fn=calc_cycle_loss,
        identity_loss_fn=identity_loss,
    )

In [None]:
NUM_EPOCHS = 25
with strategy.scope():
    history = cycle_gan_model.fit(
        tf.data.Dataset.zip((paintings, photos)),
        epochs=NUM_EPOCHS,
    )

In [None]:
paint_gen_loss = []
photo_gen_loss = []
paint_disc_loss = []
photo_disc_loss = []
total_cycle_loss = []

print(history.history.keys())
for epoc in range(NUM_EPOCHS):
    paint_gen_loss.append(np.average(history.history['monet_gen_loss'][epoc].flatten()))    
    photo_gen_loss.append(np.average(history.history['photo_gen_loss'][epoc].flatten()))
    paint_disc_loss.append(np.average(history.history['monet_disc_loss'][epoc].flatten()))
    photo_disc_loss.append(np.average(history.history['photo_disc_loss'][epoc].flatten()))

# The loss

In [None]:
plt.plot(paint_gen_loss, label='Monet Gen Loss')
plt.plot(photo_gen_loss, label='Photo Gen Loss')
plt.plot(paint_disc_loss, label='Monet Disc Loss')
plt.plot(photo_disc_loss, label='Photo Disc Loss')

plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator and Discriminator Loss')
plt.show()

# Saving the model

In [None]:
SAVE_PATH = 'images/cycleGAN/'
G_paint.save(SAVE_PATH + f'G_{PAINTER}.h5')
G_photo.save(SAVE_PATH + 'G_photo.h5')
D_paint.save(SAVE_PATH + f'D_{PAINTER}.h5')
D_photo.save(SAVE_PATH + 'D_photo.h5')

# Generate the output images

In [None]:
import PIL
! mkdir output
! mkdir output/images

In [None]:
i = 1
for img in photos:
    prediction = G_paint(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("output/images" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("output/output_cyclegan", 'zip', "output/images")