In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np

import re

from utils.preprocessing import (
    load_dataset,
    get_gan_dataset, get_photo_dataset,
)

from utils.gan import generator, discriminator

from utils.diffaugmentation import data_augment_flip, CycleGan

from utils.cyclegan import (
    generator_loss, discriminator_loss,
    calc_cycle_loss, identity_loss
)

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()

# Import the data

In [None]:
PAINTER = "monet"
# PAINTER = "vangogh"

In [None]:
# get the number of files in the monet and photo directories
data_dir = ('data/')

if PAINTER == "monet":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'monet_tfrec/*.tfrec'))
if PAINTER == "vangogh":
    PAINTER_FILES = tf.io.gfile.glob(str(data_dir + 'vangogh_tfrec/*.tfrecord'))

PHOTO_FILES = tf.io.gfile.glob(str(data_dir + 'photo_tfrec/*.tfrec'))

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_paint_samples = count_data_items(PAINTER_FILES)
n_photo_samples = count_data_items(PHOTO_FILES)

# Creating the model

In [None]:
BATCH_SIZE =  128
EPOCHS_NUM = 28

In [None]:
full_dataset = get_gan_dataset(PAINTER_FILES, PHOTO_FILES, augment=data_augment_flip, repeat=True, shuffle=False, batch_size=BATCH_SIZE)

In [None]:
with strategy.scope():
    paint_generator = generator()
    photo_generator = generator()

    paint_discriminator = discriminator()
    photo_discriminator = discriminator()

In [None]:
with strategy.scope():
    paint_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    paint_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        paint_generator, photo_generator, paint_discriminator, photo_discriminator
    )

    cycle_gan_model.compile(
        paint_gen_optimizer = paint_generator_optimizer,
        photo_gen_optimizer = photo_generator_optimizer,
        paint_disc_optimizer = paint_discriminator_optimizer,
        photo_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )


In [None]:
cycle_gan_model.fit(
    full_dataset,
    epochs=EPOCHS_NUM,
    steps_per_epoch=(max(n_paint_samples, n_photo_samples)//4),
)


# Saving the model

In [None]:
SAVE_PATH = 'images/cycleGAN/'
paint_generator.save(SAVE_PATH + f'G_{PAINTER}.h5')
photo_generator.save(SAVE_PATH + 'G_photo.h5')
paint_discriminator.save(SAVE_PATH + f'D_{PAINTER}.h5')
photo_discriminator.save(SAVE_PATH + 'D_photo.h5')

# Generate the output images

In [None]:
import PIL
! mkdir output
! mkdir output/images

In [None]:
i = 1
for img in photos:
    prediction = paint_discriminator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("output/images" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("output/output_diffaugment", 'zip', "output/images")