# Import Libraries

Ensure we are using the latest version of tfa addons?

In [1]:
!pip install tensorflow_addons

[0m

In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np

import re

from functools import partial

In [3]:
#Make our Exception reporting verbose for error checking
%xmode Verbose

Exception reporting mode: Verbose


# Configure Tensorflow

In [4]:
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
#BATCH_SIZE = 10 * strategy.num_replicas_in_sync # reset BATCH SIZE to a single number?
BATCH_SIZE = 10
    
print(tf.__version__)

Device: grpc://10.0.0.2:8470
Number of replicas: 8
2.11.0


# Load datasets

## Connect to Google Cloud Storage
N.B. Need to enable Google CLoud SDK in the Addons menu

In [5]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
user_credential = user_secrets.get_gcloud_credential()
user_secrets.set_tensorflow_credential(user_credential)

In [6]:
GCS_PATH = KaggleDatasets().get_gcs_path()
print(GCS_PATH)

BackendError: Unexpected response from the service. Response: {'errors': ['Google Cloud SDK must be authorized before copying private datasets.'], 'error': {'code': 9, 'details': []}, 'wasSuccessful': False}.

## Define helper functions

In [None]:
IMAGE_SIZE = [256, 256] # amend our image size to regular
OUTPUT_CHANNELS = 3
BATCH_SIZE =  10
EPOCHS_NUM = 10

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.image.resize(image, [*IMAGE_SIZE])
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "image": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

# This function extracts the image data from the files to create our dataset
def load_dataset(filenames, labeled=True, ordered=False):
    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(
        partial(read_tfrecord), num_parallel_calls=AUTOTUNE
    )
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

In [None]:
def get_gan_dataset(source_files, target_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    source_ds = load_dataset(source_files)
    target_ds = load_dataset(target_files)
    
    if augment:
        source_ds = source_ds.map(augment, num_parallel_calls=AUTO)
        target_ds = target_ds.map(augment, num_parallel_calls=AUTO)
        
    if repeat:
        source_ds = source_ds.repeat()
        target_ds = target_ds.repeat()
    if shuffle:
        source_ds = source_ds.shuffle(2048)
        target_ds = target_ds.shuffle(2048)
        
    source_ds = source_ds.batch(batch_size, drop_remainder=True)
    target_ds = target_ds.batch(batch_size, drop_remainder=True)
    #source_ds = source_ds.cache()
    #target_ds = target_ds.cache()
    source_ds = source_ds.prefetch(AUTOTUNE)
    target_ds = target_ds.prefetch(AUTOTUNE)
    
    gan_ds = tf.data.Dataset.zip((source_ds, target_ds))
    
    return gan_ds

## Load our source and target datasets
I.E. what style of images we are converting from and to

In [None]:
source_path = '/art_portraits_tfrec/photo_tfrec/*.tfrec'
target_path = '/art_portraits_tfrec/ukiyo_tfrec/*.tfrec'

SOURCE_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + source_path))
print('Source TFRecord Files:', len(SOURCE_FILENAMES))

TARGET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + target_path))
print('Target TFRecord Files:', len(TARGET_FILENAMES))

In [None]:
source_ds = load_dataset(SOURCE_FILENAMES, labeled=False).batch(1)
target_ds = load_dataset(TARGET_FILENAMES, labeled=False).batch(1)

In [None]:
# We will also load some datasets for FID implementation
fid_source_ds = load_dataset(SOURCE_FILENAMES).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_target_ds = load_dataset(TARGET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

In [None]:
full_dataset = get_gan_dataset(SOURCE_FILENAMES, TARGET_FILENAMES, augment=None, repeat=True, shuffle=False, batch_size=BATCH_SIZE)

### Check that our data has loaded as expected

In [None]:
buffer_dataset = iter(full_dataset)
example_source , example_target = next(buffer_dataset)

plt.subplot(121)
plt.title('Source')
plt.imshow(example_source[0] * 0.5 + 0.5)

plt.subplot(122)
plt.title('target')
plt.imshow(example_target[0] * 0.5 + 0.5)

# Using Frechet Inception Distance as a metric
FID is a metric that calculates the distance between feature vectors the distribution of the generated image against the real ones that were used to train the generator. We calculate this by using a pre-trained inception model for image classification. With FID, the lower the score, the more similiar the images are i.e. the better the GAN is at producing them.

In [None]:
with strategy.scope():

    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False) # amended input shape

    mix3  = inception_model.get_layer("mixed9").output
    f0 = tf.keras.layers.GlobalAveragePooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False

    
    
    def calculate_activation_statistics_mod(images,fid_model):

            act=tf.cast(fid_model.predict(images), tf.float32)

            mu = tf.reduce_mean(act, axis=0)
            mean_x = tf.reduce_mean(act, axis=0, keepdims=True)
            mx = tf.matmul(tf.transpose(mean_x), mean_x)
            vx = tf.matmul(tf.transpose(act), act)/tf.cast(tf.shape(act)[0], tf.float32)
            sigma = vx - mx
            return mu, sigma
    myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(fid_target_ds,inception_model)        
    fids=[]

In [None]:
with strategy.scope():
    def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
       
        covmean = tf.linalg.sqrtm(tf.cast(tf.matmul(sigma1,sigma2),tf.complex64))
#         isgood=tf.cast(tf.math.is_finite(covmean), tf.int32)
#         if tf.size(isgood)!=tf.math.reduce_sum(isgood):
#             return 0

        covmean = tf.cast(tf.math.real(covmean),tf.float32)
  
        tr_covmean = tf.linalg.trace(covmean)


        return tf.matmul(tf.expand_dims(mu1 - mu2, axis=0),tf.expand_dims(mu1 - mu2, axis=1)) + tf.linalg.trace(sigma1) + tf.linalg.trace(sigma2) - 2 * tr_covmean


    
    
    def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
                inp = layers.Input(shape=[512, 512, 3], name='input_image') # amended input shape
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
                mu1, sigma1= calculate_activation_statistics_mod(images,fid_model)

                fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)


                return fid_value

# Building our Generator network
We'll be using a UNET architecture for our CycleGAN. (Paper: https://arxiv.org/pdf/1505.04597v1.pdf)

First, we define upsample and downsample functions. These will reduce the 2D dimensions of the image by the stride i.e. the size of the steps that the filter takes when performing the convolution/deconvolution.

We are also using Instance Normalisation (https://arxiv.org/pdf/1607.08022.pdf) rather than Batch Normalisation, as this is faster than applying normalisation to sets of images and obtains similiar results.

In [None]:
OUTPUT_CHANNELS = 3

def downsample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = keras.Sequential()
    result.add(layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(layers.Dropout(0.5))

    result.add(layers.ReLU())

    return result

With these functions defined, we can now build our generator model. We do this by downsampling then upsampling the image, whil establishing long skip connections. A skip connection is where we feded the output from an earlier layer into a later one, "skipping" some layers. This solves the degradation problem, where adding deeper layers reduces the performance of the model, as deep layers don't learn as effectively as shallow ones (https://www.analyticsvidhya.com/blog/2021/08/all-you-need-to-know-about-skip-connections/). In UNET architecture, our skip connections concatenate the output of each downsample layer onto the input of the equivalent upsample layer:

TODO: adjust our up and downsample layer to accomodate the new size of our input images.
* Could add an additional layer?
* Resize the images prior to loading into model?
* Or just leave it as is and see what happens?

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3]) # adjusted to 512, 512

    # bs = batch size
    down_stack = [
        downsample(64, 4, apply_instancenorm=False), # (bs, 256, 256, 32)
        downsample(128, 4), # (bs, 128, 128, 64)
        downsample(256, 4), # (bs, 64, 64, 128)
        downsample(512, 4), # (bs, 32, 32, 256)
        downsample(512, 4), # (bs, 16, 16, 512)
        downsample(512, 4), # (bs, 8, 8, 512)
        downsample(512, 4), # (bs, 4, 4, 512)
        downsample(512, 4), # (bs, 2, 2, 512)
        #downsample(512, 4), # (bs, 1, 1, 512) # added another downsample layer
    ]

    up_stack = [
        #upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024) # added another upsample layer
        upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
        upsample(512, 4, apply_dropout=True), # (bs, 16, 16, 1024)
        upsample(512, 4), # (bs, 32, 32, 1024)
        upsample(256, 4), # (bs, 64, 64, 512)
        upsample(128, 4), # (bs, 128, 128, 256)
        upsample(64, 4), # (bs, 256, 256, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(OUTPUT_CHANNELS, 4,
                                  strides=2,
                                  padding='same',
                                  kernel_initializer=initializer,
                                  activation='tanh') # (bs, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

## Visualising the Generator model

In [None]:
from tensorflow.keras.utils import plot_model

In [None]:
gen_viz = Generator()
plot_model(gen_viz, show_shapes=True)

In [None]:
gen_viz.summary()

## Building the discriminator

In [None]:
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = layers.Input(shape=[256, 256, 3], name='input_image') # adjusted to 512x512

    x = inp

    down1 = downsample(64, 4, False)(x) # (bs, 256, 256, 32)
    down2 = downsample(128, 4)(down1) # (bs, 128, 128, 64)
    down3 = downsample(256, 4)(down2) # (bs, 64, 64, 128)
    #down4 = downsample(512, 4)(down3) # (bs, 32, 32, 256) # added another downsample layer

    zero_pad1 = layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256) # should have fixed this too?
    conv = layers.Conv2D(512, 4, strides=1,
                         kernel_initializer=initializer,
                         use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = layers.LeakyReLU()(norm1)

    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

    last = layers.Conv2D(1, 4, strides=1,
                         kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

## Visuaising the Discriminator model

In [None]:
Discriminator_viz = Discriminator()
plot_model(Discriminator_viz, show_shapes=True,show_dtype=True,show_layer_names=False , expand_nested=True)

In [None]:
Discriminator_viz.summary()

In [None]:
with strategy.scope():
    source_generator = Generator() # transforms source images to the target style
    target_generator = Generator() # transforms target style images to be more the source

    source_discriminator = Discriminator() # differentiates real source images and generated sourcs
    target_discriminator = Discriminator() # differentiates real target style images and generated images

We can at least quickly and easily test the Generator contructor without having to build the model

In [None]:
to_target = source_generator(example_source)

plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(example_source[0] * 0.5 + 0.5)

plt.subplot(1, 2, 2)
plt.title("Target Image Style")
plt.imshow(to_target[0] * 0.5 + 0.5)
plt.show()

# Image Augmentation
We can make the model more data-efficient by using Differential Augmentation, as described here: (https://arxiv.org/pdf/2006.10738.pdf).

In [None]:
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738

#import tensorflow as tf

with strategy.scope():
    def DiffAugment(x, policy='', channels_first=False):
        if policy:
            if channels_first:
                x = tf.transpose(x, [0, 2, 3, 1])
            for p in policy.split(','):
                for f in AUGMENT_FNS[p]:
                    x = f(x)
            if channels_first:
                x = tf.transpose(x, [0, 3, 1, 2])
        return x


    def rand_brightness(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
        x = x + magnitude
        return x


    def rand_saturation(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
        x_mean = tf.reduce_mean(x, axis=3, keepdims=True)
        x = (x - x_mean) * magnitude + x_mean
        return x


    def rand_contrast(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
        x_mean = tf.reduce_mean(x, axis=[1, 2, 3], keepdims=True)
        x = (x - x_mean) * magnitude + x_mean
        return x


    def rand_translation(x, ratio=0.125):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
        translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
        grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
        grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
        x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
        x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
        return x


    def rand_cutout(x, ratio=0.5):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
        offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
        grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
        cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
        mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
        cutout_grid = tf.maximum(cutout_grid, 0)
        cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
        mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
        x = x * tf.expand_dims(mask, axis=3)
        return x


    AUGMENT_FNS = {
        'color': [rand_brightness, rand_saturation, rand_contrast],
        'translation': [rand_translation],
        'cutout': [rand_cutout],
    }
    # To use in our code, we will create a wrapper function and incorporate it into our training step.
    def aug_fn(image):
         return DiffAugment(image,"color,translation,cutout")

# Building the CycleGAN
Our main CycleGAN class will be doing a lot: we will inherit from the keras.Model class in order to use the.fit() method when training our GAN.

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        target_generator,
        source_generator,
        target_discriminator,
        source_discriminator,
        lambda_cycle=10,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = target_generator
        self.p_gen = source_generator
        self.m_disc = target_discriminator
        self.p_disc = source_discriminator
        self.lambda_cycle = lambda_cycle
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn,
        disc_loss_fn,
        cycle_loss_fn,
        identity_loss_fn,
        jit_compile=False #disables XLA, might fix the problem?
    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        
    def train_step(self, batch_data):
        # To use DiffAug, we will concatenate the real and fake images, run DiffAug, then separate them out again:
        real_target, real_source = batch_data
        batch_size = tf.shape(real_target)[0]
        
        with tf.GradientTape(persistent=True) as tape:
            # source to target back to source
            fake_target = self.m_gen(real_source, training=True)
            cycled_source = self.p_gen(fake_target, training=True)

            # target to source back to target
            fake_source = self.p_gen(real_target, training=True)
            cycled_target = self.m_gen(fake_source, training=True)

            # generating itself
            same_target = self.m_gen(real_target, training=True)
            same_source = self.p_gen(real_source, training=True)
            
            # Applying DiffAug
            both_target = tf.concat([real_target, fake_target], axis=0)            

            aug_target = aug_fn(both_target)

            aug_real_target = aug_target[:batch_size]
            aug_fake_target = aug_target[batch_size:]

            # discriminator used to check, inputing real images
            disc_real_target = self.m_disc(aug_real_target, training=True) # Use aug_real_target
            disc_real_source = self.p_disc(real_source, training=True)

            # discriminator used to check, inputing fake images
            disc_fake_target = self.m_disc(aug_fake_target, training=True) # Use aug_fake_target
            disc_fake_source = self.p_disc(fake_source, training=True)

            # evaluates generator loss
            target_gen_loss = self.gen_loss_fn(disc_fake_target)
            source_gen_loss = self.gen_loss_fn(disc_fake_source)

            # evaluates total cycle consistency loss
            total_cycle_loss = self.cycle_loss_fn(real_target, cycled_target, self.lambda_cycle) + self.cycle_loss_fn(real_source, cycled_source, self.lambda_cycle)

            # evaluates total generator loss
            total_target_gen_loss = target_gen_loss + total_cycle_loss + self.identity_loss_fn(real_target, same_target, self.lambda_cycle)
            total_source_gen_loss = source_gen_loss + total_cycle_loss + self.identity_loss_fn(real_source, same_source, self.lambda_cycle)

            # evaluates discriminator loss
            target_disc_loss = self.disc_loss_fn(disc_real_target, disc_fake_target)
            source_disc_loss = self.disc_loss_fn(disc_real_source, disc_fake_source)

        # Calculate the gradients for generator and discriminator
        target_generator_gradients = tape.gradient(total_target_gen_loss,
                                                  self.m_gen.trainable_variables)
        source_generator_gradients = tape.gradient(total_source_gen_loss,
                                                  self.p_gen.trainable_variables)

        target_discriminator_gradients = tape.gradient(target_disc_loss,
                                                      self.m_disc.trainable_variables)
        source_discriminator_gradients = tape.gradient(source_disc_loss,
                                                      self.p_disc.trainable_variables)

        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(target_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(source_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(target_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(source_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "target_gen_loss": total_target_gen_loss,
            "source_gen_loss": total_source_gen_loss,
            "target_disc_loss": target_disc_loss,
            "source_disc_loss": source_disc_loss
        }

## Define Loss Functions

Each model requires a loss function. Our discriminator will compare real images to a matrix of 1s and fake images to a matrix of 0s, as a perfect discriminator will output pure 1s or 0s for real or fake. We will use the BinaryCrossEntropy loss function for this, and output the average of the real and generated loss:


In [None]:
with strategy.scope():
    def discriminator_loss(real, generated):
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(real), real)

        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated)

        total_disc_loss = real_loss + generated_loss

        return total_disc_loss * 0.5

Our generator is easier: as we are only checking the generated images, which we want the discriminator to evaluate as a matrix of 1s:

In [None]:
with strategy.scope():
    def generator_loss(generated):
        return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)

To calculate the cycle consistency loss, we compare our original image to the twice transformed one. They should be identical, so we can take the average of their difference to be the loss:

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1

We can also calculate the identity loss of our generators. For example, if we feed a source into our source generator, our output should be identical to the orginal image.

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_mean(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss

# Training the GAN
The big moment! We compile our model, and as we inherited from the Keras.Model class, we can use the fit() method to train our model

In [None]:
with strategy.scope():
    target_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    source_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    target_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    source_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        target_generator, source_generator, target_discriminator, source_discriminator
    )

    cycle_gan_model.compile(
        m_gen_optimizer = target_generator_optimizer,
        p_gen_optimizer = source_generator_optimizer,
        m_disc_optimizer = target_discriminator_optimizer,
        p_disc_optimizer = source_discriminator_optimizer,
        gen_loss_fn = generator_loss,
        disc_loss_fn = discriminator_loss,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss
    )

Remember, as we have augmented our dataset, we will need to reduce the size

In [None]:
def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

n_target_samples = count_data_items(TARGET_FILENAMES)
n_source_samples = count_data_items(SOURCE_FILENAMES)

epoch_steps = (max(n_target_samples, n_source_samples)//4)
#epoch_steps = 600

cycle_gan_model.fit(
    full_dataset,
    epochs=EPOCHS_NUM,
    steps_per_epoch=epoch_steps, #error from this variable, hard-code it?
)
print(FID(fid_source_ds,target_generator))

# Visualise some outputs

In [None]:
_, ax = plt.subplots(5, 2, figsize=(12, 12))
for i, img in enumerate(source_ds.take(5)):
    prediction = photo_generator(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input image")
    ax[i, 1].set_title("Target style image")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")
plt.show()

# Save the model
TODO