Note: The markdown texts are my notes from understanding the code blocks and from ChatGPT's explanations of some of the blocks
- Kaggle competition "[I'm Something of a Painter Myself](https://www.kaggle.com/competitions/gan-getting-started/overview)"
- Original file copied in Feb. 2024 from Saravana Kumar, which was copied from UnfriendlyAI

# GAN
- **GAN** = Generative Adversarial Network
- Comprised of:
    1. **Generator**: takes random noise as input and generates images that resemble the training data
    2. **Discriminator**: like a classifier that determines between real images of Monet from the training data and replications from the generator
        - Trained on real and fake to learn to differentiate between them
- Training a GAN: **Adversarial training**
    - Generator tries to create realistic images that trick the discriminator, and the discriminator tries to catch the fakes from the generator
    - Both improve iteratively

## Two-Objective Discriminator
- Used with adversarial training
- Two discriminators in this GAN
    1. Real vs. Fake Discriminator: distinguishes between real images from dataset and fake ones from generator
    2. Attribute/Class Discriminator: focuses on specific attribute/class of images, discriminating between images with certain style/color/ect.
- Constraining the generator
    - Must generate images that meet criteria of *both* discriminators

In [None]:
BATCH_SIZE = 15 # Setting the batch size as 32 (I changed this to help with the OOM error received later on)

# Importing nesessary packages
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

import matplotlib.pyplot as plt
import numpy as np
import re
try:
    from kaggle_datasets import KaggleDatasets
except:
    pass

In [None]:
# Checking the Runtime TPU
try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
    print('Device:', tpu.master())
    tf.config.experimental_connect_to_cluster(tpu)
    tf.tpu.experimental.initialize_tpu_system(tpu)
    strategy = tf.distribute.experimental.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print('Number of replicas:', strategy.num_replicas_in_sync)

AUTOTUNE = tf.data.experimental.AUTOTUNE
AUTO = tf.data.experimental.AUTOTUNE    
print(tf.__version__)

In [None]:
# Getting the images and separating between Monet and photo images
GCS_PATH = KaggleDatasets().get_gcs_path("gan-getting-started")

MONET_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/monet_tfrec/*.tfrec'))
PHOTO_FILENAMES = tf.io.gfile.glob(str(GCS_PATH + '/photo_tfrec/*.tfrec'))

print(len(MONET_FILENAMES))
print(len(PHOTO_FILENAMES))

### def data_augment_flip
- Function that flips images with a 50% chance
    - Image has 50% chance that flipped
    - Imcreases randomness => increases diversity in training data
    - Improves discriminator's performance

In [None]:
def data_augment_flip(image):
    image = tf.image.random_flip_left_right(image)
    return image

### def decode_image & def read_tfrecord
- Used to decode JPEG images from TFRecord format
- Resizes, normalizes pixel values, prepares them for training

In [None]:
IMAGE_SIZE = [256, 256]

def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

def read_tfrecord(example):
    tfrecord_format = {
        "image": tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

### def load_dataset
- Takes a list of filenames TFRecord files
- Sends to previous function (read_tfrecord) => decodes image to JPEG
- Returns Tensorflow Dataset object

### ds Variables
- Loading Monet dataset
    - Batching the Monet dataset
- Loading photo dataset w. larger batch size
    - Batching the photo dataset
- Basically preprocessing datasets and batching to improve training performance

In [None]:
def load_dataset(filenames):
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTOTUNE)
    return dataset

monet_ds = load_dataset(MONET_FILENAMES).batch(1)
photo_ds = load_dataset(PHOTO_FILENAMES).batch(1)


fast_photo_ds = load_dataset(PHOTO_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_photo_ds = load_dataset(PHOTO_FILENAMES).take(1024).batch(32*strategy.num_replicas_in_sync).prefetch(32)
fid_monet_ds = load_dataset(MONET_FILENAMES).batch(32*strategy.num_replicas_in_sync).prefetch(32)

### def get_gan_dataset
- Preprocessing the Monet & photo datasets, mapping vectors to train on information in photos later
- Zips Monet & photo datasets together (makes them a pair to train)
- Combining Monet and photo datasets

In [None]:
def get_gan_dataset(monet_files, photo_files, augment=None, repeat=True, shuffle=True, batch_size=1):

    monet_ds = load_dataset(monet_files)
    photo_ds = load_dataset(photo_files)
    

        
    if repeat:
        monet_ds = monet_ds.repeat()
        photo_ds = photo_ds.repeat()
    if shuffle:
        monet_ds = monet_ds.shuffle(2048)
        photo_ds = photo_ds.shuffle(2048)
        
    monet_ds = monet_ds.batch(batch_size, drop_remainder=True)
    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
#     monet_ds = monet_ds.cache()
#     photo_ds = photo_ds.cache()
    if augment:
        monet_ds = monet_ds.map(augment, num_parallel_calls=AUTO)
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)
        
    monet_ds = monet_ds.prefetch(AUTO)
    photo_ds = photo_ds.prefetch(AUTO)
    
    gan_ds = tf.data.Dataset.zip((monet_ds, photo_ds))
    
    return gan_ds

### def get_photo_dataset
- Basically same function as above, except this is just for photo preprocessing
- Does not zip with Monet images

In [None]:
def get_photo_dataset(photo_files, augment=None, repeat=False, shuffle=False, batch_size=1):
    photo_ds = load_dataset(photo_files)
        
    if repeat:
        photo_ds = photo_ds.repeat()
    if shuffle:
        photo_ds = photo_ds.shuffle(2048)

    photo_ds = photo_ds.batch(batch_size, drop_remainder=True)
#     monet_ds = monet_ds.cache()
#     photo_ds = photo_ds.cache()
    if augment:
        photo_ds = photo_ds.map(augment, num_parallel_calls=AUTO)
    
    photo_ds = photo_ds.prefetch(AUTO)
    
    return photo_ds



In [None]:
# Getting a gan dataset with monet and photo
# Pairing the monet to the photos
final_dataset = get_gan_dataset(MONET_FILENAMES, PHOTO_FILENAMES, augment=data_augment_flip, repeat=True, shuffle=True, batch_size=BATCH_SIZE)

# FID (Frechet Inception Distance): How performance is scored in this Kaggle Competition
- Calculates FID scores (feature extraction with InceptionV3)

In [None]:
with strategy.scope():

    inception_model = tf.keras.applications.InceptionV3(input_shape=(256,256,3),pooling="avg",include_top=False)

    mix3  = inception_model.get_layer("mixed9").output
    f0 = tf.keras.layers.GlobalAveragePooling2D()(mix3)

    inception_model = tf.keras.Model(inputs=inception_model.input, outputs=f0)
    inception_model.trainable = False

    
    
    def calculate_activation_statistics_mod(images,fid_model):

            act=tf.cast(fid_model.predict(images), tf.float32)

            mu = tf.reduce_mean(act, axis=0)
            mean_x = tf.reduce_mean(act, axis=0, keepdims=True)
            mx = tf.matmul(tf.transpose(mean_x), mean_x)
            vx = tf.matmul(tf.transpose(act), act)/tf.cast(tf.shape(act)[0], tf.float32)
            sigma = vx - mx
            return mu, sigma
    myFID_mu2, myFID_sigma2 = calculate_activation_statistics_mod(fid_monet_ds,inception_model)        
    fids=[]


In [None]:
with strategy.scope():
    def calculate_frechet_distance(mu1,sigma1,mu2,sigma2):
        fid_epsilon = 1e-14
       
        covmean = tf.linalg.sqrtm(tf.cast(tf.matmul(sigma1,sigma2),tf.complex64))
#         isgood=tf.cast(tf.math.is_finite(covmean), tf.int32)
#         if tf.size(isgood)!=tf.math.reduce_sum(isgood):
#             return 0

        covmean = tf.cast(tf.math.real(covmean),tf.float32)
  
        tr_covmean = tf.linalg.trace(covmean)

        return tf.matmul(tf.expand_dims(mu1 - mu2, axis=0),tf.expand_dims(mu1 - mu2, axis=1)) + tf.linalg.trace(sigma1) + tf.linalg.trace(sigma2) - 2 * tr_covmean

    
    def FID(images,gen_model,inception_model=inception_model,myFID_mu2=myFID_mu2, myFID_sigma2=myFID_sigma2):
                inp = layers.Input(shape=[256, 256, 3], name='input_image')
                x  = gen_model(inp)
                x=inception_model(x)
                fid_model = tf.keras.Model(inputs=inp, outputs=x)
                
                mu1, sigma1= calculate_activation_statistics_mod(images,fid_model)

                fid_value = calculate_frechet_distance(mu1, sigma1,myFID_mu2, myFID_sigma2)

                return fid_value

### def down_sample
- Creates downsampling layer for CNN (convolutional neural network) (used in the CycleGAN later)
- Lowers the quality of the images to extract feature information
    - Used in discriminator
- Combines convolutonal layer with normalization

In [None]:
OUTPUT_CHANNELS = 3

def down_sample(filters, size, apply_instancenorm=True):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    layer = keras.Sequential()
    layer.add(layers.Conv2D(filters, size, strides=2, padding='same', kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        layer.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    layer.add(layers.LeakyReLU())

    return layer

### def up_sample
- Creates upsampling layer for CNN
- Increases quality/dimensions of photos 
    - Used in generator
- Combines transposed convolutional layer w. normalization

In [None]:
def up_sample(filters, size, apply_dropout=False):
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    layer = keras.Sequential()
    layer.add(layers.Conv2DTranspose(filters, size, strides=2, padding='same', kernel_initializer=initializer,use_bias=False))
    layer.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        layer.add(layers.Dropout(0.5))

    layer.add(layers.ReLU())

    return layer

# Generator
- Image to Image translation
- Downsampling layers --> Encoder
    - down_stack
    - Increasingly reducing spatial dimentions of the image
    - Extracts features
- Upsampling layers --> Decoder
    - up_stack (sequence of upsampling layers)
    - Increases the spatial dimensions
    - Reconstructs the image from the extracted features
- Generates realistic Monet images

In [None]:
def Generator():
    inputs = layers.Input(shape=[256,256,3])
    down_stack = [
        down_sample(64, 4, apply_instancenorm=False),# (size, 128, 128, 64)
        down_sample(128, 4),                         # (size, 64, 64, 128)
        down_sample(256, 4),                         # (size, 32, 32, 256)
        down_sample(512, 4),                         # (size, 16, 16, 512)
        down_sample(512, 4),                         # (size, 8, 8, 512)
        down_sample(512, 4),                         # (size, 4, 4, 512)
        down_sample(512, 4),                         # (size, 2, 2, 512)
        down_sample(512, 4),                         # (size, 1, 1, 512)
    ]

    up_stack = [
        up_sample(512, 4, apply_dropout=True),       # (size, 2, 2, 1024)
        up_sample(512, 4, apply_dropout=True),       # (size, 4, 4, 1024)
        up_sample(512, 4, apply_dropout=True),       # (size, 8, 8, 1024)
        up_sample(512, 4),                           # (size, 16, 16, 1024)
        up_sample(256, 4),                           # (size, 32, 32, 512)
        up_sample(128, 4),                           # (size, 64, 64, 256)
        up_sample(64, 4),                            # (size, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = layers.Conv2DTranspose(3, 4, strides=2, padding='same', kernel_initializer=initializer, activation='tanh') 
    # (size, 256, 256, 3)

    x = inputs

    # Downsampling through the model
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # Upsampling and establishing the skip connections
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = layers.Concatenate()([x, skip])

    x = last(x)

    return keras.Model(inputs=inputs, outputs=x)

# Discriminator (Headless)
- Takes images
- Downsampling layers --> Encoder
- Zero padding (x2) & convolutional layer
- Constructs Keras Model w. image and outputs a zero-padded Leaky ReLU output
- Discriminates between real Monet paintings and fake reproductions from GAN's Generator (above)

- **GOAL** = for the discriminator to become *so* good at distinguishing real and fake Monet images that the generator has trouble fooling it
    - When discriminator is good, the generator is also (results from the adversarial training: blessing & curse)
        - Issues that arise include the generator becoming good at producing one aspect of the image and forgetting the others (Model collapse)

## Headless?
- Means that the network doesn't have a final classification or output layer that produces specific prediction/decision
- Outputs intermediate features that are used later for processing/analysis
- This discriminator doesn't make a final output layer (=> binary classification: real or fake)
    - Feature extraction
    - **DOESN'T MEMORIZE CLASSIFICATION RESULTS, BUT INSTEAD LEARNS THE FEATURES THAT MAKE AN IMAGE REAL/FAKE**
        - This is a lot better, because it won't just memorize the results; it will actually learn

In [None]:
# Discriminator learns to determine between real and fake images
# Eventually helps generate images produced by GAN
def Discriminator():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    inp = layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    
    down1 = down_sample(64, 4, False)(x)       # (size, 128, 128, 64)
    down2 = down_sample(128, 4)(down1)         # (size, 64, 64, 128)
    down3 = down_sample(256, 4)(down2)         # (size, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (size, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1) # (size, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)
    leaky_relu = layers.LeakyReLU()(norm1)
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (size, 33, 33, 512)

    return tf.keras.Model(inputs=inp, outputs=zero_pad2) # Returns attributes of fake and real images that learned
    # Doesn't return a prediction = headless

## Head for two-objective discriminator for Monet
- Defines discriminator head = part of discriminator
- takes the feature maps from discriminator and => decision/score whether input is real/fake

In [None]:
def DHead():
    initializer = tf.random_normal_initializer(0., 0.02)
    
    inp = layers.Input(shape=[33, 33, 512], name='input_image')
    x = inp
    
    last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(x) # (size, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

## Why headless and then define a head?? Why not just have a head-discriminator?

1. Headless: 
    - Lacks classification layers => training for generator
    - Doesn't make any final decisions
    - Keeps the generator from being influenced by the discriminator's classifications too early

2. Head
    - After the generator has been trained with headless, classification layers (**head**) added to discriminator
    - Makes final decisions (real vs fake) during adversarial training process
        - When both discriminator and generator trained together
        
        
- - - - - - 
- Basically, we use the headless first so the generator actually learns the features 
- So it's not misled by the discriminator's classifications too early on in training (since they're both learning together)
    - Don't want to lead the blind if you're blind yourself
- Adding the head later helps with fine tuning the classification decisions

## Discriminator for Photos
- Analyzes the features specifically in photos
- Decides if it's real/fake

In [None]:
def DiscriminatorP():
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
    
    inp = layers.Input(shape=[256, 256, 3], name='input_image')
    x = inp
    
    down1 = down_sample(64, 4, False)(x)       # (size, 128, 128, 64)
    down2 = down_sample(128, 4)(down1)         # (size, 64, 64, 128)
    down3 = down_sample(256, 4)(down2)         # (size, 32, 32, 256)

    zero_pad1 = layers.ZeroPadding2D()(down3) # (size, 34, 34, 256)
    conv = layers.Conv2D(512, 4, strides=1, kernel_initializer=initializer, use_bias=False)(zero_pad1) # (size, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)
    leaky_relu = layers.LeakyReLU()(norm1)
    zero_pad2 = layers.ZeroPadding2D()(leaky_relu) # (size, 33, 33, 512)
    last = layers.Conv2D(1, 4, strides=1, kernel_initializer=initializer)(zero_pad2) # (size, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

## Why have two discriminators?
- They're each tailored to different input types
- Better analyze and decide authenticity 

In [None]:
with strategy.scope():
    monet_generator = Generator() # transforms photos to Monet-esque paintings
    photo_generator = Generator() # transforms Monet paintings to be more like photos

    monet_discriminator = Discriminator() # differentiates real Monet paintings and generated Monet paintings
    photo_discriminator = DiscriminatorP() # differentiates real photos and generated photos
    dHead1 = DHead() # Head for BCE
    dHead2 = DHead() # Head for hinge loss


# CycleGAN with DiffAugment and two-objective discriminator
- Training logic for the CycleGAN
    - Compiles
    - Train step method: custom, calculates losses and gradients

In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        monet_generator,
        photo_generator,
        monet_discriminator,
        photo_discriminator,
        dhead1,        
        dhead2,        
        lambda_cycle=3,
        lambda_id=3,
    ):
        super(CycleGan, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle
        self.lambda_id = lambda_id
        self.dhead1 = dhead1
        self.dhead2 = dhead2
        
    def compile(
        self,
        m_gen_optimizer,
        p_gen_optimizer,
        m_disc_optimizer,
        p_disc_optimizer,
        gen_loss_fn1,
        gen_loss_fn2,
        disc_loss_fn1,
        disc_loss_fn2,
        cycle_loss_fn,
        identity_loss_fn,
        aug_fn,

    ):
        super(CycleGan, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn1 = gen_loss_fn1
        self.gen_loss_fn2 = gen_loss_fn2
        self.disc_loss_fn1 = disc_loss_fn1
        self.disc_loss_fn2 = disc_loss_fn2
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        self.aug_fn = aug_fn

        self.step_num = 0
        
    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        batch_size = tf.shape(real_monet)[0]
        with tf.GradientTape(persistent=True) as tape:
        
            # photo to monet back to photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # monet to photo back to monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # generating itself
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            
            
            # Diffaugment
            both_monet = tf.concat([real_monet, fake_monet], axis=0)            
            
            aug_monet = self.aug_fn(both_monet)
            
            aug_real_monet = aug_monet[:batch_size]
            aug_fake_monet = aug_monet[batch_size:]
            
            
            # two-objective discriminator
            disc_fake_monet1 = self.dhead1(self.m_disc(aug_fake_monet, training=True), training=True)
            disc_real_monet1 = self.dhead1(self.m_disc(aug_real_monet, training=True), training=True)
            disc_fake_monet2 = self.dhead2(self.m_disc(aug_fake_monet, training=True), training=True)
            disc_real_monet2 = self.dhead2(self.m_disc(aug_real_monet, training=True), training=True)

            monet_gen_loss1 = self.gen_loss_fn1(disc_fake_monet1) 
            monet_head_loss1 = self.disc_loss_fn1(disc_real_monet1, disc_fake_monet1)
            monet_gen_loss2 = self.gen_loss_fn2(disc_fake_monet2)
            monet_head_loss2 = self.disc_loss_fn2(disc_real_monet2, disc_fake_monet2)

            monet_gen_loss = (monet_gen_loss1 + monet_gen_loss2) * 0.4
            monet_disc_loss = monet_head_loss1 + monet_head_loss2

            
            # discriminator used to check, inputing real images

            disc_real_photo = self.p_disc(real_photo, training=True)
            # discriminator used to check, inputing fake images

            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # evaluates generator loss

            photo_gen_loss = self.gen_loss_fn1(disc_fake_photo)
            
            # evaluates discriminator loss

            photo_disc_loss = self.disc_loss_fn1(disc_real_photo, disc_fake_photo)


            # evaluates total generator loss
            total_cycle_loss = self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle/ tf.cast(batch_size,tf.float32)) + self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle/ tf.cast(batch_size,tf.float32))

            # evaluates total generator loss
            total_monet_gen_loss = monet_gen_loss + total_cycle_loss +  self.identity_loss_fn(real_monet, same_monet, self.lambda_id / tf.cast(batch_size,tf.float32))
            total_photo_gen_loss = photo_gen_loss + total_cycle_loss + self.identity_loss_fn(real_photo, same_photo, self.lambda_id/ tf.cast(batch_size,tf.float32))



        # Calculate the gradients for generator and discriminator
        monet_generator_gradients = tape.gradient(total_monet_gen_loss,
                                                  self.m_gen.trainable_variables)
        photo_generator_gradients = tape.gradient(total_photo_gen_loss,
                                                  self.p_gen.trainable_variables)

        monet_discriminator_gradients = tape.gradient(monet_disc_loss,
                                                      self.m_disc.trainable_variables)
        photo_discriminator_gradients = tape.gradient(photo_disc_loss,
                                                      self.p_disc.trainable_variables)
        

        # Heads gradients
        monet_head_gradients = tape.gradient(monet_head_loss1,
                                                      self.dhead1.trainable_variables)

        self.m_disc_optimizer.apply_gradients(zip(monet_head_gradients,
                                                  self.dhead1.trainable_variables))       

        monet_head_gradients = tape.gradient(monet_head_loss2,
                                                      self.dhead2.trainable_variables)
        self.m_disc_optimizer.apply_gradients(zip(monet_head_gradients,
                                                  self.dhead2.trainable_variables))     
        
        
        
        # Apply the gradients to the optimizer
        self.m_gen_optimizer.apply_gradients(zip(monet_generator_gradients,
                                                 self.m_gen.trainable_variables))

        self.p_gen_optimizer.apply_gradients(zip(photo_generator_gradients,
                                                 self.p_gen.trainable_variables))

        self.m_disc_optimizer.apply_gradients(zip(monet_discriminator_gradients,
                                                  self.m_disc.trainable_variables))

        self.p_disc_optimizer.apply_gradients(zip(photo_discriminator_gradients,
                                                  self.p_disc.trainable_variables))
        
        return {
            "monet_head_loss1": monet_head_loss1, 
            "monet_head_loss2": monet_head_loss2, 
            "disc_real_monet": disc_real_monet1, 
            "disc_fake_monet": disc_fake_monet1, 
            "disc_real_monet2": disc_real_monet2, 
            "disc_fake_monet2": disc_fake_monet2, 
            "monet_gen_loss": monet_gen_loss, 
            "photo_disc_loss": photo_disc_loss, 
            }


## Two loss functions for discriminator
1. Loss function #1: uses hinge loss formulation
    - Real and generated data as inputs, and penalizes discriminator when incorecctly classifies incorrectly
    - Total loss = sum of losses
    
2. Loss funciton #2: uses Binary Cross Entropy (BCE) loss
    - Handles logits (unnormalized outputs) and applies sigmoid activation internally
    - Same as other loss funciton: penalizes when classifies incorrectly

- - - - 
- We have two loss functions defined because they are both helpful in different areas of training.

### Binary Cross Entropy Loss Explanation
- Binary classification tasks
- Measures difference bw 2 probability distrobutions: predicted probabilities & actual target probabilities

In [None]:
with strategy.scope(): # for TPU

    def discriminator_loss1(real, generated):
        real_loss = tf.math.minimum(tf.zeros_like(real), real-tf.ones_like(real))

        generated_loss = tf.math.minimum(tf.zeros_like(generated), -generated-tf.ones_like(generated))

        total_disc_loss = real_loss + generated_loss

        return tf.reduce_mean(-total_disc_loss * 0.5)

    def discriminator_loss2(real, generated):
        generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.ones_like(generated), generated)
        real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(real), real)
        total_disc_loss = real_loss + generated_loss

        return tf.reduce_mean(total_disc_loss * 0.5)



## Two loss functions for generator

1. Generator loss function #1: Simple negative mean formulation
    - Takes generated data as input => negative mean of generated data
    - Encourages the generator to produce data that discriminator classifies as real
    
    Soooo....
    1. Generates fake data w. generator
    2. Calculates the "distance" bw generated data & real data (w. binary cross entropy usually)
    3. Take negative mean of distance measure
    4. Basically penalizes the generator for generated data that's '"farther" than the real, input data
    
2. Generator loss function #2: Binary cross entropy (BCE) loss
    - Same as before basically...
    - Considered more sophisticated by ChatGPT

In [None]:
with strategy.scope():
    def generator_loss1(generated):
        return  tf.reduce_mean(-generated)

    def generator_loss2(generated):
        return tf.reduce_mean(tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(tf.zeros_like(generated), generated))

## So we're calculating loss for both discriminators and generators?
- Yep.
- Loss functions for both the discriminator and the generator
- Help train CycleGAN by => gradients that update parameters of generator & discriminator => better performance

## calc_cycle_loss
- Calculates cycle consistency loss
- Ensures that image translated from domain A -> domain B and then back again should keep ~ original image from A
- Can update LAMBDA (hyperparameter) to control influence of cycle gan consistency loss

In [None]:
with strategy.scope():
    def calc_cycle_loss(real_image, cycled_image, LAMBDA):
        loss1 = tf.reduce_sum(tf.abs(real_image - cycled_image))

        return LAMBDA * loss1 * 0.0000152587890625

## identity_loss
- Similar to calc_cycle_loss, but *preserves important features of original image during translation*

In [None]:
with strategy.scope():
    def identity_loss(real_image, same_image, LAMBDA):
        loss = tf.reduce_sum(tf.abs(real_image - same_image))
        return LAMBDA * 0.5 * loss * 0.0000152587890625

## Why 'with strategy.scope()'?
- Ensures that loss functions executed within specified distributed training context (TPU's / GPU's)
- Keeps training consistent

## Differentiable Augmentation for Data-Efficient GAN Training
- Data augmentation in training (brightness, saturation, contrast....ect.)
    - Combined => training data to make model more robust

In [None]:
with strategy.scope():
# Differentiable Augmentation for Data-Efficient GAN Training
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
# https://arxiv.org/pdf/2006.10738
# from https://github.com/mit-han-lab/data-efficient-gans/blob/master/DiffAugment_tf.py



    def DiffAugment(x, policy='', channels_first=False):
        if policy:
            if channels_first:
                x = tf.transpose(x, [0, 2, 3, 1])
            for p in policy.split(','):
                for f in AUGMENT_FNS[p]:
                    x = f(x)
            if channels_first:
                x = tf.transpose(x, [0, 3, 1, 2])
        return x


    def rand_brightness(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) - 0.5
        x = x + magnitude
        return x


    def rand_saturation(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) * 2
        x_mean = tf.reduce_sum(x, axis=3, keepdims=True) * 0.3333333333333333333
        x = (x - x_mean) * magnitude + x_mean
        return x


    def rand_contrast(x):
        magnitude = tf.random.uniform([tf.shape(x)[0], 1, 1, 1]) + 0.5
        x_mean = tf.reduce_sum(x, axis=[1, 2, 3], keepdims=True) * 5.086e-6
        x = (x - x_mean) * magnitude + x_mean
        return x

    def rand_translation(x, ratio=0.125):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        shift = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        translation_x = tf.random.uniform([batch_size, 1], -shift[0], shift[0] + 1, dtype=tf.int32)
        translation_y = tf.random.uniform([batch_size, 1], -shift[1], shift[1] + 1, dtype=tf.int32)
        grid_x = tf.clip_by_value(tf.expand_dims(tf.range(image_size[0], dtype=tf.int32), 0) + translation_x + 1, 0, image_size[0] + 1)
        grid_y = tf.clip_by_value(tf.expand_dims(tf.range(image_size[1], dtype=tf.int32), 0) + translation_y + 1, 0, image_size[1] + 1)
        x = tf.gather_nd(tf.pad(x, [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_x, -1), batch_dims=1)
        x = tf.transpose(tf.gather_nd(tf.pad(tf.transpose(x, [0, 2, 1, 3]), [[0, 0], [1, 1], [0, 0], [0, 0]]), tf.expand_dims(grid_y, -1), batch_dims=1), [0, 2, 1, 3])
        return x


    def rand_cutout(x, ratio=0.5):
        batch_size = tf.shape(x)[0]
        image_size = tf.shape(x)[1:3]
        cutout_size = tf.cast(tf.cast(image_size, tf.float32) * ratio + 0.5, tf.int32)
        offset_x = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[0] + (1 - cutout_size[0] % 2), dtype=tf.int32)
        offset_y = tf.random.uniform([tf.shape(x)[0], 1, 1], maxval=image_size[1] + (1 - cutout_size[1] % 2), dtype=tf.int32)
        grid_batch, grid_x, grid_y = tf.meshgrid(tf.range(batch_size, dtype=tf.int32), tf.range(cutout_size[0], dtype=tf.int32), tf.range(cutout_size[1], dtype=tf.int32), indexing='ij')
        cutout_grid = tf.stack([grid_batch, grid_x + offset_x - cutout_size[0] // 2, grid_y + offset_y - cutout_size[1] // 2], axis=-1)
        mask_shape = tf.stack([batch_size, image_size[0], image_size[1]])
        cutout_grid = tf.maximum(cutout_grid, 0)
        cutout_grid = tf.minimum(cutout_grid, tf.reshape(mask_shape - 1, [1, 1, 1, 3]))
        mask = tf.maximum(1 - tf.scatter_nd(cutout_grid, tf.ones([batch_size, cutout_size[0], cutout_size[1]], dtype=tf.float32), mask_shape), 0)
        x = x * tf.expand_dims(mask, axis=3)
        return x


    AUGMENT_FNS = {
        'color': [rand_brightness, rand_saturation, rand_contrast],
        'translation': [rand_translation],
        'cutout': [rand_cutout],
}
    def aug_fn(image):
        return DiffAugment(image,"color,translation,cutout")

## Optimizers
- Adam = Adaptive Moment Estimation = combo of momentum and Root Mean Square Propogation (RMSProp)
    - Dynamically changes learing rate for each parameter
    - Momentum to accelerate gradient decent
    - Bias correction to adjust estimates for moments of gradients
    - Efficient
    
1. Learning rate = 2e-4
    - Step size that optimizer updates weights during training
    - Stable, gradual updates to model parameters
2. Beta value = 0.5
    - Controls exponential decay rate for mean of gradients
    - Past gradients have less influence on current update

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

## cycle_gan_model
- Defining cycle gan!
- Takes the generators and discriminators we defined earlier and combo's with head from discriminator!
- Learns mapping bw 2 different domains (Monet & photos) and generates realistic images that mimic Monet

In [None]:
with strategy.scope():
    cycle_gan_model = CycleGan(
        monet_generator, photo_generator, monet_discriminator, photo_discriminator, dHead1,  dHead2
    )



# Testing with different epoch #'s and alpha and beta values
- Compile to prepare model for training by specifying how it should be optimized and what loss functions to use

In [None]:
with strategy.scope():

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn1 = generator_loss1,
        gen_loss_fn2 = generator_loss2,
        disc_loss_fn1 = discriminator_loss1,
        disc_loss_fn2 = discriminator_loss2,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss,
        aug_fn = aug_fn ,


    )

## cycle_gan_model.fit
- Fitting w. specific training configuration
- Evaluates FID distance bw photos and generated images w. generator

In [None]:
cycle_gan_model.fit(final_dataset,steps_per_epoch=1407, epochs=26)
FID(fid_photo_ds, monet_generator) 

# OOM Error:
- My GPU often ran out of memory at this point
- - - 
Solutions produced by ChatGPT:

1. **Reduce Batch Size:** Decrease the batch size used during training. A smaller batch size requires less memory but may also slow down training.

2. **Reduce Model Complexity:** Simplify the architecture of your neural network, especially in terms of the number of layers, units, and parameters.

3. **Use a Different GPU:** If possible, switch to a GPU with more memory or use a cloud-based service that provides access to more powerful GPUs.

4. **Memory Management:** Check if there are any memory leaks in your code or if there are unnecessary tensors being stored in memory during training.

5. **Gradient Accumulation:** Implement gradient accumulation techniques to simulate larger batch sizes without increasing memory usage.

6. **Memory Profiling:** Use TensorFlow's memory profiler to analyze memory usage and identify areas where memory is being consumed excessively.

7. **Limit Data Loading:** If using large datasets, consider loading data in batches rather than all at once to reduce memory consumption.

## Compiling w. a Lower Learning Rate (2e-4 -> 1e-4)
- Updates parameters more slowly during training

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(1e-4, beta_1=0.5)

In [None]:
with strategy.scope():

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn1 = generator_loss1,
        gen_loss_fn2 = generator_loss2,
        disc_loss_fn1 = discriminator_loss1,
        disc_loss_fn2 = discriminator_loss2,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss,
        aug_fn = aug_fn ,
    )

In [None]:
cycle_gan_model.fit(final_dataset,steps_per_epoch=1407, epochs=18)
FID(fid_photo_ds,monet_generator) 

## Compiling w. an even Smaller Learning Rate (1e-4 -> 1e-5)

In [None]:
with strategy.scope():
    monet_generator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
    photo_generator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)

    monet_discriminator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)
    photo_discriminator_optimizer = tf.keras.optimizers.Adam(1e-5, beta_1=0.5)

In [None]:
with strategy.scope():

    cycle_gan_model.compile(
        m_gen_optimizer = monet_generator_optimizer,
        p_gen_optimizer = photo_generator_optimizer,
        m_disc_optimizer = monet_discriminator_optimizer,
        p_disc_optimizer = photo_discriminator_optimizer,
        gen_loss_fn1 = generator_loss1,
        gen_loss_fn2 = generator_loss2,
        disc_loss_fn1 = discriminator_loss1,
        disc_loss_fn2 = discriminator_loss2,
        cycle_loss_fn = calc_cycle_loss,
        identity_loss_fn = identity_loss,
        aug_fn = aug_fn ,
    )

In [None]:
cycle_gan_model.fit(final_dataset,steps_per_epoch=1407, epochs=8)
FID(fid_photo_ds,monet_generator) 

## Showing the images

In [None]:
ds_iter = iter(photo_ds)
for n_sample in range(8):
        example_sample = next(ds_iter)
        generated_sample = monet_generator(example_sample)
        
        f = plt.figure(figsize=(32, 32))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

## Showing the monet images

In [None]:
ds_iter = iter(monet_ds)
for n_sample in range(8):

        example_sample = next(ds_iter)
        generated_sample = photo_generator(example_sample)
        
        f = plt.figure(figsize=(24, 24))
        
        plt.subplot(121)
        plt.title('Input image')
        plt.imshow(example_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        
        plt.subplot(122)
        plt.title('Generated image')
        plt.imshow(generated_sample[0] * 0.5 + 0.5)
        plt.axis('off')
        plt.show()

In [None]:
import PIL
! mkdir ../images

## Saving Predictions as jpg's
- Iterates over images in photo and produces predictions from generator
- Converts predictions to image format and saves them

In [None]:
%%time
i = 1
for img in fast_photo_ds:
    prediction = monet_generator(img, training=False).numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    for pred in prediction:
        im = PIL.Image.fromarray(pred)
        im.save("../images/" + str(i) + ".jpg")
        i += 1
    

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")

## Saving weights
- Weights = trainable parameters of specified models, internal state of model that adjusted during training to minimize loss function & improve performance
- Later loaded to restore model's states
- Can reuse wo. retraining models from scratch

In [None]:

dHead1.save_weights("dHead1.h5")
dHead2.save_weights("dHead2.h5")

photo_generator.save_weights("photo_generator.h5")
monet_generator.save_weights("monet_generator.h5")

photo_discriminator.save_weights("photo_discriminator.h5")
monet_discriminator.save_weights("monet_discriminator.h5")