In [None]:
!pip install tensorflow-addons==0.9.1 

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import tensorflow as tf
import tensorflow_addons as tfa
import IPython.display as display

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


import random
import math

import time
from IPython.display import clear_output

In [None]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.MirroredStrategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

In [None]:
from kaggle_datasets import KaggleDatasets
GCS_PATH = KaggleDatasets().get_gcs_path()

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
print('Monet TFRecord Files:', len(MONET_FILENAMES))

PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))
print('Photo TFRecord Files:', len(PHOTO_FILENAMES))

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def random_crop(image):
    cropped_image = tf.image.random_crop(
    image, size=[*IMAGE_SIZE, 3])

    return cropped_image

def random_jitter(image):
    # resizing to 286 x 286 x 3
    image = tf.image.resize(image, [286, 286],
                          method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

    # randomly cropping to 256 x 256 x 3
    image = random_crop(image)

    # random mirroring
    image = tf.image.random_flip_left_right(image)

    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames, labeled=True, ordered=False):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

submit_ds = load_dataset(PHOTO_FILENAMES, labeled=True).batch(1)

monet_ds = load_dataset(MONET_FILENAMES, labeled=True).map(random_jitter, num_parallel_calls=AUTOTUNE).shuffle(300).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES, labeled=True).map(random_jitter, num_parallel_calls=AUTOTUNE).shuffle(300).batch(1)

final_dataset = tf.data.Dataset.zip((monet_ds, photo_ds))

In [None]:
it = iter(submit_ds)
tes = next(it)

In [None]:
def generate_images(model, test_input):
  prediction = model(test_input)

  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        self.input_spec = [tf.keras.layers.InputSpec(ndim=4)]
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def compute_output_shape(self, s):
        """ If you are using "channels_last" configuration"""
        return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])

    def call(self, x, mask=None):
        w_pad,h_pad = self.padding
        return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')

In [None]:
def FeatureMapBlock(output_channels):
    initializer = tf.random_normal_initializer(0.,0.02)
    
    result = tf.keras.Sequential()
    
    result.add(ReflectionPadding2D((3,3)))
    result.add(tf.keras.layers.Conv2D(
                                output_channels, 
                                7, 
                                strides= 1, 
                                kernel_initializer=initializer,
                                use_bias=False,
                                padding='valid'
               ))
    
    return result

In [None]:
def ContractingBlock(output_channels, use_bn=True, kernel_size=3, strides=2, activation='relu'):
    #Intializer
    initializer = tf.random_normal_initializer(0.,0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
    
    #Layers
    result = tf.keras.Sequential()
    
    result.add(ReflectionPadding2D((1,1)))
    
    result.add(tf.keras.layers.Conv2D(output_channels, 
                                           kernel_size= kernel_size, 
                                           kernel_initializer=initializer,
                                           use_bias=False,
                                           strides= strides, 
                                           padding='valid'))
    
    if use_bn:
            result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    
    if activation == 'lrelu':
        result.add(tf.keras.layers.LeakyReLU(0.2))
    elif activation == 'relu':
        result.add(tf.keras.layers.Activation('relu'))
    else:
        result.add(tf.keras.layers.Activation('sigmoid'))
    
    return result   

In [None]:
class ResidualBlock(tf.keras.Model):

    def __init__(self, output_channels):
        super(ResidualBlock, self).__init__()
        
        #Initializer
        initializer = tf.random_normal_initializer(0.,0.02)
        gamma_init = tf.keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
        
        #Layer
        self.padding1 = ReflectionPadding2D((1,1))
        self.padding2 = ReflectionPadding2D((1,1))
        
        
        self.conv1 = tf.keras.layers.Conv2D(output_channels, 
                                            3, 
                                            padding='valid', 
                                            kernel_initializer=initializer,
                                            use_bias=False,)
        
        self.conv2 = tf.keras.layers.Conv2D(output_channels, 
                                            3, 
                                            padding='valid', 
                                            kernel_initializer=initializer,
                                            use_bias=False,)
        
        
        self.instancenorm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)
        self.instancenorm2 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)
        self.activation = tf.keras.layers.Activation('relu')

    def call(self, x):
    
        x_original = tf.identity(x)
        x = self.padding1(x)
        x = self.conv1(x)
        x = self.instancenorm1(x)
        x = self.activation(x)
        x = self.padding2(x)
        x = self.conv2(x)
        x = self.instancenorm2(x)
    
        return (x_original + x)

In [None]:
def ExpandingBlock(output_channels, use_bn=True, kernel_size=3):
    #Intializer
    initializer = tf.random_normal_initializer(0.,0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0,stddev=0.02)
    
    #Layers
    result = tf.keras.Sequential()
    
    result.add(tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='nearest'))
    
    result.add(ReflectionPadding2D((1,1)))
    
    result.add(tf.keras.layers.Conv2D(output_channels, 
                                      kernel_size= kernel_size, 
                                      kernel_initializer=initializer,
                                      use_bias=False,
                                      strides= 1, 
                                      padding='valid'))
    
    if use_bn:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))
    
    result.add(tf.keras.layers.Activation('relu'))
    
    return result   

In [None]:
class Generator(tf.keras.Model):

    def __init__(self, output_channels=3, res_layer= 9, hidden_channels=64):
        super(Generator, self).__init__()
        self.upfeature = FeatureMapBlock(hidden_channels)
        self.contract1 = ContractingBlock(hidden_channels*2)
        self.contract2 = ContractingBlock(hidden_channels * 4)
        self.res = [ResidualBlock(hidden_channels * 4) for _ in range(res_layer)]
        self.expand2 = ExpandingBlock(hidden_channels * 2)
        self.expand3 = ExpandingBlock(hidden_channels)
        self.downfeature = FeatureMapBlock(output_channels)
        self.tanh = tf.keras.layers.Activation('tanh')
        
    def call(self, x):
        
        x = self.upfeature(x)
        x = self.contract1(x)
        x = self.contract2(x)
        for layer in self.res:
            x = layer(x)
        x = self.expand2(x)
        x = self.expand3(x)
        x = self.downfeature(x)
        
        return self.tanh(x)

In [None]:
class Discriminator(tf.keras.Model):
    
    
    def __init__(self, hidden_channels=64):
        super(Discriminator, self).__init__()
        
        # initializer
        initializer = tf.random_normal_initializer(0.,0.02)
        
        self.contract1 = ContractingBlock(hidden_channels, use_bn=False, kernel_size=4, activation='lrelu')
        self.contract2 = ContractingBlock(hidden_channels*2, kernel_size=4, activation='lrelu')
        self.contract3 = ContractingBlock(hidden_channels*4, kernel_size=4, activation='lrelu')
        self.contract4 = ContractingBlock(hidden_channels*8, kernel_size=4, strides=1, activation='lrelu')
        self.final = ContractingBlock(1, kernel_size=4, use_bn=False, strides=1, activation='sigmoid')

    def call(self, x):
        x = self.contract1(x)
        x = self.contract2(x)
        x = self.contract3(x)
        x = self.contract4(x)
        x = self.final(x)
        return x

In [None]:
with strategy.scope():
    g_monet = Generator()
    g_photo = Generator()
    
    d_monet = Discriminator()
    d_photo = Discriminator()
    
    g_monet_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
    g_photo_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
    d_monet_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)
    d_photo_optimizer = tf.keras.optimizers.Adam(0.0002, beta_1=0.5)

In [None]:
#adv_criterion = tf.keras.losses.MeanSquaredError(reduction=tf.keras.losses.Reduction.SUM)
#recon_criterion = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.SUM)

with strategy.scope():
    def adv_criterion(x,y):
        return tf.reduce_mean(tf.math.squared_difference(x,y))
    
    def recon_criterion(x,y):
        return tf.reduce_mean(tf.abs(x-y))
    
    def get_disc_loss(d_fake, d_real, adv_criterion):
        return 0.5 * adv_criterion(d_real, tf.ones_like(d_real)) + 0.5 * adv_criterion(d_fake, tf.zeros_like(d_fake)) 
    
    def get_gen_loss(d_fake, adv_criterion):
        return 0.5*adv_criterion(d_fake, tf.ones_like(d_fake))
    
    def get_identity_loss(real, identity, identity_criterion, lambda_identity):
        return lambda_identity * identity_criterion(real, identity)
    
    def get_cycle_consistency_loss(real, cycle, cycle_criterion, lambda_cycle):
        return lambda_cycle * cycle_criterion(real, cycle)
    
   
    

In [None]:
class CycleGan(tf.keras.Model):
    def __init__(
                self,
                monet_generator,
                photo_generator,
                monet_discriminator,
                photo_discriminator
        ):
        super(CycleGan, self).__init__()
        self.g_monet = g_monet
        self.g_photo = g_photo
        self.d_monet = d_monet
        self.d_photo = d_photo
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        disc_loss_fn,
        gen_loss_fn,
        cycle_loss_fn,
        identity_loss_fn
    ):
        super(CycleGan, self).compile()
        self.g_monet_optimizer = m_gen_optimizer
        self.g_photo_optimizer = p_gen_optimizer
        self.d_monet_optimizer = m_disc_optimizer
        self.d_photo_optimizer = p_disc_optimizer
        self.get_disc_loss = disc_loss_fn
        self.get_gen_loss = gen_loss_fn
        self.get_cycle_loss = cycle_loss_fn
        self.get_identity_loss = identity_loss_fn
     
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        with tf.GradientTape(persistent=True) as tape:
            fake_monet = self.g_monet(real_photo, training=True)
            fake_photo = self.g_photo(real_monet, training=True)
            
            identity_monet = self.g_monet(real_monet, training=True)
            identity_photo = self.g_photo(real_photo, training=True)
            
            cycle_monet = self.g_monet(fake_photo, training=True)
            cycle_photo = self.g_photo(fake_monet, training=True)
            
            d_fake_monet = self.d_monet(fake_monet, training=True)
            d_fake_photo = self.d_photo(fake_photo, training=True)
            
            d_real_monet = self.d_monet(real_monet, training=True)
            d_real_photo = self.d_photo(real_photo, training=True)
            
            #Calculate Adversarial Loss of Monet Discrimator, the Gradient and train the Discrimintaor
            monet_disc_loss = self.get_disc_loss(d_fake_monet, d_real_monet, adv_criterion)
            
            #Calculate Adversarial Loss of Photo Discrimator, the Gradient and train the Discrimintaor
            photo_disc_loss = self.get_disc_loss(d_fake_photo, d_real_photo, adv_criterion)
            
            #Calculate Generator Loss, Total Cycle Loss, and Identity Loss
            total_cycle_loss = (self.get_cycle_loss(real_monet, cycle_monet, recon_criterion, 10) +
                                self.get_cycle_loss(real_photo, cycle_photo, recon_criterion, 10))
            
            monet_gen_loss = self.get_gen_loss(d_fake_monet, adv_criterion)
            photo_gen_loss = self.get_gen_loss(d_fake_photo, adv_criterion)
            
            monet_identity_loss = self.get_identity_loss(real_monet, identity_monet, recon_criterion, 5)
            photo_identity_loss = self.get_identity_loss(real_photo, identity_photo, recon_criterion, 5)
            
            monet_total_gen_loss = monet_gen_loss + total_cycle_loss + monet_identity_loss
 
            photo_total_gen_loss = photo_gen_loss + total_cycle_loss + photo_identity_loss
            
        monet_disc_gradients = tape.gradient(monet_disc_loss, self.d_monet.trainable_variables)
        self.d_monet_optimizer.apply_gradients(zip(monet_disc_gradients, self.d_monet.trainable_variables))    
        
        photo_disc_gradients = tape.gradient(photo_disc_loss, self.d_photo.trainable_variables)
        self.d_photo_optimizer.apply_gradients(zip(photo_disc_gradients, self.d_photo.trainable_variables))
        
        monet_gen_gradients = tape.gradient(monet_total_gen_loss, self.g_monet.trainable_variables)
        self.g_monet_optimizer.apply_gradients(zip(monet_gen_gradients, self.g_monet.trainable_variables))
        
        photo_gen_gradients = tape.gradient(photo_total_gen_loss, self.g_photo.trainable_variables)
        self.g_photo_optimizer.apply_gradients(zip(photo_gen_gradients, self.g_photo.trainable_variables))
            
        return {
            "monet_gen_loss": monet_total_gen_loss,
            "photo_gen_loss": photo_total_gen_loss,
            "monet_disc_loss": monet_disc_loss,
            "photo_disc_loss": photo_disc_loss
        } 

In [None]:
class CustomCallback(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        clear_output(wait=True)
        generate_images(g_monet, tes) 

In [None]:
with strategy.scope():
    cycle_gan = CycleGan(
                    g_monet,
                    g_photo,
                    d_monet,
                    d_photo
    )
    cycle_gan.compile(
                    g_monet_optimizer,
                    g_photo_optimizer,
                    d_monet_optimizer,
                    d_photo_optimizer,
                    get_disc_loss,
                    get_gen_loss,
                    get_cycle_consistency_loss,
                    get_identity_loss
    )

In [None]:
history = cycle_gan.fit(
    final_dataset,
    epochs=,
    callbacks=[CustomCallback()]
)

In [None]:
cycle_gan.d_photo.save('./d_photo',save_format='tf')

In [None]:
import shutil
shutil.make_archive("/kaggle/working/g_photo", 'zip', "/kaggle/working")

In [None]:
# Run the trained model on the test dataset
for inp in submit_ds.take(20):
  generate_images(g_monet, inp)

In [None]:
import PIL
! mkdir ./images
i = 1
for img in submit_ds:
    prediction = g_monet(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    im = PIL.Image.fromarray(prediction)
    im.save("./images/" + str(i) + ".jpg")
    i += 1

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/working")