In [18]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

In [19]:
dataset, info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

BUFFERSIZE = 1000
BATCHSIZE = 1
WIDTH = 256
HEIGHT = 256
CHANNELS = 3
LAMBDA = 10
EPOCHS = 40

In [20]:
def random_crop(image):
    return tf.image.random_crop(image, size=[HEIGHT, WIDTH, 3])

def normalize(image):
    return (tf.cast(image, tf.float32) / 127.5) - 1

def random_jitter(image):
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    image = random_crop(image)
    image = tf.image.random_flip_left_right(image)
    return image

def preprocess_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_test(image, label):
    image = normalize(image)
    return image

train_horses = train_horses.cache().map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE).shuffle(BUFFERSIZE).batch(BATCHSIZE)
train_zebras = train_zebras.cache().map(preprocess_train, num_parallel_calls=tf.data.AUTOTUNE).shuffle(BUFFERSIZE).batch(BATCHSIZE)
test_horses = test_horses.map(preprocess_test, num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(BUFFERSIZE).batch(BATCHSIZE)
test_zebras = test_zebras.map(preprocess_test, num_parallel_calls=tf.data.AUTOTUNE).cache().shuffle(BUFFERSIZE).batch(BATCHSIZE)

In [21]:
gen_g = pix2pix.unet_generator(CHANNELS, norm_type='instancenorm')
gen_f = pix2pix.unet_generator(CHANNELS, norm_type='instancenorm')
disc_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
disc_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [22]:
loss_bce = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def disc_loss(real, gen):
    real_loss = loss_bce(tf.ones_like(real), real)
    gen_loss = loss_bce(tf.zeros_like(gen), gen)
    return 0.5 * (real_loss + gen_loss)
def gen_loss(gen):
    return loss_bce(tf.ones_like(gen), gen)
def cycle_loss(real, cycled):
    return LAMBDA * tf.reduce_mean(tf.abs(real - cycled))
def identity_loss(real, same):
    return 0.5 * LAMBDA * tf.reduce_mean(tf.abs(real - same))

gen_g_optim = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
gen_f_optim = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_x_optim = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
disc_y_optim = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [23]:
def generate_images(model, test_input):
    pred = model(test_input)
    plt.figure()
    disp_list = [test_input[0], pred[0]]
    for i in range(2):
        plt.subplot(1,2,i+1)
        plt.imshow(disp_list[i] * 0.5 + 0.5)
    plt.show()

In [24]:
@tf.function
def train_step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        fake_y = gen_g(real_x, training=True)
        cycled_x = gen_f(fake_y, training=True)
        
        fake_x = gen_f(real_y, training=True)
        cycled_y = gen_g(fake_x, training=True)
        
        same_x = gen_f(real_x, training=True)
        same_y = gen_g(real_y, training=True)
        
        disc_real_x = disc_x(real_x, training=True)
        disc_real_y = disc_y(real_y, training=True)
        
        disc_fake_x = disc_x(fake_x, training=True)
        disc_fake_y = disc_y(fake_y, training=True)
        
        gen_g_loss = gen_loss(disc_fake_y)
        gen_f_loss = gen_loss(disc_fake_x)
        
        total_cycle_loss = cycle_loss(real_x, cycled_x) + cycle_loss(real_y, cycled_y)
        
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)
        
        disc_x_loss = disc_loss(disc_real_x, disc_fake_x)
        disc_y_loss = disc_loss(disc_real_y, disc_fake_y)
    
    gen_g_grads = tape.gradient(total_gen_g_loss, gen_g.trainable_variables)
    gen_f_grads = tape.gradient(total_gen_f_loss, gen_f.trainable_variables)
    disc_x_grads = tape.gradient(disc_x_loss, disc_x.trainable_variables)
    disc_y_grads = tape.gradient(disc_y_loss, disc_y.trainable_variables)
    
    gen_g_optim.apply_gradients(zip(gen_g_grads, gen_g.trainable_variables))
    gen_f_optim.apply_gradients(zip(gen_f_grads, gen_f.trainable_variables))
    disc_x_optim.apply_gradients(zip(disc_x_grads, disc_x.trainable_variables))
    disc_y_optim.apply_gradients(zip(disc_y_grads, disc_y.trainable_variables))

In [None]:
sample_horse = next(iter(train_horses))

for epoch in range(EPOCHS):
    n = 0
    for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
        train_step(image_x, image_y)
        if n % 10 == 0:
            print('adsf')
        n += 1
    clear_output(wait=True)
    generate_images(gen_g, sample_horse)