In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np
import os

from utils.preprocessing import load_dataset
from utils.gan import generator

from utils.preprocessing import get_gan_dataset

from utils.diffaugmentation import data_augment_flip, aug_fn

from utils.dualdiscriminator import (
    discriminator_paint, discriminator_photo, d_head,
    CycleGan,
    generator_loss1, generator_loss2,
    discriminator_loss1, discriminator_loss2,
    calc_cycle_loss, identity_loss
)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

# Import the data

In [None]:
PAINTER = "monet"
# PAINTER = "vangogh"

In [None]:
# get the number of files in the paint and photo directories
data_dir = ('data/')

if PAINTER == "monet":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'monet_tfrec/*.tfrec'))
if PAINTER == "vangogh":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'vangogh_tfrec/*.tfrecord'))

PHOTO_FILES = tf.io.gfile.glob(str(data_dir + 'photo_tfrec/*.tfrec'))

In [None]:
paint_ds = load_dataset(PAINTER_FILES).batch(1)
photo_ds = load_dataset(PHOTO_FILES).batch(1)


fast_photo_ds = load_dataset(PHOTO_FILES).batch(32 * strategy.num_replicas_in_sync).prefetch(32)

In [None]:
BATCH_SIZE = 32 
final_dataset = get_gan_dataset(PAINTER_FILES, PHOTO_FILES, augment=data_augment_flip, repeat=True, shuffle=True, batch_size=BATCH_SIZE)

# Creating and training the model

In [None]:
with strategy.scope():
    paint_generator = generator()
    photo_generator = generator()

    paint_discriminator = discriminator_paint()
    photo_discriminator = discriminator_photo()

    d_head_bce = d_head()
    d_head_hinge_loss = d_head()

In [None]:
os.makedirs('checkpoints', exist_ok=True)
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor = 'total_loss',patience=10,restore_best_weights=True, mode='min'),
    tf.keras.callbacks.TerminateOnNaN(),
]

In [None]:
with strategy.scope():
    paint_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    paint_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        paint_generator, photo_generator,
        paint_discriminator, photo_discriminator,
        d_head_bce, d_head_hinge_loss,
    )

In [None]:
with strategy.scope():
    cycle_gan_model.compile(
        paint_gen_optimizer = paint_generator_optimizer,
        photo_gen_optimizer = photo_generator_optimizer,
        paint_disc_optimizer = paint_discriminator_optimizer,
        photo_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn1 = generator_loss1,
        gen_loss_fn2 = generator_loss2,
        disc_loss_fn1 = discriminator_loss1,
        disc_loss_fn2 = discriminator_loss2,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss,
        aug_fn = aug_fn,
    )

In [None]:
cycle_gan_model.fit(final_dataset, steps_per_epoch=1407, epochs=23, callbacks=[callbacks])

# Visualising the results

In [None]:
ds_iter = iter(photo_ds)
for n_sample in range(8):
        example_sample = next(ds_iter)
        generated_sample = paint_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
ds_iter = iter(paint_ds)
for n_sample in range(10):

        example_sample = next(ds_iter)
        generated_sample = photo_generator(example_sample)
        
        f = plt.figure(figsize=(24, 24))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

# Saving the images


In [None]:
import PIL
! mkdir ../images

In [None]:
%%time
i = 1
for img in fast_photo_ds:
    prediction = monet_generator(img, training=False).numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    for pred in prediction:
        im = PIL.Image.fromarray(pred)
        im.save("output/images/" + str(i) + ".jpg")
        i += 1

In [None]:
import shutil
shutil.make_archive("output/output_cyclegan", 'zip', "output/images")

# Saving the model

In [None]:
SAVE_PATH = 'images/dualD/'
paint_generator.save(SAVE_PATH + f'G_{PAINTER}.h5')
photo_generator.save(SAVE_PATH + 'G_photo.h5')
paint_discriminator.save(SAVE_PATH + f'D_{PAINTER}.h5')
photo_discriminator.save(SAVE_PATH + 'D_photo.h5')
d_head_bce.save(SAVE_PATH + 'D_head_bce.h5')
d_head_hinge_loss.save(SAVE_PATH + 'D_head_hinge_loss.h5')