Dans ce notebook, nous allons effectuer une "traduction" d'image à image en utilisant des CycleGAN. Nous proposons une méthode qui capture les caractéristiques d'images du base de données (MNIST) et applique ses caractéristiques dans un autre domaine de données (USPS), et ce sans exemple de paires pour l'entrainement (problème non-supervisé).

Les CycleGANs utilisent une perte de cycle constante pour permettre l'entrainement sans images en paires. En d'autres termes, ils permettent de traduire d'un domaine à un autre sans avoir mappé au préalable les données.

## Set up

In [1]:
# pip install -q git+https://github.com/tensorflow/examples.git

In [2]:
# pip install tensorflow

In [3]:
# pip install tensorflow_datasets

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow_examples.models.pix2pix import pix2pix

import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

import h5py
import numpy as np

In [None]:
tfds.disable_progress_bar()
AUTOTUNE = tf.data.experimental.AUTOTUNE

## Données

In [None]:
dataset, metadata = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)
train_horses, train_zebras = dataset['trainA'], dataset['trainB']
test_horses, test_zebras = dataset['testA'], dataset['testB']

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[IMG_HEIGHT, IMG_WIDTH, 3])
    return cropped_image

# normalizing the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)
    # random mirroring
    image = tf.image.random_flip_left_right(image)
    return image

def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image, label):
    image = normalize(image)
    return image

In [None]:
BUFFER_SIZE = 1000
BATCH_SIZE = 1
IMG_WIDTH = 256
IMG_HEIGHT = 256

train_horses = train_horses.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
train_zebras = train_zebras.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_horses = test_horses.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)
test_zebras = test_zebras.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(1)

sample_horse = next(iter(train_horses))
sample_zebra = next(iter(train_zebras))

In [None]:
(mnist_images, mnist_labels), (_, _) = tf.keras.datasets.mnist.load_data()

mnist_images = mnist_images.reshape(mnist_images.shape[0], 28, 28, 1).astype('float32')
mnist_images = (mnist_images - 127.5) / 127.5 # Normalize the images to [-1, 1]

BUFFER_SIZE = len(mnist_images)
BATCH_SIZE = 1

mnist_images2 = np.ndarray(shape=(60000, 28, 28, 3))
for i in range(mnist_images.shape[0]):
    for j in range(mnist_images.shape[1]):
        for k in range(mnist_images.shape[2]):
            mnist_images2[i][j][k]=np.array([float(mnist_images[i][j][k]), float(mnist_images[i][j][k]), float(mnist_images[i][j][k])])
        
# Batch and shuffle the data
mnist_dataset = tf.data.Dataset.from_tensor_slices(mnist_images2).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# train_mnist = tf.convert_to_tensor(mnist_images2, dtype=None, dtype_hint=None, name=None)

In [None]:
filename = "Data/usps.h5"
f = h5py.File(filename, 'r+')

usps = f['train']

usps_images = np.array(usps['data'])
usps_labels = np.array(usps['target'])

usps_images = usps_images.reshape(usps_images.shape[0], 16, 16, 1).astype('float32')

BUFFER_SIZE = len(usps_images)
BATCH_SIZE = 1

usps_images2 = np.ndarray(shape=(60000, 28, 28, 3))
for i in range(usps_images.shape[0]):
    for j in range(usps_images.shape[1]):
        for k in range(usps_images.shape[2]):
            usps_images2[i][j][k]=np.array([float(usps_images[i][j][k]), float(usps_images[i][j][k]), float(usps_images[i][j][k])])

usps_dataset = tf.data.Dataset.from_tensor_slices(usps_images2).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
sample_usps = tf.convert_to_tensor(usps_images2, dtype=None, dtype_hint=None, name=None)

In [None]:
sample_mnist = next(iter(mnist_dataset))
sample_usps = next(iter(usps_dataset))

In [None]:
def random_crop(image):
    cropped_image = tf.image.random_crop(image, size=[256, 256, 3])
    return cropped_image

# normalizing the images to [-1, 1]
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)
    # random mirroring
    image = tf.image.random_flip_left_right(image)
    return image

def preprocess_image_train(image, label):
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image, label):
    image = normalize(image)
    return image

In [None]:
plt.subplot(221)
plt.title('mnist')
plt.imshow(sample_mnist[0], cmap='gray')

plt.subplot(222)
plt.title('mnist with random jitter')
plt.imshow(random_jitter(sample_mnist[0]), cmap='gray')

plt.subplot(223)
plt.title('usps')
plt.imshow(sample_usps[0], cmap='gray')

plt.subplot(224)
plt.title('usps with random jitter')
plt.imshow(random_jitter(sample_usps[0]), cmap='gray')

In [None]:
sample_mnist = tf.image.resize(sample_mnist, [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
sample_usps = tf.image.resize(sample_usps, [256, 256], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

## Pix2Pix models

On importe le générateur et le discriminateur du modèle Pix2Pix en installant le package tensorflow_examples.

L'architecture du modèle utilisé ets très similaire à ce qui est utilisé dans le modèle Pix2Pix.

Il y a deux générateurs ($G$ et $F$) qui sont entrainés, et deux discriminateurs ($D_X$ et $D_Y$) :
* Le générateur $G$ apprend à transformer les images $X$ en image $Y$
* Le générateur $F$ apprend à transformer les images $Y$ en image $X$
* Le discriminateur $D_X$ apprend à différentier les vraies images $X$ des images générées $X$
* Le discriminateur $D_Y$ apprend à différentier les vraies images $Y$ des images générées $Y$


In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
# to_zebra = generator_g(sample_horse)
# to_horse = generator_f(sample_zebra)
# plt.figure(figsize=(8, 8))
# contrast = 8

# imgs = [sample_horse, to_zebra, sample_zebra, to_horse]
# title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']

# for i in range(len(imgs)):
#     plt.subplot(2, 2, i+1)
#     plt.title(title[i])
#     if i % 2 == 0:
#         plt.imshow(imgs[i][0] * 0.5 + 0.5)
#     else:
#         plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
# plt.show()

In [None]:
to_usps = generator_g(sample_mnist)
to_mnist = generator_f(sample_usps)
plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_mnist, to_usps, sample_usps, to_mnist]
title = ['mnist', 'to usps', 'usps', 'to mnist']

for i in range(len(imgs)):
    plt.subplot(2, 2, i+1)
    plt.title(title[i])
    if i % 2 == 0:
        plt.imshow(imgs[i][0] * 0.5 + 0.5)
    else:
        plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
plt.show()

In [None]:
sample_mnist

In [None]:
sample_usps

In [None]:
# plt.figure(figsize=(8, 8))
# contrast = 8

# imgs = [sample_mnist, to_usps, sample_usps, to_mnist]
# title = ['mnist', 'to usps', 'usps', 'to mnist']

# for i in range(len(imgs)):
#     plt.subplot(2, 2, i+1)
#     plt.title(title[i])
#     if i % 2 == 0:
#         plt.imshow(imgs[i][0] * 0.5 + 0.5)
#     else:
#         plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)
# plt.show()

In [None]:
# plt.figure(figsize=(8, 8))

# plt.subplot(121)
# plt.title('Is a real zebra?')
# plt.imshow(discriminator_y(sample_zebra)[0, ..., -1], cmap='RdBu_r')

# plt.subplot(122)
# plt.title('Is a real horse?')
# plt.imshow(discriminator_x(sample_horse)[0, ..., -1], cmap='RdBu_r')

# plt.show()

## Fonctions de perte
Dans les CycleGANs, il n'y a pas de données couplés pour l'entrainement, donc il n'y a aucune garantie que l'input x et la target y sont couplés de manière sensée durant l'entrainement. Donc, pour apprendre à correctement mapper, on utilise une perte de consistance de cycle ("cycle consistency loss")  

In [None]:
LAMBDA = 10

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real, generated):
    real_loss = loss_obj(tf.ones_like(real), real)
    generated_loss = loss_obj(tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss / 2

In [None]:
def generator_loss(generated):
    return loss_obj(tf.ones_like(generated), generated)

La consistance du cycle signifie que le résulat devrait être proche de l'input original. Par exemple, si on traduit une phrase de l'anglais au français, puis qu'on retraduit du français vers l'anglais, alors la phrase obtenue devrait être la même que la phrase originale.

Dans la perte de consistance de cycle, 
* le générateur $G$ crée une image (générée) $\hat{Y}$ à partir d'une vraie image $X$
* le générateur $F$ crée une image (cyclée) $\hat{X}$ à partir de l'image générée $\hat{Y}$
* on calcule l'erreur moyenne absolue entre $X$ et $\hat{X}$ (forward cycle consistency loss) et entre $Y$ et $\hat{Y}$ (backward cycle consistency loss)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

### Fonctions d'opimisation

In [None]:
generator_g_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_f_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_x_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_y_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## Checkpoints

In [None]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

## Entrainement
La boucle d'entrainement est composée des 4 étapes suivantes : 
1. Obtenir les prédictions
2. Calculer la perte
3. Calculer les gradients par backpropagation
* Appliquer les gradients à la fonction d'optimisation

In [None]:
EPOCHS = 1 #40

In [None]:
def generate_images(model, test_input):
    prediction = model(test_input)
    
    plt.figure(figsize=(12, 12))

    display_list = [test_input[0], prediction[0]]
    title = ['Input Image', 'Predicted Image']

    for i in range(2):
        plt.subplot(1, 2, i+1)
        plt.title(title[i])
        # getting the pixel values between [0, 1] to plot it.
        plt.imshow(display_list[i] * 0.5 + 0.5)
        plt.axis('off')
    plt.show()

In [None]:
@tf.function
def train_step(real_x, real_y):
    # persistent is set to True because the tape is used more than
    # once to calculate the gradients.
    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.
    
        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)
    
        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)
    
        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)
    
        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)
  
    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, 
                                        generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, 
                                        generator_f.trainable_variables)
  
    discriminator_x_gradients = tape.gradient(disc_x_loss, 
                                            discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, 
                                            discriminator_y.trainable_variables)
  
    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, 
                                            generator_g.trainable_variables))

    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, 
                                            generator_f.trainable_variables))
  
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients,
                                                discriminator_x.trainable_variables))
  
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients,
                                                discriminator_y.trainable_variables))

In [None]:
for epoch in range(EPOCHS):
    start = time.time()

    n = 0
    for image_x, image_y in tf.data.Dataset.zip((train_horses, train_zebras)):
        train_step(image_x, image_y)
        if n % 10 == 0:
            print ('.', end='')
        n+=1

    clear_output(wait=True)
    # Using a consistent image (sample_horse) so that the progress of the model
    # is clearly visible.
    generate_images(generator_g, sample_horse)

    if (epoch + 1) % 5 == 0:
        ckpt_save_path = ckpt_manager.save()
        print ('Saving checkpoint for epoch {} at {}'.format(epoch+1, ckpt_save_path))

    print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1, time.time()-start))