In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from keras.layers import Dense, Conv2D, Flatten, Conv2DTranspose, MaxPool2D

import tensorflow_addons as tfa

import PIL
import PIL.Image

autotune = tf.data.AUTOTUNE
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        pass
       # print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
input_image_size = (256, 256, 3)

def preprocess_image(img):
    img = (tf.cast(img, dtype=tf.float32) / 127.5) - 1
    return img
    
monet_filepath = '../input/gan-getting-started/monet_jpg'
photo_filepath = '../input/gan-getting-started/photo_jpg'

monet_data = tf.keras.utils.image_dataset_from_directory(monet_filepath, labels = None, image_size= (256, 256), batch_size = 1)
photo_data = tf.keras.utils.image_dataset_from_directory(photo_filepath, labels = None, image_size= (256, 256), batch_size = 1)
    


monet_ds = monet_data.map(preprocess_image, num_parallel_calls = autotune)
photo_ds = photo_data.map(preprocess_image, num_parallel_calls = autotune)

monet_ds

In [None]:
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(photo_ds.take(4), monet_ds.take(4))):
    photo = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    monet = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    ax[i, 0].imshow(photo)
    ax[i, 1].imshow(monet)
plt.show()

MODEL BUILDING BLOCKS

In [None]:
kernel_init = tf.random_normal_initializer(0., 0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

def downsample(
    filters, 
    kernel_size,
    activation, 
    strides = 2,
    kernel_initializer = kernel_init, 
    gamma_initializer = gamma_init,
    apply_instancenorm=True
):
    result = tf.keras.models.Sequential()
    
    result.add(Conv2D(
        filters, 
        kernel_size, 
        strides=strides, 
        padding='same',
        kernel_initializer = kernel_initializer, 
        use_bias = False
    ))
    
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer))
    
    result.add(activation)
    
    return result

def upsample(
    filters, 
    kernel_size, 
    activation,
    strides = 2,
    kernel_initializer = kernel_init, 
    gamma_initializer = gamma_init,
    apply_instancenorm = True, 
    apply_dropout=False
):
    result = tf.keras.models.Sequential()
    
    result.add(Conv2DTranspose(
        filters, 
        kernel_size, 
        strides=strides, 
        padding='same', 
        kernel_initializer = kernel_initializer, 
        use_bias = False
    ))
    
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer))
        
    if apply_dropout:
        result.add(layers.Dropout(0.5))
    
    result.add(activation)
    
    return result

def residual_block(
    dim, 
    activation,
    kernel_size = (3,3),
    strides = 1,
    kernel_initializer = kernel_init, 
    gamma_initializer = gamma_init, 
    apply_instancenorm = True
):
    
    result = tf.keras.models.Sequential()
    
    result.add(Conv2D(
        dim,
        kernel_size = kernel_size,
        strides = strides,
        padding = 'same',
        kernel_initializer = kernel_initializer,
        use_bias = False
    ))
    
    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer = gamma_initializer))
        
    result.add(activation)
    
    return result

# GENERATOR

In [None]:
def get_gen_model(
    filters = 64, 
    num_downsample = 2, 
    num_residual = 9, 
    num_upsample = 2, 
    gamma_initializer = gamma_init, 
    activation = layers.Activation('relu')
):
    img_input = layers.Input(shape=[256,256,3])
    
    #1st Conv Block with kernel_size 7
    x = layers.Conv2D(filters, (7,7), kernel_initializer = kernel_init, padding = 'same')(img_input)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)
    
    #Downsampling
    for _ in range(num_downsample):
        filters *= 2
        x = downsample(filters = filters, kernel_size = (3,3), activation = activation)(x)
        
    #Residual
    for i in range(num_residual):
        dim = filters
        x = residual_block(dim, activation = activation)(x)
        
    #Upsampling
    for _ in range(num_upsample):
        filters //= 2
        x = upsample(filters, kernel_size = (3,3), activation = activation)(x)
    
    #Output Conv Block
    x = layers.Conv2D(3, (7,7), padding='same')(x)
    x = layers.Activation('tanh')(x)
    
    model = keras.Model(img_input, x)
    return model

# DISCRIMINATOR

In [None]:
def get_disc_model(
    filters = 64,
    num_downsample = 4,
    gamma_initializer = gamma_init,
    activation = layers.Activation('relu')
):
    #Input layer
    img_input = layers.Input(shape = [256,256,3])
    
    #1st Layer downsample
    x = downsample(filters, kernel_size = 4, activation=activation)(img_input)
    
    #Downsampling
    for _ in range(num_downsample - 1):
        filters *= 2
        x = downsample(filters, kernel_size = 4, activation = activation)(x)
    
    #Output
    x = layers.Conv2D(1, kernel_size = 4, strides = 1, padding = 'same', activation='sigmoid')(x)
    
    model = keras.Model(img_input, x)
    return model

# Build the CycleGAN Model

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_gen,
        photo_gen,
        monet_disc,
        photo_disc,
        lambda_cycle = 10,
        lambda_identity = 0.5,
    ):
        super().__init__()
        self.monet_gen = monet_gen
        self.photo_gen = photo_gen
        self.monet_disc = monet_disc
        self.photo_disc = photo_gen
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity
        
    def compile(
        self,
        monet_gen_optimizer,
        photo_gen_optimizer,
        monet_disc_optimizer,
        photo_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn,
    ):
        super().compile()
        self.monet_gen_optimizer = monet_gen_optimizer
        self.photo_gen_optimizer = photo_gen_optimizer
        self.monet_disc_optimizer = monet_disc_optimizer
        self.photo_disc_optimizer = photo_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent = True) as tape:
            #Generate fake images from real images
            fake_monet = self.monet_gen(real_photo, training = True)
            fake_photo = self.photo_gen(real_monet, training = True)
            
            #Cycle Consistency: pass fake photos back through generators to check if we can get back the original image
            cycled_photo = self.photo_gen(fake_monet, training = True)
            cycled_monet = self.monet_gen(fake_photo, training = True)
            
            #Identity Mapping: pass real monet photo into monet generator to check if it changes the image
            identity_monet = self.monet_gen(real_monet, training = True)
            identity_photo = self.photo_gen(real_photo, training = True)
            
            #Discriminators recieving real and fake images
            disc_fake_monet = self.monet_disc(fake_monet, training = True)
            disc_fake_photo = self.photo_disc(fake_photo, training = True)
            
            disc_real_monet = self.monet_disc(real_monet, training = True)
            disc_real_photo = self.photo_disc(real_photo, training = True)
            
            #Generators adverserial loss
            monet_adv_loss = self.gen_loss_fn(disc_fake_monet)
            photo_adv_loss = self.gen_loss_fn(disc_fake_photo)
            
            #Generator cycle loss
            monet_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet) *self.lambda_cycle
            photo_cycle_loss = self.cycle_loss_fn(real_photo, cycled_photo) *self.lambda_cycle
            
            #Generator identity loss
            monet_identity_loss = self.identity_loss_fn(real_monet, identity_monet) * self.lambda_identity
            photo_identity_loss = self.identity_loss_fn(real_photo, identity_photo) * self.lambda_identity
            
            #Total Generator loss
            total_monet_gen_loss = monet_adv_loss + monet_cycle_loss + monet_identity_loss
            total_photo_gen_loss = photo_adv_loss + photo_cycle_loss + photo_identity_loss
            
            #Discriminator loss
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
        
            #Update weights
            monet_gen_grads = tape.gradient(total_monet_gen_loss, self.monet_gen.trainable_variables)
            photo_gen_grads = tape.gradient(total_photo_gen_loss, self.photo_gen.trainable_variables)
            
            monet_disc_grads = tape.gradient(monet_disc_loss, self.monet_disc.trainable_variables)
            photo_disc_grads = tape.gradient(photo_disc_loss, self.photo_disc.trainable_variables)
            
            self.monet_gen_optimizer.apply_gradients(zip(monet_gen_grads, self.monet_gen.trainable_variables))
            self.photo_gen_optimizer.apply_gradients(zip(photo_gen_grads, self.photo_gen.trainable_variables))
            
            self.monet_disc_optimizer.apply_gradients(zip(monet_disc_grads, self.monet_disc.trainable_variables))
            self.photo_disc_optimizer.apply_gradients(zip(photo_disc_grads, self.photo_disc.trainable_variables))
            
            return {
                'monet_gen_loss': total_monet_gen_loss,
                'photo_gen_loss': total_photo_gen_loss,
                'monet_disc_loss': monet_disc_loss,
                'photo_disc_loss': photo_disc_loss,
            }
            

# Loss Functions

In [None]:
l1_loss_fn = tf.keras.losses.MeanAbsoluteError()
l2_loss_fn = tf.keras.losses.MeanSquaredError()

def gen_loss_fn(fake):
    fake_loss = l2_loss_fn(tf.ones_like(fake), fake)
    return fake_loss

def disc_loss_fn(real, fake):
    real_loss = l2_loss_fn(tf.ones_like(real), real)
    fake_loss = l2_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5

def cycle_loss_fn(real_image, cycled_image):
    cycle_loss = l1_loss_fn(real_image, cycled_image)
    return cycle_loss

def identity_loss_fn(real, identity):
    identity_loss = l1_loss_fn(real, identity)
    return identity_loss

In [None]:
class GANMonitor(keras.callbacks.Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(photo_ds.take(self.num_img)):
            prediction = self.model.monet_gen(img)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction)
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")
            
        plt.show()
        plt.close()

In [None]:
LEARNING_RATE = 2e-4
BETA_1 = 0.5
EPOCHS = 50

#Get Generator and Disc models
monet_gen = get_gen_model(num_residual = 3)
photo_gen = get_gen_model(num_residual = 3)
monet_disc = get_disc_model()
photo_disc = get_disc_model()

#Create CycleGan Object
cycle_gan = CycleGan(monet_gen, photo_gen, monet_disc, photo_disc)
optimizer = tf.keras.optimizers.Adam(learning_rate = LEARNING_RATE, beta_1 = BETA_1)

#Compile CycleGan
cycle_gan.compile(
    monet_gen_optimizer = optimizer,
    photo_gen_optimizer = optimizer, 
    monet_disc_optimizer = optimizer, 
    photo_disc_optimizer = optimizer, 
    gen_loss_fn = gen_loss_fn, 
    disc_loss_fn = disc_loss_fn, 
    cycle_loss_fn = cycle_loss_fn, 
    identity_loss_fn = identity_loss_fn,
)

#Callback
plotter = GANMonitor()
    
cycle_gan.fit(tf.data.Dataset.zip((monet_ds, photo_ds)), epochs=EPOCHS, callbacks =[plotter])
    

In [None]:

_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, img in enumerate(photo_ds.take(4)):
    prediction = monet_gen(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
#     ax[i, 0].set_title("Input Photo")
#     ax[i, 1].set_title("Monet-esque")
#     ax[i, 0].axis("off")
#     ax[i, 1].axis("off")
plt.show()

In [None]:
import PIL
! mkdir '../images'

i = 1
for img in photo_ds:
    prediction = monet_gen(img, training = False)[0].numpy()
    prediction = (prediction * 127.5 +127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save('../images/' + str(i) + '.jpg')
    i += 1



In [None]:
import shutil
shutil.make_archive('/kaggle/working/images', 'zip', '/kaggle/images')