# GANS/Cyclic GANS For Data Augmentation and Neural Style Transfer

In [None]:
import tensorflow as tf
import os
import glob
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
import shutil
import random
import itertools
import math
import time
import wandb

from skimage.io import imread
from skimage.transform import resize
from mpl_toolkits.axes_grid1 import ImageGrid

# Model class
from tensorflow.keras.models import Model, Sequential

# Layers
from tensorflow.keras.layers import (LeakyReLU, BatchNormalization, Dropout, 
                                     Flatten, Conv2D, Dense, MaxPool2D, Conv2DTranspose, 
                                     GlobalMaxPool2D, Reshape, BatchNormalization, Input, Embedding, multiply)
from tensorflow.keras.layers import Layer
from tensorflow.keras import layers

# Optimizer
from tensorflow.keras.optimizers import Adam

# Loss function
from tensorflow.keras.losses import BinaryCrossentropy, MeanAbsoluteError 

# Data loader
from tensorflow.keras.utils import Sequence

# Metrics
from tensorflow.keras.metrics import BinaryAccuracy, Recall, Precision, MSE

# ImageLoader
from skimage.io import imread

# Weights Initializer
from tensorflow.keras.initializers import RandomNormal, GlorotNormal, GlorotUniform

# Callbacks
from tensorflow.keras.callbacks import Callback, ModelCheckpoint

# Initializers
from tensorflow.keras.initializers import RandomNormal, RandomUniform

In [None]:
!pip install tensorflow-addons

In [None]:
import tensorflow_addons as tfa

In [None]:
# To limit GPU VRAM allocation by tensorflow
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

## Pre-Processing The Data

In [None]:
class DataLoader(Sequence):
    
    def __init__(self, im_dir, labels= None, batch_size= 32, output_dim= None, classification= True, shuffle= True):
        
        self.im_dir = im_dir
        self.classification = classification
        self.shuffle = shuffle
        self.resize = resize
        self.output_dim = output_dim
        
        self.labels = {}
        if isinstance(labels, list):
            for i, label in enumerate(labels):
                self.labels[label] = i
        elif isinstance(labels, dict):
            self.labels = labels
        else:
            for root_dir, subdirs, files in os.walk(self.im_dir):
                for i, subdir in enumerate(subdirs):
                    self.labels[subdir] = i
        
            
        self.images = glob.glob(f"{self.im_dir}/*/*")
        random.shuffle(self.images)
        self.batch_size = len(self.images) if batch_size == -1 else batch_size
        
        self.on_epoch_end()
        
    
    def __len__(self):
        return int(np.floor(len(self.images)/self.batch_size))
    
    def __getitem__(self, index):
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        im_files_batch = [self.images[idx] for idx in indexes]
        if self.output_dim:
            batch_images = np.array([resize(imread(file), self.output_dim, preserve_range= True) for file in im_files_batch], dtype= np.uint8)
            
        else:
            batch_images = np.arrary([imread(file) for file in im_files_batch], dtype= np.uint8)
        
        if self.classification:
            batch_images = np.asanyarray(batch_images, dtype= np.float32)/255.
        else:
            batch_images = (np.asanyarray(batch_images, dtype= np.float32) - 127.5)/127.5
            
        batch_labels = np.empty((self.batch_size), dtype= np.uint8)
        for i, im in enumerate(im_files_batch):
            for label in self.labels.keys():
                if label in im:
                    batch_labels[i] = self.labels[label]
        return batch_images, batch_labels.reshape(-1, 1)
    
    def on_epoch_end(self):
        "Updates all indexes after the end of each epoch"
        self.indexes = np.arange(len(self.images))
        if self.shuffle:
            np.random.shuffle(self.indexes)
            
        
    def get_all_data(self, batched_data):
        return

In [None]:
train_horses_cows = DataLoader(im_dir= "dataset/Newdata/Train", classification= True, output_dim= (128, 128, 3), batch_size= 16)
val_horses_cows = DataLoader(im_dir= "dataset/Newdata/Test", classification= True, output_dim= (128, 128, 3))

In [None]:
# Get a batch of data
horses_cows_iter = iter(train_horses_cows)
batch_data = next(horses_cows_iter)

In [None]:
horse_batch = [data for i, data in enumerate(batch_data[0]) if batch_data[1][i] == 1]
cow_batch = [data for i, data in enumerate(batch_data[0]) if batch_data[1][i] == 0]

In [None]:
random_horse = np.random.choice(range(len(horse_batch)), size= 4, replace= False)
random_cow = np.random.choice(range(len(cow_batch)-1), size= 4, replace= False)
random_horse, random_cow

In [None]:
cow_images, horse_images = [cow_batch[i] for i in random_cow], [horse_batch[i] for i in random_horse]

In [None]:
fig = plt.figure(figsize=(10., 10.))
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 4),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, [*cow_images, *horse_images]):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)

plt.show()

## Define Metrics For Classification
- Accuracy
- Precision
- Recall

In [None]:
def show_metrics_classification(model_history):
    metrics = ["_".join(metric.split("_")[1:]) for metric in model_history.history.keys() if "val" not in metric]
    print(metrics)
    colors = list(itertools.combinations(['b', 'g', 'r', 'c', 'm', 'y', 'k'], 2))
    
    f = plt.figure()
    f.set_figwidth(2.5*len(metrics))
    f.set_figheight(3*len(metrics))
    for i, metric in enumerate(metrics):
        plt.subplot(math.ceil(len(metrics)/2), 2, i+1)
        color_plts = np.random.randint(0, len(colors))
        plt.plot(model_history.history[f"train_{metric}"], color= colors[color_plts][0])
        plt.plot(model_history.history[f"val_{metric}"], color= colors[color_plts][1])
        plt.ylabel(metric)
        plt.xlabel("epochs")
        plt.legend(["train", "val"], loc= "upper left")
        plt.title(f"{metric.title()} vs Epochs")
    plt.show()
    

## Building The Classifier

In [None]:
class Classifier(Model):
    
    def __init__(self, name):
        super(Classifier, self).__init__(name= name)
        
    def build(self, input_shape):
        
        self.conv1 = Conv2D(filters= 50, kernel_size= 3, activation= 'relu', padding= 'same', kernel_initializer= 'glorot_uniform', input_shape= input_shape)
        self.maxpool1 = MaxPool2D(pool_size= (2, 2))
        
        self.conv2 = Conv2D(filters= 20, kernel_size= 3, activation= 'relu', padding= 'valid', kernel_initializer= 'glorot_uniform')
        self.maxpool2 = MaxPool2D(pool_size= (2, 2))
        
        self.conv3 = Conv2D(filters= 5, kernel_size= 3, activation= 'relu', padding= 'valid', kernel_initializer= 'glorot_uniform')
        self.maxpool3 = MaxPool2D(pool_size= (2, 2))
        
        self.flatten = Flatten()
        self.dense1 = Dense(units= 30, activation= 'relu', kernel_initializer= 'glorot_uniform')
        self.dense2 = Dense(units= 20, activation= 'relu', kernel_initializer= 'glorot_uniform')
        self.dense3 = Dense(units= 1, activation= 'sigmoid', kernel_initializer= 'glorot_uniform')
        
        super(Classifier, self).build(input_shape)
    
    def call(self, input_):
        x = self.conv1(input_)
        x = self.maxpool1(x)
        
        x = self.conv2(x)
        x = self.maxpool2(x)
        
        x = self.conv3(x)
        x = self.maxpool3(x)
     
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        
        output = self.dense3(x)
        return output
    
    def train_step(self, train_batch):
        X_train, y_train = train_batch
        
        with tf.GradientTape() as tape:
            # apply forward pass
            y_pred = self(X_train, training= True)
            loss = self.compiled_loss(y_train, y_pred, regularization_losses= self.losses)
        # calculate gradients - uses reverse gradient autodiff
        gradients = tape.gradient(loss, self.trainable_variables)
        # backpropagate the gradients and update the weights using the compiled optimizer
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        self.compiled_metrics.update_state(y_train, y_pred)
        return {f"train_{metric.name}": metric.result() for metric in self.metrics}
    
    def test_step(self, test_batch):
        X_test, y_test = test_batch
        # obtain prediciton
        y_pred = self(X_test, training= False)
        
        # updates loss metric
        self.compiled_loss(y_test, y_pred)
        
        # updates metrics
        self.compiled_metrics.update_state(y_test, y_pred)
        
        return {f"{metric.name}": metric.result() for metric in self.metrics}

## Traning the Classifier

In [None]:
EPOCHS= 50
BATCH_SIZE= 16
IMAGE_SIZE= (128, 128, 3)

In [None]:
train_horses_cows_classification = DataLoader(im_dir= "dataset/Newdata/Train", classification= True, output_dim= IMAGE_SIZE, batch_size= BATCH_SIZE)
val_horses_cows_classification = DataLoader(im_dir= "dataset/Newdata/Test", classification= True, output_dim= IMAGE_SIZE, batch_size= BATCH_SIZE)

In [None]:
horses_cows_classifier = Classifier(name= "Horses_vs_Cows_Classifier")
horses_cows_classifier.build(input_shape= (None, *IMAGE_SIZE))

horses_cows_classifier.summary()
horses_cows_classifier.compile(optimizer= Adam(0.001), loss= BinaryCrossentropy(), metrics= [BinaryAccuracy(), Recall(), Precision(), MSE])
model_data = horses_cows_classifier.fit(train_horses_cows_classification, validation_data= val_horses_cows_classification, epochs= EPOCHS, workers= 10)

In [None]:
show_metrics_classification(horses_cows_classifier.history)

## DC-GAN
First, we will experiment with a normal DCGAN without conditioning the output to a desired target variable. This way we gain more understanding of how GANs train which can be carried to the Conditional GAN part.

### Building The DCGAN Architecture
#### Generator
#### Convolutional Layer With Batch Normalization and Leaky ReLU Activation Function
This block will be used multiple times in the Generator Block

In [None]:
class Conv2DTBatchNorm(Layer):
    def __init__(self, filters, kernel_size, strides, padding, use_bias):
        super(Conv2DTBatchNorm, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.use_bias = use_bias
    
    def build(self, input_shape):
        self.conv2d_t = Conv2DTranspose(filters= self.filters, kernel_size= self.kernel_size, strides= self.strides, use_bias= self.use_bias)
        self.batch_norm = BatchNormalization()
        self.leaky_relu = LeakyReLU()
        super(Conv2DTBatchNorm, self).build(input_shape)
    
    def call(self, input_tensor):
        x = input_tensor
        
        x = self.conv2d_t(x)
        x = self.batch_norm(x)
        x = self.leaky_relu(x)
        
        return x    

In [None]:
class Generator(Model):
    def __init__(self, latent_dim, name, **kwargs):
        super(Generator, self).__init__(name= name, **kwargs)
        self.latent_dim = latent_dim
    
    def build(self, input_shape):
        assert input_shape[1:] == (1, 1, self.latent_dim), f"input_shape should have shape (batch_size, 1, 1, latent_dimension), received: {input_shape}"
        self.conv2d_t_1 = Conv2DTBatchNorm(filters= 64*12, kernel_size= (4, 4), strides=(1, 1), padding="valid", use_bias= False)
        self.conv2d_t_2 = Conv2DTBatchNorm(filters= 64*8, kernel_size= (2, 2), strides=(2, 2), padding="same", use_bias= False)
        self.conv2d_t_3 = Conv2DTBatchNorm(filters= 64*8, kernel_size= (2, 2), strides=(2, 2), padding="same", use_bias= False)
        self.conv2d_t_4 = Conv2DTBatchNorm(filters= 64*4, kernel_size= (2, 2), strides=(2, 2), padding="same", use_bias= False)
        self.conv2d_t_5 = Conv2DTBatchNorm(filters= 64*4, kernel_size= (2, 2), strides=(2, 2), padding="same", use_bias= False)
        self.conv2d_t_6 = Conv2DTBatchNorm(filters= 64*2, kernel_size= (2, 2), strides=(2, 2), padding="same", use_bias= False)
        
        self.conv2d = Conv2D(filters= 3, kernel_size= (3, 3), strides=(1, 1), padding="same", use_bias= False, activation= "tanh")
        super(Generator, self).build(input_shape)
    
    def call(self, input_tensor):
        x = input_tensor
        x = self.conv2d_t_1(x)
        x = self.conv2d_t_2(x)
        x = self.conv2d_t_3(x)
        x = self.conv2d_t_4(x)
        x = self.conv2d_t_5(x)
        x = self.conv2d_t_6(x)
        
        x = self.conv2d(x)
        return x

#### Discriminator

In [None]:
class Conv2DBatchNorm(Layer):
    def __init__(self, filters, kernel_size, strides, padding, use_bias, batch_norm= True):
        super(Conv2DBatchNorm, self).__init__()
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        self.use_bias = use_bias
        self.batch_norm = batch_norm
        
    def build(self, input_shape):
        self.conv = Conv2D(filters= self.filters, kernel_size= self.kernel_size, strides= self.strides, padding= self.padding, use_bias= self.use_bias)
        if self.batch_norm:
            self.bn = BatchNormalization()
        self.leaky_relu = LeakyReLU(alpha= 0.2)
        
        super(Conv2DBatchNorm, self).build(input_shape)
    
    def call(self, input_tensor):
        x = input_tensor
        x = self.conv(x)
        if self.batch_norm:
            x = self.bn(x)
        x = self.leaky_relu(x)
        
        return x

In [None]:
class Discriminator(Model):
    def __init__(self, name, **kwargs):
        super(Discriminator, self).__init__(name, **kwargs)
    
    def build(self, input_shape):
        self.conv1 = Conv2DBatchNorm(filters=64, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)
        
        self.conv_bn_1 = Conv2DBatchNorm(filters=64*2, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)
        self.conv_bn_2 = Conv2DBatchNorm(filters=64*4, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)        
        self.conv_bn_3 = Conv2DBatchNorm(filters=64*8, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)        
        self.conv_bn_4 = Conv2DBatchNorm(filters=64*8, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)        
        self.conv_bn_5 = Conv2DBatchNorm(filters=64*4, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)        
        self.conv_bn_6 = Conv2DBatchNorm(filters=64*2, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)
        self.conv_bn_7 = Conv2DBatchNorm(filters=64, kernel_size=(4, 4), strides=(2, 2), padding="same", use_bias=False, batch_norm= False)  
        
        self.conv2 = Conv2D(1, (3, 3), strides=(4, 4), padding="same", use_bias= False, activation= 'sigmoid')
        super(Discriminator, self).build(input_shape)
    
    def call(self, input_tensor):
        x = input_tensor
        x = self.conv1(x)
        
        x = self.conv_bn_1(x)
        x = self.conv_bn_2(x)
        x = self.conv_bn_3(x)
        x = self.conv_bn_4(x)
        x = self.conv_bn_5(x)        
        x = self.conv_bn_6(x)
        x = self.conv_bn_7(x)
        
        x = self.conv2(x)
        return x

#### GAN

In [None]:
class GAN(Model):
    def __init__(self, discriminator, generator, latent_dim, batch_size):
        super(GAN, self).__init__()
        self.discriminator = discriminator
        self.generator = generator
        self.latent_dim = latent_dim
        self.batch_size = batch_size
        
    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super(GAN, self).compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn
        
    def call(self, data, training=False): 
        # Method needed to be implemented for tensorflow reasons when using a custom data loader
        pass
    
    def train_step(self, real_images):
        if isinstance(real_images, tuple):
            real_images = real_images[0]
        # Sample random points in the latent space
        batch_size = self.batch_size
        #random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_latent_vectors = tf.random.normal(shape=(batch_size, 1, 1, self.latent_dim))
        
        # Decode them to fake images
        generated_images = self.generator(random_latent_vectors)
        
        # Combine them with real images
        combined_images = tf.concat([generated_images, real_images], axis=0)
        
        # Assemble labels discriminating real from fake images
        labels = tf.concat(
            [tf.ones((batch_size, 1, 1, 1)), tf.zeros((batch_size, 1, 1, 1))], axis=0
        )
        
        # Add random noise to the labels - important trick!
        labels += 0.05 * tf.random.uniform(tf.shape(labels))

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions = self.discriminator(combined_images)
            d_loss = self.loss_fn(labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(
            zip(grads, self.discriminator.trainable_weights)
        )

        # Sample random points in the latent space
        #random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        random_latent_vectors = tf.random.normal(shape=(batch_size, 1, 1, self.latent_dim))

        # Assemble labels that say "all real images"
        misleading_labels = tf.zeros((batch_size, 1, 1, 1))

        # Train the generator (note that we should *not* update the weights
        # of the discriminator)!
        with tf.GradientTape() as tape:
            predictions = self.discriminator(self.generator(random_latent_vectors))
            g_loss = self.loss_fn(misleading_labels, predictions)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        return {"d_loss": d_loss, "g_loss": g_loss}

**Callbacks**

In [None]:
import datetime
class SaveImagesCallback(Callback):
    def __init__(self, logdir, latent_dim, save_freq, batch_size):
        self.logdir = Path(f"{logdir}/gan_image_output/{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}")
        self.latent_dim = latent_dim
        self.save_freq = save_freq
        
        self.logdir.mkdir(exist_ok=True, parents=True)
        self.batch_size = batch_size
        
        self.fixed_noise = tf.random.normal([self.batch_size, 1, 1, self.latent_dim])
        
    def on_epoch_end(self, epoch, logs= None):
        if epoch % self.save_freq == 0:
            generator = self.model.generator
            predictions = generator(self.fixed_noise, training=False)
            
            pred_index = np.random.choice(np.array(list(range(predictions.shape[0]))), size= predictions.shape[0])
            predictions = np.array([predictions[x, :, :, :] for x in pred_index])
            plt_shape = int(np.math.sqrt(predictions.shape[0]))
            
            fig = plt.figure(figsize=(8, 8))
            for i in range(predictions.shape[0]):
                plt.subplot(plt_shape, plt_shape, i+1)
                plt.imshow(np.asarray(predictions[i, :, :, :] * 127.5 + 127.5, dtype= np.uint8))
                plt.axis('off')
            plt.show()
            plt.savefig(f'{str(self.logdir)}/tf_image_at_epoch_{epoch:04d}.png')

In [None]:
BATCH_SIZE = 16
EPOCHS= 30000
EPOCH_SAVE_FREQ = 30000
LATENT_DIMENSION = 128
IMAGE_SIZE = (128, 128, 3)

In [None]:
train_horses_cows_data = DataLoader(im_dir= "dataset/Newdata/Train", classification= False, output_dim= IMAGE_SIZE, batch_size= BATCH_SIZE)

In [None]:
logdir = 'gan-logdir/horses_cows/'
isaveimg = SaveImagesCallback(logdir= logdir, latent_dim= LATENT_DIMENSION, save_freq= 100, batch_size= BATCH_SIZE)

In [None]:
discriminator = Discriminator(name= "Discriminator")
generator = Generator(latent_dim= LATENT_DIMENSION, name= "Generator")
gan = GAN(discriminator=discriminator, generator=generator, latent_dim=LATENT_DIMENSION, batch_size= BATCH_SIZE)

optimizer_d = Adam(0.0001, 0.5)
optimizer_g = Adam(0.0001, 0.5)

gan.compile(d_optimizer= optimizer_d, g_optimizer= optimizer_g, loss_fn=BinaryCrossentropy())
gan.fit(train_horses_cows_data, epochs=EPOCHS, callbacks= [isaveimg])

## Conditional GAN

In [None]:
#tf.config.experimental_run_functions_eagerly(True)

In [None]:
class Generator(Model):
    def __init__(self, latent_dim, num_classes):
        super(Generator, self).__init__()
        self.latent_dim, self.num_classes = latent_dim, num_classes
        
    def build(self, input_shape):
        self.seq_model = Sequential(
            [
                Dense(16 * 16 * (self.latent_dim + self.num_classes)),
                LeakyReLU(alpha=0.2),
                Reshape((16, 16, (self.latent_dim + self.num_classes))),
                Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                LeakyReLU(alpha=0.2),
                Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                LeakyReLU(alpha=0.2),
                Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
                LeakyReLU(alpha=0.2),
                Conv2D(3, (7, 7), padding="same", activation="tanh"),    
            ]
        )
        super(Generator, self).build(input_shape)
        
    def call(self, input_tensor):
        noise, label = input_tensor
        label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
        model_input = multiply([noise, label_embedding])
        
        output = self.seq_model(model_input)
        return output

In [None]:
def build_generator(latent_dim, num_classes):

    model = Sequential(
    [
        layers.InputLayer(latent_dim,),
        Dense(16 * 16 * (latent_dim + num_classes)),
        LeakyReLU(alpha=0.2),
        Reshape((16, 16, (latent_dim + num_classes))),
        Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        LeakyReLU(alpha=0.2),
        Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        LeakyReLU(alpha=0.2),
        Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
        LeakyReLU(alpha=0.2),
        Conv2D(3, (7, 7), padding="same", activation="tanh"),
    ],
    name="generator",
    )

    model.summary()

    noise = Input(shape=(latent_dim,))
    label = Input(shape=(1,), dtype='int32')
    label_embedding = Flatten()(Embedding(num_classes, latent_dim)(label))
    model_input = multiply([noise, label_embedding])
    img = model(model_input)

    return Model([noise, label], img)

In [None]:
class Discriminator(Model):
    def __init__(self, image_size, num_classes):
        super(Discriminator, self).__init__()
        self.img_shape, self.num_classes = image_size, num_classes
        
    def build(self, input_shape):
        self.seq_model = Sequential(
            [
                Dense(512),
                LeakyReLU(alpha=0.2),
                Dense(512),
                LeakyReLU(alpha=0.2),
                Dropout(0.4),
                Dense(512),
                LeakyReLU(alpha=0.2),
                Dropout(0.4),
                Dense(1, activation='sigmoid')
            ]
        )
        super(Discriminator, self).build(input_shape)
        
    def call(self, input_tensor):
        img, label = input_tensor
        assert img.shape[1:] == self.img_shape, f"Received input shape: {img.shape[1:]} does not match image shape: {self.img_shape} in Discriminator Call"
        
        label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
        flat_img = Flatten()(img)
        model_input = multiply([flat_img, label_embedding])
        
        output = self.seq_model(model_input)
        return output

In [None]:
def build_discriminator(img_shape, num_classes):

    model = Sequential()

    model.add(Dense(512, input_dim=np.prod(img_shape)))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Dense(1, activation='sigmoid'))
    model.summary()

    img = Input(shape=img_shape)
    label = Input(shape=(1,), dtype='int32')

    label_embedding = Flatten()(Embedding(num_classes, np.prod(img_shape))(label))
    flat_img = Flatten()(img)

    model_input = multiply([flat_img, label_embedding])

    validity = model(model_input)

    return Model([img, label], validity)

In [None]:
class CondGAN():
    def __init__(self, image_size, num_classes, latent_dim, generator, discriminator):
        # Input shape
        self.img_rows, self.cols, self.channels = image_size
        self.img_shape = image_size
        self.num_classes = num_classes
        self.latent_dim = latent_dim

        optimizer = Adam(0.0002, 0.5)

        # Build and compile the discriminator
        self.discriminator = build_discriminator(self.img_shape, self.num_classes)
        
        #self.discriminator.build(input_shape= [(None, *self.img_shape), (None, 1,)])
        
        self.discriminator.compile(loss=['binary_crossentropy'],
            optimizer=optimizer,
            metrics=['accuracy'])

        # Build the generator
        self.generator = build_generator(latent_dim= self.latent_dim, num_classes= self.num_classes)
        #self.generator.build(input_shape= [(None, self.latent_dim), (None, 1,)])

        # The generator takes noise and the target label as input
        # and generates the corresponding class of that label
        noise = Input(shape=(self.latent_dim,))
        label = Input(shape=(1,))
        img = self.generator([noise, label])

        # For the combined model we will only train the generator
        self.discriminator.trainable = False

        # The discriminator takes generated image as input and determines validity
        # and the label of that image
        valid = self.discriminator([img, label])

        # The combined model  (stacked generator and discriminator)
        # Trains generator to fool discriminator
        self.combined = Model([noise, label], valid)
        self.combined.compile(loss=['binary_crossentropy'],
            optimizer=optimizer)

    def train(self, epochs, dataset, batch_size=128, sample_interval=50):
        D_loss, G_loss, acc = [], [], []
        # Load iterator dataset assuming DataLoader with one batch (full data)
        X_train, y_train = next(iter(dataset))

        # Adversarial ground truths
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # ---------------------
            #  Train Discriminator
            # ---------------------

            # Select a random half batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs, labels = X_train[idx], y_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            # Generate a half batch of new images
            gen_imgs = self.generator.predict([noise, labels])

            # Train the discriminator
            d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
            d_loss_fake = self.discriminator.train_on_batch([gen_imgs, labels], fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # ---------------------
            #  Train Generator
            # ---------------------

            # Condition on labels
            sampled_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)

            # Train the generator
            g_loss = self.combined.train_on_batch([noise, sampled_labels], valid)

            # Plot the progress
            D_loss.append(d_loss[0])
            G_loss.append(g_loss)
            acc.append(100*d_loss[1])
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            # If at save interval => save generated image samples
            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):
        r, c = 2, 4
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        sampled_labels = np.array([[0], [0], [0], [0], [1], [1], [1], [1]]).reshape(-1, 1)

        gen_imgs = self.generator.predict([noise, sampled_labels])

        # Rescale images 0 - 1
        gen_imgs = gen_imgs

        fig, axs = plt.subplots(r, c)
        for i in range(r):
            if i == 0:
                cnt = 0
                title= "Cow"
            elif i == 1:
                cnt = 4
                title= "Horse"
            for j in range(c):
                axs[i, j].imshow(np.asarray(gen_imgs[cnt,:,:, :]* 127.5 + 127.5, dtype= np.uint8))
                axs[i, j].set_title(title)
                axs[i, j].axis('off')
                cnt += 1
        plt.show()
        plt.close()

In [None]:
IMAGE_SIZE = (128, 128, 3)
BATCH_SIZE = -1
LATENT_DIM = 100
NUM_CLASSES = 2

In [None]:
train_horses_cows = DataLoader(im_dir= "dataset/Newdata/Train", classification= False, output_dim= IMAGE_SIZE, batch_size= BATCH_SIZE)

In [None]:
generator = Generator(LATENT_DIM, NUM_CLASSES)
discriminator = Discriminator(IMAGE_SIZE, NUM_CLASSES)

cgan = CondGAN(image_size= IMAGE_SIZE, num_classes= NUM_CLASSES, latent_dim= LATENT_DIM, generator= generator, discriminator= discriminator)
cgan.train(epochs=40000, dataset= train_horses_cows, batch_size=4, sample_interval=100)

### Generating 1000 Images for Both Cows and Horses

In [None]:
nb_of_images = [250, 500, 1000]

noise_labels = {}
for no_im in nb_of_images:
    noise_labels[f"noise_{no_im}"] = tf.random.normal(shape= (no_im, 100))
    noise_labels[f"labels_{no_im}"] = np.random.randint(low= 0, high= 2, size= (no_im, 1)).reshape(-1, 1)

generated_cows_horses_images_250 = np.asarray(cgan.generator.predict([noise_labels["noise_250"], noise_labels["labels_250"]]) * 127.5 + 127.5, dtype= np.uint8)
generated_cows_horses_images_500 =  np.asarray(cgan.generator.predict([noise_labels["noise_500"], noise_labels["labels_500"]]) * 127.5 + 127.5, dtype= np.uint8)
generated_cows_horses_images_1000 =  np.asarray(cgan.generator.predict([noise_labels["noise_1000"], noise_labels["labels_1000"]]) * 127.5 + 127.5, dtype= np.uint8)

In [None]:
generated_data_250 = (np.asarray(generated_cows_horses_images_250, dtype= np.float32), np.asarray(noise_labels["labels_250"], dtype= np.int32))
generated_data_500 = (np.asarray(generated_cows_horses_images_500, dtype= np.float32), np.asarray(noise_labels["labels_500"], dtype= np.int32))
generated_data_1000 = (np.asarray(generated_cows_horses_images_1000, dtype= np.float32), np.asarray(noise_labels["labels_1000"], dtype= np.int32))

In [None]:
horse_batch = [data for i, data in enumerate(generated_data_250[0]) if generated_data_250[1][i] == 1]
cow_batch = [data for i, data in enumerate(generated_data_250[0]) if generated_data_250[1][i] == 0]

In [None]:
random_horse = np.random.choice(range(len(horse_batch)), size= 4, replace= False)
random_cow = np.random.choice(range(len(cow_batch)-1), size= 4, replace= False)
random_horse, random_cow

In [None]:
cow_images, horse_images = [np.asarray(cow_batch[i], dtype= np.uint8) for i in random_cow], [np.asarray(horse_batch[i], dtype= np.uint8) for i in random_horse]

In [None]:
fig = plt.figure(figsize=(10., 10.), constrained_layout=True)
grid = ImageGrid(fig, 111,  # similar to subplot(111)
                 nrows_ncols=(2, 4),  # creates 2x2 grid of axes
                 axes_pad=0.1,  # pad between axes in inch.
                 )

for ax, im in zip(grid, [*cow_images, *horse_images]):
    # Iterating over the grid returns the Axes.
    ax.imshow(im)
print("\tGenerated Data")
plt.show()

### Appending To Original Dataset

In [None]:
train_horses_cows_classification = tf.keras.preprocessing.image_dataset_from_directory(directory= "dataset/Newdata/Train",
                                                                                      batch_size= 16, image_size= (128, 128))

In [None]:
train_horses_cows_classification

In [None]:
generated_dataset_250 = tf.data.Dataset.from_tensor_slices(generated_data_250).batch(batch_size= 16)
generated_dataset_500 = tf.data.Dataset.from_tensor_slices(generated_data_500).batch(batch_size= 16)
generated_dataset_1000 = tf.data.Dataset.from_tensor_slices(generated_data_1000).batch(batch_size= 16)

In [None]:
combined_dataset_250 = train_horses_cows_classification.concatenate(generated_dataset_250).shuffle(buffer_size= len(generated_dataset_250))
combined_dataset_500 = train_horses_cows_classification.concatenate(generated_dataset_500).shuffle(buffer_size= len(generated_dataset_500))
combined_dataset_1000 = train_horses_cows_classification.concatenate(generated_dataset_1000).shuffle(buffer_size= len(generated_dataset_1000))

val_horses_cows = DataLoader(im_dir= "dataset/Newdata/Test", classification= True, output_dim= (128, 128, 3))

### Classification Re-training

In [None]:
EPOCHS= 50
BATCH_SIZE= 16
IMAGE_SIZE= (128, 128, 3)

In [None]:
horses_cows_classifier_gen = Classifier(name= "Horses_vs_Cows_Classifier_with_Gen_Data")
horses_cows_classifier_gen.build(input_shape= (None, *IMAGE_SIZE))

horses_cows_classifier_gen.summary()
horses_cows_classifier_gen.compile(optimizer= Adam(0.001), loss= BinaryCrossentropy(), metrics= [BinaryAccuracy(), Recall(), Precision(), MSE])
# combined_dataset_250, combined_dataset_500, combined_dataset_1000
model_data = horses_cows_classifier_gen.fit(combined_dataset_1000, validation_data= val_horses_cows_classification, epochs= EPOCHS, workers= 10)

### Re-training Results

In [None]:
show_metrics_classification(horses_cows_classifier_gen.history)

## Neural Style Transfer and Cycle GANs

In [None]:
train_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory= "dataset/Newdata/Train").unbatch()
test_dataset = tf.keras.preprocessing.image_dataset_from_directory(directory= "dataset/Newdata/Train").unbatch()

train_cows = train_dataset.filter(lambda img, label: label == 0).map(lambda img, label: img)
train_horses = train_dataset.filter(lambda img, label: label == 1).map(lambda img, label: img)

test_cows = test_dataset.filter(lambda img, label: label == 0).map(lambda img, label: img)
test_horses = test_dataset.filter(lambda img, label: label == 1).map(lambda img, label: img)

In [None]:
# Define the standard image size.
orig_img_size = (256, 256)
# Size of the random crops to be used during training.
input_img_size = (200, 200, 3)
# Weights initializer for the layers.
kernel_init = RandomNormal(mean=0.0, stddev=0.02)
# Gamma initializer for instance normalization.
gamma_init = RandomNormal(mean=0.0, stddev=0.02)

buffer_size = 256
batch_size = 1


def normalize_img(img):
    img = tf.cast(img, dtype=tf.float32)
    # Map values in the range [-1, 1]
    return (img / 127.5) - 1.0


def preprocess_train_image(img):
    # Random flip
    img = tf.image.random_flip_left_right(img)
    # Resize to the original size first
    img = tf.image.resize(img, [*orig_img_size])
    # Random crop to 256X256
    img = tf.image.random_crop(img, size=[*input_img_size])
    # Normalize the pixel values in the range [-1, 1]
    img = normalize_img(img)
    return img


def preprocess_test_image(img):
    # Only resizing and normalization for the test images.
    img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
    img = normalize_img(img)
    return img


In [None]:
# Apply the preprocessing operations to the training data
train_horses = (
    train_horses.map(preprocess_train_image)
    .cache()
    .shuffle(buffer_size)
    .batch(batch_size)
)
train_cows = (
    train_cows.map(preprocess_train_image)
    .cache()
    .shuffle(buffer_size)
    .batch(batch_size)
)

# Apply the preprocessing operations to the test data
test_horses = (
    test_horses.map(preprocess_test_image)
    .cache()
    .shuffle(buffer_size)
    .batch(batch_size)
)
test_cows = (
    test_cows.map(preprocess_test_image)
    .cache()
    .shuffle(buffer_size)
    .batch(batch_size)
)

In [None]:
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(train_horses.take(4), train_cows.take(4))):
    horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    cow = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
    ax[i, 0].imshow(horse)
    ax[i, 1].imshow(cow)
plt.show()


In [None]:
class ReflectionPadding2D(Layer):

    def __init__(self, padding=(1, 1), **kwargs):
        self.padding = tuple(padding)
        super(ReflectionPadding2D, self).__init__(**kwargs)

    def call(self, input_tensor, mask=None):
        padding_width, padding_height = self.padding
        padding_tensor = [
            [0, 0],
            [padding_height, padding_height],
            [padding_width, padding_width],
            [0, 0],
        ]
        return tf.pad(input_tensor, padding_tensor, mode="REFLECT")


def residual_block(
    x,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(1, 1),
    padding="valid",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    dim = x.shape[-1]
    input_tensor = x

    x = ReflectionPadding2D()(input_tensor)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = activation(x)

    x = ReflectionPadding2D()(x)
    x = layers.Conv2D(
        dim,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.add([input_tensor, x])
    return x


def downsample(
    x,
    filters,
    activation,
    kernel_initializer=kernel_init,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2D(
        filters,
        kernel_size,
        strides=strides,
        kernel_initializer=kernel_initializer,
        padding=padding,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x


def upsample(
    x,
    filters,
    activation,
    kernel_size=(3, 3),
    strides=(2, 2),
    padding="same",
    kernel_initializer=kernel_init,
    gamma_initializer=gamma_init,
    use_bias=False,
):
    x = layers.Conv2DTranspose(
        filters,
        kernel_size,
        strides=strides,
        padding=padding,
        kernel_initializer=kernel_initializer,
        use_bias=use_bias,
    )(x)
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    if activation:
        x = activation(x)
    return x


In [None]:
def get_resnet_generator(
    filters=64,
    num_downsampling_blocks=2,
    num_residual_blocks=9,
    num_upsample_blocks=2,
    gamma_initializer=gamma_init,
    name=None,
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = ReflectionPadding2D(padding=(3, 3))(img_input)
    x = layers.Conv2D(filters, (7, 7), kernel_initializer=kernel_init, use_bias=False)(
        x
    )
    x = tfa.layers.InstanceNormalization(gamma_initializer=gamma_initializer)(x)
    x = layers.Activation("relu")(x)

    # Downsampling
    for _ in range(num_downsampling_blocks):
        filters *= 2
        x = downsample(x, filters=filters, activation=layers.Activation("relu"))

    # Residual blocks
    for _ in range(num_residual_blocks):
        x = residual_block(x, activation=layers.Activation("relu"))

    # Upsampling
    for _ in range(num_upsample_blocks):
        filters //= 2
        x = upsample(x, filters, activation=layers.Activation("relu"))

    # Final block
    x = ReflectionPadding2D(padding=(3, 3))(x)
    x = layers.Conv2D(3, (7, 7), padding="valid")(x)
    x = layers.Activation("tanh")(x)

    model = Model(img_input, x, name=name)
    return model


In [None]:
def get_discriminator(
    filters=64, kernel_initializer=kernel_init, num_downsampling=3, name=None
):
    img_input = layers.Input(shape=input_img_size, name=name + "_img_input")
    x = layers.Conv2D(
        filters,
        (4, 4),
        strides=(2, 2),
        padding="same",
        kernel_initializer=kernel_initializer,
    )(img_input)
    x = layers.LeakyReLU(0.2)(x)

    num_filters = filters
    for num_downsample_block in range(3):
        num_filters *= 2
        if num_downsample_block < 2:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(2, 2),
            )
        else:
            x = downsample(
                x,
                filters=num_filters,
                activation=layers.LeakyReLU(0.2),
                kernel_size=(4, 4),
                strides=(1, 1),
            )

    x = layers.Conv2D(
        1, (4, 4), strides=(1, 1), padding="same", kernel_initializer=kernel_initializer
    )(x)

    model = Model(inputs=img_input, outputs=x, name=name)
    return model


# Get the generators
gen_G = get_resnet_generator(name="generator_G")
gen_F = get_resnet_generator(name="generator_F")

# Get the discriminators
disc_X = get_discriminator(name="discriminator_X")
disc_Y = get_discriminator(name="discriminator_Y")


In [None]:
class CycleGan(keras.Model):
    def __init__(
        self,
        generator_G,
        generator_F,
        discriminator_X,
        discriminator_Y,
        lambda_cycle=10.0,
        lambda_identity=0.5,
    ):
        super(CycleGan, self).__init__()
        self.gen_G = generator_G
        self.gen_F = generator_F
        self.disc_X = discriminator_X
        self.disc_Y = discriminator_Y
        self.lambda_cycle = lambda_cycle
        self.lambda_identity = lambda_identity

    def compile(
        self,
        gen_G_optimizer,
        gen_F_optimizer,
        disc_X_optimizer,
        disc_Y_optimizer,
        gen_loss_fn,
        disc_loss_fn,
    ):
        super(CycleGan, self).compile()
        self.gen_G_optimizer = gen_G_optimizer
        self.gen_F_optimizer = gen_F_optimizer
        self.disc_X_optimizer = disc_X_optimizer
        self.disc_Y_optimizer = disc_Y_optimizer
        self.generator_loss_fn = gen_loss_fn
        self.discriminator_loss_fn = disc_loss_fn
        self.cycle_loss_fn = MeanAbsoluteError()
        self.identity_loss_fn = MeanAbsoluteError()

    def train_step(self, batch_data):
        # x is Horse and y is Cow
        real_x, real_y = batch_data

        # For CycleGAN, we need to calculate different
        # kinds of losses for the generators and discriminators.
        # We will perform the following steps here:
        #
        # 1. Pass real images through the generators and get the generated images
        # 2. Pass the generated images back to the generators to check if we
        #    we can predict the original image from the generated image.
        # 3. Do an identity mapping of the real images using the generators.
        # 4. Pass the generated images in 1) to the corresponding discriminators.
        # 5. Calculate the generators total loss (adverserial + cycle + identity)
        # 6. Calculate the discriminators loss
        # 7. Update the weights of the generators
        # 8. Update the weights of the discriminators
        # 9. Return the losses in a dictionary

        with tf.GradientTape(persistent=True) as tape:
            # Horse to fake Cow
            fake_y = self.gen_G(real_x, training=True)
            # Cow to fake horse -> y2x
            fake_x = self.gen_F(real_y, training=True)

            # Cycle (Horse to fake Cow to fake horse): x -> y -> x
            cycled_x = self.gen_F(fake_y, training=True)
            # Cycle (Cow to fake horse to fake Cow) y -> x -> y
            cycled_y = self.gen_G(fake_x, training=True)

            # Identity mapping
            same_x = self.gen_F(real_x, training=True)
            same_y = self.gen_G(real_y, training=True)

            # Discriminator output
            disc_real_x = self.disc_X(real_x, training=True)
            disc_fake_x = self.disc_X(fake_x, training=True)

            disc_real_y = self.disc_Y(real_y, training=True)
            disc_fake_y = self.disc_Y(fake_y, training=True)

            # Generator adverserial loss
            gen_G_loss = self.generator_loss_fn(disc_fake_y)
            gen_F_loss = self.generator_loss_fn(disc_fake_x)

            # Generator cycle loss
            cycle_loss_G = self.cycle_loss_fn(real_y, cycled_y) * self.lambda_cycle
            cycle_loss_F = self.cycle_loss_fn(real_x, cycled_x) * self.lambda_cycle

            # Generator identity loss
            id_loss_G = (
                self.identity_loss_fn(real_y, same_y)
                * self.lambda_cycle
                * self.lambda_identity
            )
            id_loss_F = (
                self.identity_loss_fn(real_x, same_x)
                * self.lambda_cycle
                * self.lambda_identity
            )

            # Total generator loss
            total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
            total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F

            # Discriminator loss
            disc_X_loss = self.discriminator_loss_fn(disc_real_x, disc_fake_x)
            disc_Y_loss = self.discriminator_loss_fn(disc_real_y, disc_fake_y)

        # Get the gradients for the generators
        grads_G = tape.gradient(total_loss_G, self.gen_G.trainable_variables)
        grads_F = tape.gradient(total_loss_F, self.gen_F.trainable_variables)

        # Get the gradients for the discriminators
        disc_X_grads = tape.gradient(disc_X_loss, self.disc_X.trainable_variables)
        disc_Y_grads = tape.gradient(disc_Y_loss, self.disc_Y.trainable_variables)

        # Update the weights of the generators
        self.gen_G_optimizer.apply_gradients(
            zip(grads_G, self.gen_G.trainable_variables)
        )
        self.gen_F_optimizer.apply_gradients(
            zip(grads_F, self.gen_F.trainable_variables)
        )

        # Update the weights of the discriminators
        self.disc_X_optimizer.apply_gradients(
            zip(disc_X_grads, self.disc_X.trainable_variables)
        )
        self.disc_Y_optimizer.apply_gradients(
            zip(disc_Y_grads, self.disc_Y.trainable_variables)
        )

        return {
            "G_loss": total_loss_G,
            "F_loss": total_loss_F,
            "D_X_loss": disc_X_loss,
            "D_Y_loss": disc_Y_loss,
        }

In [None]:
class GANMonitor(Callback):
    """A callback to generate and save images after each epoch"""

    def __init__(self, num_img=4):
        self.num_img = num_img

    def on_epoch_end(self, epoch, logs=None):
        _, ax = plt.subplots(4, 2, figsize=(12, 12))
        for i, img in enumerate(test_horses.take(self.num_img)):
            prediction = self.model.gen_G(img)[0].numpy()
            prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
            img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

            ax[i, 0].imshow(img)
            ax[i, 1].imshow(prediction)
            ax[i, 0].set_title("Input image")
            ax[i, 1].set_title("Translated image")
            ax[i, 0].axis("off")
            ax[i, 1].axis("off")

            prediction = tf.keras.preprocessing.image.array_to_img(prediction)
            prediction.save(
                "generated_img_{i}_{epoch}.png".format(i=i, epoch=epoch + 1)
            )
        plt.show()
        plt.close()

In [None]:
# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError()

# Define the loss function for the generators
def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss


# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5


# Create cycle gan model
cycle_gan_model = CycleGan(
    generator_G=gen_G, generator_F=gen_F, discriminator_X=disc_X, discriminator_Y=disc_Y
)

# Compile the model
cycle_gan_model.compile(
    gen_G_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
    gen_F_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
    disc_X_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
    disc_Y_optimizer=Adam(learning_rate=2e-4, beta_1=0.5),
    gen_loss_fn=generator_loss_fn,
    disc_loss_fn=discriminator_loss_fn,
)
# Callbacks
plotter = GANMonitor()

cycle_gan_model.fit(
    tf.data.Dataset.zip((train_horses, train_cows)),
    epochs=50,
    callbacks=[plotter],
)

In [None]:
_, ax = plt.subplots(4, 2, figsize=(10, 15))
for i, img in enumerate(test_horses.take(4)):
    prediction = cycle_gan_model.gen_G(img, training=False)[0].numpy()
    prediction = (prediction * 127.5 + 127.5).astype(np.uint8)
    img = (img[0] * 127.5 + 127.5).numpy().astype(np.uint8)

    ax[i, 0].imshow(img)
    ax[i, 1].imshow(prediction)
    ax[i, 0].set_title("Input image")
    ax[i, 0].set_title("Input image")
    ax[i, 1].set_title("Translated image")
    ax[i, 0].axis("off")
    ax[i, 1].axis("off")

    prediction = tf.keras.preprocessing.image.array_to_img(prediction)
    prediction.save("predicted_img_{i}.png".format(i=i))
plt.tight_layout()
plt.show()