based on: https://medium.com/analytics-vidhya/transforming-the-world-into-paintings-with-cyclegan-6748c0b85632

## The Data

In [68]:
import tensorflow_datasets as tfds

data, metadata = tfds.load("cycle_gan/monet2photo",
    with_info=True, as_supervised=True)
train_x, train_y = data["trainA"], data["trainB"]
test_x, test_y = data["testA"], data["testB"]

## Steup

In [69]:
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import RandomNormal

# num training epochs
EPOCHS = 50

# relative importance of the cycle loss to the adversarial loss
LAMBDA = 10

img_rows, img_cols, channels = 256, 256, 3
weight_initializer = RandomNormal(stddev=0.02)

gen_g_optimizer = gen_f_optimizer = Adam(lr=0.0002, beta_1=0.5)
dis_x_optimizer = dis_y_optimizer = Adam(lr=0.0002, beta_1=0.5)

In [70]:
import tensorflow as tf

if tf.test.gpu_device_name():
    print(f"Default GPU Device: {tf.test.gpu_device_name()}")
    tf.device(tf.test.gpu_device_name())

# TODO: add random cropping
def preprocess_image(image, _):
    return tf.reshape(
        tf.cast(
            tf.image.resize(image, (int(img_rows), int(img_cols))),
            tf.float32
        ) / 127.5 - 1,
        (1, img_rows, img_cols, channels)
    )

In [71]:
train_x = train_x.map(preprocess_image)
train_y = train_y.map(preprocess_image)
test_x = test_x.map(preprocess_image)
test_y = test_y.map(preprocess_image)

## The Model

In [72]:
from tensorflow.keras.layers import Conv2D, LeakyReLU
from tensorflow_addons.layers import InstanceNormalization

def Ck(inpt, k, use_instancenorm=True):
    """
    Ck denotes a 4 × 4 Convolution-InstanceNorm-LeakyReLU layer 
    with k filters and stride 2.
    """
    block = Conv2D(
        k, 
        (4, 4), 
        strides=2, 
        padding="same", 
        kernel_initializer=weight_initializer
    )(inpt)

    if use_instancenorm:
        block = InstanceNormalization(axis=-1)(block)

    block = LeakyReLU(0.2)(block)

    return block


In [73]:
from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model

def discriminator():
    dis_input = Input(shape=(img_rows, img_cols, channels))

    d = Ck(dis_input, 64, False)
    d = Ck(d, 128)
    d = Ck(d, 256)
    d = Ck(d, 512)

    d = Conv2D(
        1, 
        (4, 4), 
        padding="same", 
        kernel_initializer=weight_initializer
    )(d)

    dis = Model(dis_input, d)
    dis.compile(loss="mse", optimizer=dis_x_optimizer)
    return dis

In [74]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Activation
from tensorflow_addons.layers import InstanceNormalization

def dk(k, use_instancenorm=True):
    """
    dk denotes a 3×3 Convolution-InstanceNorm-ReLU with k 
    filters and stride 2.
    """
    block = Sequential()
    block.add(
        Conv2D(
            k,
            (3, 3),
            strides=2,
            padding="same",
            kernel_initializer=weight_initializer
        )
    )

    if use_instancenorm:
        block.add(InstanceNormalization(axis=-1))
    
    block.add(Activation("relu"))

    return block

In [75]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, Activation
from tensorflow_addons.layers import InstanceNormalization

def uk(k):
    """
    uk denotes a 3×3 fractional-strided-ConvolutionInstanceNorm-ReLU 
    layer with k filters and stride ½.
    """
    block = Sequential()
    block.add(
        Conv2DTranspose(
            k,
            (3, 3),
            strides=2,
            padding="same",
            kernel_initializer=weight_initializer
        )
    )

    block.add(InstanceNormalization(axis=-1))
    block.add(Activation("relu"))
    
    return block

In [76]:
from tensorflow.keras.layers import Input, Concatenate, Conv2DTranspose
from tensorflow.keras.models import Model

def generator():
    gen_input = Input(shape=(img_rows, img_cols, channels))

    encoder_layers = [
        dk(64, False),
        dk(128),
        dk(256),
        dk(512),
        dk(512),
        dk(512),
        dk(512),
        dk(512,)
    ]

    decoder_layers = [
        uk(512),
        uk(512),
        uk(512),
        uk(512),
        uk(256),
        uk(128),
        uk(64),
    ]

    gen = gen_input

    skips = []
    for layer in encoder_layers:
        gen = layer(gen)
        skips.append(gen)

    # reverse and remove first element
    skips = skips[::-1][1:]

    for skip_layer, layer in zip(skips, decoder_layers):
        gen = layer(gen)
        gen = Concatenate()([gen, skip_layer])

    gen = Conv2DTranspose(
        channels, 
        (3, 3),
        strides=2,
        padding="same", 
        kernel_initializer=weight_initializer,
        activation="tanh"
    )(gen)

    return Model(gen_input, gen)

In [77]:
generator_g = generator()
generator_f = generator()

discriminator_x = discriminator()
discriminator_y = discriminator()

## Training

In [78]:
from tensorflow.keras.losses import BinaryCrossentropy

loss = BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
    return (loss(tf.ones_like(real), real) 
        + loss(tf.zeros_like(generated), generated)) * 0.5


def gen_loss(validity):
    return loss(tf.ones_like(validity), validity)


def image_similarity(image1, image2):
    return tf.reduce_mean(tf.abs(image1 - image2))

In [79]:
@tf.function
def step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # compute discriminator y loss
        fake_y = generator_g(real_x, training=True)
        gen_g_validity = discriminator_y(fake_y, training=True)
        dis_y_loss = discriminator_loss(
            discriminator_y(real_y, training=True),
            gen_g_validity
        )

        # compute and apply dicriminator y gradients
        with tape.stop_recording():
            discriminator_y_gradients = tape.gradient(
                dis_y_loss,
                discriminator_y.trainable_variables
            )

            dis_y_optimizer.apply_gradients(
                zip(discriminator_y_gradients, discriminator_y.trainable_variables)
            )

        # compute discriminator x loss
        fake_x = generator_f(real_y, training=True)
        gen_f_validity = discriminator_x(fake_x, training=True)
        dis_x_loss = discriminator_loss(
            discriminator_x(real_x, training=True),
            gen_f_validity
        )

        # compute and apply discriminator x gradients
        with tape.stop_recording():
            discriminator_x_gradients = tape.gradient(
                dis_x_loss,
                discriminator_x.trainable_variables
            )

            dis_x_optimizer.apply_gradients(
                zip(discriminator_x_gradients, discriminator_x.trainable_variables)
            )

        # adversarial losses - how real the dis believed the generated images
        gen_g_adv_loss = gen_loss(gen_g_validity)
        gen_f_adv_loss = gen_loss(gen_f_validity)

        # cycle losses - how well a translated image can be re-translated
        cyc_x = generator_f(fake_y, training=True)
        cyc_x_loss = image_similarity(real_x, cyc_x)
        cyc_y = generator_g(fake_x, training=True)
        cyc_y_loss = image_similarity(real_y, cyc_y)

        # identity loss 
        id_x = generator_f(real_x, training=True)
        id_x_loss = image_similarity(real_x, id_x)
        id_y = generator_g(real_y, training=True)
        id_y_loss = image_similarity(real_y, id_y)

        # generator losses
        gen_g_loss = gen_g_adv_loss + (cyc_x_loss + cyc_y_loss) * LAMBDA + id_y_loss * 0.5 * LAMBDA
        gen_f_loss = gen_f_adv_loss + (cyc_x_loss + cyc_y_loss) * LAMBDA + id_x_loss * 0.5 * LAMBDA

        # compute and apply generator gradients
        with tape.stop_recording():
            generator_g_gradients = tape.gradient(gen_g_loss, generator_g.trainable_variables)
            gen_g_optimizer.apply_gradients(
                zip(generator_g_gradients, generator_g.trainable_variables)
            )

            generator_f_gradients = tape.gradient(gen_f_loss, generator_f.trainable_variables)
            gen_f_optimizer.apply_gradients(
                zip(generator_f_gradients, generator_f.trainable_variables)
            )

In [80]:
import matplotlib.pyplot as plt

def generate_images():
    # sample images
    x = next(iter(test_x.shuffle(1000))).numpy()
    y = next(iter(test_y.shuffle(1000))).numpy()

    # get predictions for those images
    y_hat = generator_g.predict(x.reshape((1, img_rows, img_cols, channels)))
    x_had = generator_f.predict(y.reshape((1, img_rows, img_cols, channels)))

    # plot images
    plt.figure(figsize(12, 12))

    images = [x[0], y_hat[0], y[0], x_hat[0]]

    for i in range(4):
        plt.subplot(2, 2, i + 1)
        plt.imshow(images[i] * 0.5 + 0.5)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

In [81]:
import time

for epoch in range(EPOCHS):
    print(f"Epoch: {epoch}")
    start = time.time()

    # Each batch
    for k, (real_x, real_y) in enumerate(tf.data.Dataset.zip((train_x, train_y))):
        if k % 100 == 0:
            print(k)

        step(
            tf.reshape(real_x, (1, img_rows, img_cols, channels)), 
            tf.reshape(real_y, (1, img_rows, img_cols, channels))
        )

    generate_images()
    print(f"Time taken: {time.time() - start}")

Epoch: 0
0
100
200
300
400
500
600
700


KeyboardInterrupt: 