# **Cycle GAN**

Implementación de "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks"

Paper: https://arxiv.org/abs/1703.10593

Código: https://github.com/eriklindernoren/Keras-GAN/tree/master/cyclegan

Datasets: https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/

<img src="https://drive.google.com/uc?export=download&id=1MecIG1hO_PnuVt0CcqQJwrkNIEnT8H1G" align="center" style="float">

<img src="https://drive.google.com/uc?export=download&id=1EQkS0QTNq0Zk-zINNQx3dlZImyDf4fdY" align="center" style="float">

In [None]:
COLAB = True

In [None]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

Collecting git+https://www.github.com/keras-team/keras-contrib.git
  Cloning https://www.github.com/keras-team/keras-contrib.git to /tmp/pip-req-build-uqutysgx
  Running command git clone -q https://www.github.com/keras-team/keras-contrib.git /tmp/pip-req-build-uqutysgx
Building wheels for collected packages: keras-contrib
  Building wheel for keras-contrib (setup.py) ... [?25l[?25hdone
  Created wheel for keras-contrib: filename=keras_contrib-2.0.8-cp37-none-any.whl size=101078 sha256=48d20ee62b492013217cb954167fd06e81e4ba094450fd1c2a091efab2790de9
  Stored in directory: /tmp/pip-ephem-wheel-cache-o8k2dcnx/wheels/11/27/c8/4ed56de7b55f4f61244e2dc6ef3cdbaff2692527a2ce6502ba
Successfully built keras-contrib
Installing collected packages: keras-contrib
Successfully installed keras-contrib-2.0.8


In [None]:
if COLAB:
    from google_drive_downloader import GoogleDriveDownloader as gdd
    gdd.download_file_from_google_drive(file_id='1MAMrDNOM6dvETOejtCtwKEaBt1yJnCZ0',
                                        dest_path='./data_loader.py')
    gdd.download_file_from_google_drive(file_id='1ZlAPNbgtK644lbrM3G5rK6WqsI04BvHc',
                                        dest_path='./download_dataset.sh')

Downloading 1MAMrDNOM6dvETOejtCtwKEaBt1yJnCZ0 into ./data_loader.py... Done.
Downloading 1ZlAPNbgtK644lbrM3G5rK6WqsI04BvHc into ./download_dataset.sh... Done.


In [None]:
from keras.datasets import mnist
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam
import datetime
import matplotlib.pyplot as plt
import sys
from data_loader import DataLoader
import numpy as np
import os

In [None]:
class CycleGAN():
    def __init__(self):
        # Input shape
        self.img_rows = 128
        self.img_cols = 128
        self.channels = 3
        self.img_shape = (self.img_rows, self.img_cols, self.channels)

        # Configure data loader
        self.dataset_name = 'apple2orange'
        self.data_loader = DataLoader(dataset_name=self.dataset_name,
                                      img_res=(self.img_rows, self.img_cols))


        # Calculate output shape of D (PatchGAN)
        patch = int(self.img_rows / 2**4)
        self.disc_patch = (patch, patch, 1)

        # Number of filters in the first layer of G and D
        self.gf = 32
        self.df = 64

        # Loss weights
        self.lambda_cycle = 10.0                    # Cycle-consistency loss
        self.lambda_id = 0.1 * self.lambda_cycle    # Identity loss

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminators
        self.d_A = self.build_discriminator()
        self.d_B = self.build_discriminator()
        self.d_A.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])
        self.d_B.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

        #-------------------------
        # Construct Computational
        #   Graph of Generators
        #-------------------------

        # Build the generators
        self.g_AB = self.build_generator()
        self.g_BA = self.build_generator()

        # Input images from both domains
        img_A = Input(shape=self.img_shape)
        img_B = Input(shape=self.img_shape)

        # Translate images to the other domain
        fake_B = self.g_AB(img_A)
        fake_A = self.g_BA(img_B)
        # Translate images back to original domain
        reconstr_A = self.g_BA(fake_B)
        reconstr_B = self.g_AB(fake_A)
        # Identity mapping of images
        img_A_id = self.g_BA(img_A)
        img_B_id = self.g_AB(img_B)

        # For the combined model we will only train the generators
        self.d_A.trainable = False
        self.d_B.trainable = False

        # Discriminators determines validity of translated images
        valid_A = self.d_A(fake_A)
        valid_B = self.d_B(fake_B)

        # Combined model trains generators to fool discriminators
        self.combined = Model(inputs=[img_A, img_B],
                              outputs=[ valid_A, valid_B,
                                        reconstr_A, reconstr_B,
                                        img_A_id, img_B_id ])
        self.combined.compile(loss=['mse', 'mse',
                                    'mae', 'mae',
                                    'mae', 'mae'],
                            loss_weights=[  1, 1,
                                            self.lambda_cycle, self.lambda_cycle,
                                            self.lambda_id, self.lambda_id ],
                            optimizer=optimizer)

    def build_generator(self):
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            d = InstanceNormalization()(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = InstanceNormalization()(u)
            u = Concatenate()([u, skip_input])
            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling
        d1 = conv2d(d0, self.gf)
        d2 = conv2d(d1, self.gf*2)
        d3 = conv2d(d2, self.gf*4)
        d4 = conv2d(d3, self.gf*8)

        # Upsampling
        u1 = deconv2d(d4, d3, self.gf*4)
        u2 = deconv2d(u1, d2, self.gf*2)
        u3 = deconv2d(u2, d1, self.gf)

        u4 = UpSampling2D(size=2)(u3)
        output_img = Conv2D(self.channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

        return Model(d0, output_img)

    def build_discriminator(self):

        def d_layer(layer_input, filters, f_size=4, normalization=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if normalization:
                d = InstanceNormalization()(d)
            return d

        img = Input(shape=self.img_shape)

        d1 = d_layer(img, self.df, normalization=False)
        d2 = d_layer(d1, self.df*2)
        d3 = d_layer(d2, self.df*4)
        d4 = d_layer(d3, self.df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model(img, validity)

    def train(self, epochs, batch_size=1, sample_interval=50):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.disc_patch)
        fake = np.zeros((batch_size,) + self.disc_patch)

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(self.data_loader.load_batch(batch_size)):

                # ----------------------
                #  Train Discriminators
                # ----------------------

                # Translate images to opposite domain
                fake_B = self.g_AB.predict(imgs_A)
                fake_A = self.g_BA.predict(imgs_B)

                # Train the discriminators (original images = real / translated = Fake)
                dA_loss_real = self.d_A.train_on_batch(imgs_A, valid)
                dA_loss_fake = self.d_A.train_on_batch(fake_A, fake)
                dA_loss = 0.5 * np.add(dA_loss_real, dA_loss_fake)

                dB_loss_real = self.d_B.train_on_batch(imgs_B, valid)
                dB_loss_fake = self.d_B.train_on_batch(fake_B, fake)
                dB_loss = 0.5 * np.add(dB_loss_real, dB_loss_fake)

                # Total disciminator loss
                d_loss = 0.5 * np.add(dA_loss, dB_loss)


                # ------------------
                #  Train Generators
                # ------------------

                # Train the generators
                g_loss = self.combined.train_on_batch([imgs_A, imgs_B],
                                                        [valid, valid,
                                                        imgs_A, imgs_B,
                                                        imgs_A, imgs_B])

                elapsed_time = datetime.datetime.now() - start_time

                # Plot the progress
                print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %05f, adv: %05f, recon: %05f, id: %05f] time: %s " \
                                                                        % ( epoch, epochs,
                                                                            batch_i, self.data_loader.n_batches,
                                                                            d_loss[0], 100*d_loss[1],
                                                                            g_loss[0],
                                                                            np.mean(g_loss[1:3]),
                                                                            np.mean(g_loss[3:5]),
                                                                            np.mean(g_loss[5:6]),
                                                                            elapsed_time))

                # If at save interval => save generated image samples
                if batch_i % sample_interval == 0:
                    self.sample_images(epoch, batch_i)

    def sample_images(self, epoch, batch_i):
        os.makedirs('images/%s' % self.dataset_name, exist_ok=True)
        r, c = 2, 3

        imgs_A = self.data_loader.load_data(domain="A", batch_size=1, is_testing=True)
        imgs_B = self.data_loader.load_data(domain="B", batch_size=1, is_testing=True)

        # Demo (for GIF)
        #imgs_A = self.data_loader.load_img('datasets/apple2orange/testA/n07740461_1541.jpg')
        #imgs_B = self.data_loader.load_img('datasets/apple2orange/testB/n07749192_4241.jpg')

        # Translate images to the other domain
        fake_B = self.g_AB.predict(imgs_A)
        fake_A = self.g_BA.predict(imgs_B)
        # Translate back to original domain
        reconstr_A = self.g_BA.predict(fake_B)
        reconstr_B = self.g_AB.predict(fake_A)

        gen_imgs = np.concatenate([imgs_A, fake_B, reconstr_A, imgs_B, fake_A, reconstr_B])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Original', 'Translated', 'Reconstructed']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[j])
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%s/%d_%d.png" % (self.dataset_name, epoch, batch_i))
        plt.close()

In [None]:
!chmod +x download_dataset.sh

In [None]:
!ls -la

total 28
drwxr-xr-x 1 root root 4096 Jun 25 08:38 .
drwxr-xr-x 1 root root 4096 Jun 25 08:38 ..
drwxr-xr-x 4 root root 4096 Jun 15 13:37 .config
-rw-r--r-- 1 root root 2515 Jun 25 08:38 data_loader.py
-rwxr-xr-x 1 root root  824 Jun 25 08:38 download_dataset.sh
drwxr-xr-x 2 root root 4096 Jun 25 08:38 __pycache__
drwxr-xr-x 1 root root 4096 Jun 15 13:37 sample_data


In [None]:
!./download_dataset.sh apple2orange

for details.

--2021-06-25 08:38:53--  https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/apple2orange.zip
Resolving people.eecs.berkeley.edu (people.eecs.berkeley.edu)... 128.32.244.190
Connecting to people.eecs.berkeley.edu (people.eecs.berkeley.edu)|128.32.244.190|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 78456409 (75M) [application/zip]
Saving to: ‘./datasets/apple2orange.zip’


2021-06-25 08:39:05 (6.43 MB/s) - ‘./datasets/apple2orange.zip’ saved [78456409/78456409]

Archive:  ./datasets/apple2orange.zip
   creating: ./datasets/apple2orange/trainA/
  inflating: ./datasets/apple2orange/trainA/n07740461_6908.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_7635.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_586.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_9813.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_6835.jpg  
  inflating: ./datasets/apple2orange/trainA/n07740461_2818.jpg  
  inf

In [None]:
gan = CycleGAN()

In [None]:
gan.train(epochs=200, batch_size=1, sample_interval=200)

[1;30;43mSe han truncado las últimas 5000 líneas del flujo de salida.[0m
[Epoch 0/200] [Batch 16/995] [D loss: 0.238765, acc:  80%] [G loss: 11.673101, adv: 0.991573, recon: 0.383772, id: 1.130964] time: 0:01:05.708397 
[Epoch 0/200] [Batch 17/995] [D loss: 0.330386, acc:  66%] [G loss: 14.287024, adv: 0.777162, recon: 0.549513, id: 0.936193] time: 0:01:05.889205 
[Epoch 0/200] [Batch 18/995] [D loss: 0.366270, acc:  52%] [G loss: 13.420904, adv: 0.757957, recon: 0.501769, id: 1.098626] time: 0:01:06.073768 
[Epoch 0/200] [Batch 19/995] [D loss: 0.962799, acc:  66%] [G loss: 13.609671, adv: 1.028240, recon: 0.487753, id: 1.014808] time: 0:01:06.259421 
[Epoch 0/200] [Batch 20/995] [D loss: 0.665527, acc:  60%] [G loss: 13.186418, adv: 0.715419, recon: 0.504655, id: 0.835880] time: 0:01:06.436585 
[Epoch 0/200] [Batch 21/995] [D loss: 1.068607, acc:  53%] [G loss: 10.575246, adv: 1.263398, recon: 0.303460, id: 1.072178] time: 0:01:06.638793 
[Epoch 0/200] [Batch 22/995] [D loss: 0.470

KeyboardInterrupt: ignored

In [None]:
gan.img_shape

(128, 128, 3)

In [None]:
gan.combined.save_weights("./cycleGAN_combined.h5")
gan.d_A.save_weights("./cycleGAN_d_A.h5")
gan.d_B.save_weights("./cycleGAN_d_B.h5")
gan.g_AB.save_weights("./cycleGAN_g_AB.h5")
gan.g_BA.save_weights("./cycleGAN_g_BA.h5")

In [None]:
for imgs_A, imgs_B in gan.data_loader.load_batch(16):
    # Translate images to opposite domain
    fake_B = gan.g_AB.predict(imgs_A)
    fake_A = gan.g_BA.predict(imgs_B)

In [None]:
imgs_A.shape, fake_A.shape, imgs_A.min(), imgs_A.max(), fake_A.min(), fake_A.max()

((16, 128, 128, 3), (16, 128, 128, 3), -1.0, 1.0, -0.99434435, 1.0)

In [None]:
gan.dataset_name

'apple2orange'

In [None]:
for i0, i1 in zip(imgs_A, fake_B):
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow((i0+1)/2)
    plt.subplot(1,2,2)
    plt.imshow((i1+1)/2)
    plt.show()

In [None]:
for i0, i1 in zip(imgs_B, fake_A):
    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    plt.imshow((i0+1)/2)
    plt.subplot(1,2,2)
    plt.imshow((i1+1)/2)
    plt.show()