In this example, we use cycleGAN to convert horse images into zebra images (and vice versa).

cycleGAN paper: https://arxiv.org/abs/1703.10593

Dataset from: http://efrosgans.eecs.berkeley.edu/cyclegan/datasets/horse2zebra.zip

cycleGAN Architecture:

<img src="https://blog.jaysinha.me/content/images/size/w2000/2023/03/cyclegan.png" alt="cycleGAN Architecture"/>

In [None]:
!pip install git+https://www.github.com/keras-team/keras-contrib.git

In [14]:
from keras.layers import Input, Conv2D, Conv2DTranspose, LeakyReLU, Concatenate, Dropout, BatchNormalization, Activation
from keras.initializers import RandomNormal
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from keras.models import Model
try:
    from keras.optimizers.legacy import Adam
except ImportError:
    from keras.optimizers import Adam
from keras.utils import plot_model, set_random_seed, load_img, img_to_array
import matplotlib.pyplot as plt
import numpy as np
from tqdm.notebook import tqdm
from os import listdir

Some user inputs:

In [15]:
data_path = "/content/drive/MyDrive/Colab Notebooks/horse2zebra/"
set_random_seed(1000) # Sets all random seeds (Python, NumPy, and backend framework, e.g. TF).

In [16]:
def load_images(path, size=(256,256)):
    imgs = []
    for filename in listdir(path):
        im = load_img(path + filename, target_size=size)
        im = img_to_array(im)
        im = (im - 127.5) / 127.5 # normalize pixel values to [-1,+1]
        imgs.append(im)
    return np.array(imgs)

horse_images_all = load_images(data_path + 'trainA/')
zebra_images_all = load_images(data_path + 'trainB/')


In [17]:
horse_images = horse_images_all[np.random.randint(0, horse_images_all.shape[0], 500)]
zebra_images = zebra_images_all[np.random.randint(0, zebra_images_all.shape[0], 500)]
training_dataset = [horse_images, zebra_images]  # define training dataset as [domainA, domainB]

In [None]:
def buildD(input_shape): # the function to build the discriminator network
    init = RandomNormal(stddev=0.02, seed=1000)
    input_img = Input(shape=input_shape)

    out = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(input_img)
    out = LeakyReLU(alpha=0.2)(out)

    out = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization(axis=-1)(out)
    out = LeakyReLU(alpha=0.2)(out)

    out = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization(axis=-1)(out)
    out = LeakyReLU(alpha=0.2)(out)

    # Not in the original paper. Comment this block if you want.
    out = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization(axis=-1)(out)
    out = LeakyReLU(alpha=0.2)(out)

    out = Conv2D(512, (4,4), strides=(1,1), padding='same', kernel_initializer=init)(out)
    out = InstanceNormalization(axis=-1)(out)
    out = LeakyReLU(alpha=0.2)(out)

    out = Conv2D(1, (4,4), strides=(1,1), padding='same', activation='sigmoid', kernel_initializer=init)(out)
    model = Model(input_img, out)
    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), loss_weights=[0.5])
    return model

horseD = buildD(horse_images.shape[1:])
zebraD = buildD(zebra_images.shape[1:])
try:
    horseD.load_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_horseD_weights.h5')
    zebraD.load_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_zebraD_weights.h5')
except FileNotFoundError:
    print("weights NOT LOADED!!!!")
horseD.summary(expand_nested=True)
plot_model(horseD, show_shapes=True, expand_nested=True, show_layer_activations=True, show_layer_names=False, dpi=70)

In [None]:
def buildG(input_shape): # the function to build the generator network
    init = RandomNormal(stddev=0.02, seed=1000)

    def resnet_block(n_filters, input_layer):
    	# first convolutional layer
    	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
    	g = InstanceNormalization(axis=-1)(g)
    	g = Activation('relu')(g)
    	# second convolutional layer
    	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
    	g = InstanceNormalization(axis=-1)(g)
    	# concatenate merge channel-wise with input layer
    	g = Concatenate()([g, input_layer])
    	return g

    input_img = Input(shape=input_shape)

	# c7s1-64
    g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(input_img)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
    # d128
    g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
	# d256
    g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
	# R256
    for _ in range(9): # addind 9 ResNet blocks
        g = resnet_block(256, g)
	# u128
    g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
	# u64
    g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    g = Activation('relu')(g)
	# c7s1-3
    g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
    g = InstanceNormalization(axis=-1)(g)
    out_image = Activation('tanh')(g)

    return Model(input_img, out_image)

horseG = buildG(horse_images.shape[1:])
zebraG = buildG(zebra_images.shape[1:])
try:
    horseG.load_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_horseG_weights.h5')
    zebraG.load_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_zebraG_weights.h5')
except FileNotFoundError:
    print("weights NOT LOADED!!!!")
horseG.summary(expand_nested=True)
plot_model(horseG, show_shapes=True, expand_nested=True, show_layer_activations=True, show_layer_names=False, dpi=70)

In [None]:
# this function combines the generator and the discriminator in order to build the cycleGAN network
def buidGAN(g_model_1, d_model_1, g_model_2, input_shape):
    g_model_1.trainable = True
    d_model_1.trainable = False
    g_model_2.trainable = False

    input_img1 = Input(shape=input_shape)
    # adversarial loss
    g1_out = g_model_1(input_img1)
    d1_out = d_model_1(g1_out)
    # identity loss
    input_img2 = Input(shape=input_shape)
    g1_out_identity = g_model_1(input_img2)
    # cycle loss - forward
    g2_out_forward = g_model_2(g1_out)
    # cycle loss - backward
    g1_out_backward = g_model_1(g_model_2(input_img2))

    model = Model([input_img1, input_img2], [d1_out, g1_out_identity, g2_out_forward, g1_out_backward])
    model.compile(loss=['binary_crossentropy', 'mae', 'mae', 'mae'], loss_weights=[1, 5, 10, 10], optimizer=Adam(0.0002, 0.5))
    return model

horseGAN = buidGAN(horseG, horseD, zebraG, horse_images.shape[1:])
zebraGAN = buidGAN(zebraG, zebraD, horseG, zebra_images.shape[1:])
horseGAN.summary(expand_nested=True)
plot_model(horseGAN, show_shapes=True, expand_nested=True, show_layer_activations=True, show_layer_names=False, dpi=70)

In [21]:
# select a batch of random samples
def generate_real_samples(data, n_samples, patch_shape):
	# choose random instances
	ix = np.random.randint(0, data.shape[0], n_samples)
	# retrieve selected images
	X = data[ix]
	# generate 'real' class labels (1)
	y = np.ones((n_samples, patch_shape, patch_shape, 1))
	return X, y

# generate a batch of fake images
def generate_fake_samples(g_model, samples, patch_shape):
	# generate fake instance
	X = g_model.predict(samples, verbose=0)
	# create 'fake' class labels (0)
	y = np.zeros((len(X), patch_shape, patch_shape, 1))
	return X, y

# generate samples and plot them
def plot_sample_images(g_model_1, g_model_2, dataset):
    # select a sample of images from domain1 and domain2
    input_images1, _ = generate_real_samples(dataset[0], 1, 1)
    input_images2, _ = generate_real_samples(dataset[1], 1, 1)
    # domain1 -> domain2 test: calculating generator1's output and passing it to generator2 (we would expect to see the primary input images as result)
    g12_1, _ = generate_fake_samples(g_model_1, input_images1, 1)
    g21_1, _ = generate_fake_samples(g_model_2, g12_1, 1)
    # domain2 -> domain1 test: calculating generator2's output and passing it to generator1 (we would expect to see the primary input images as result)
    g21_2, _ = generate_fake_samples(g_model_2, input_images2, 1)
    g12_2, _ = generate_fake_samples(g_model_1, g21_2, 1)
    # identity test: passing an image from domain2 to generator1 should output the same input image. the same applies for domain1 and generator2.
    g11, _ = generate_fake_samples(g_model_1, input_images2, 1)
    g22, _ = generate_fake_samples(g_model_2, input_images1, 1)
    # scale all pixels from [-1,1] to [0,1]
    input_images1 = (input_images1 + 1) / 2.0
    input_images2 = (input_images2 + 1) / 2.0
    g12_1 = (g12_1 + 1) / 2.0
    g21_1 = (g21_1 + 1) / 2.0
    g12_2 = (g12_2 + 1) / 2.0
    g12_2 = (g12_2 + 1) / 2.0
    g11 =   (g11 + 1) / 2.0
    g22 =   (g22 + 1) / 2.0

    # plot real source images
    fig, axs = plt.subplots(4, 3, figsize=(10,10))
    axs[0,0].imshow(input_images1[0])
    axs[0,1].imshow(g12_1[0])
    axs[0,2].imshow(g21_1[0])
    axs[0,0].set_title("domain1 ---->", size=10)
    axs[0,1].set_title("generator1 ---->", size=10)
    axs[0,2].set_title("generator2 (reconstructed)", size=10)

    axs[1,0].imshow(input_images2[0])
    axs[1,1].imshow(g21_2[0])
    axs[1,2].imshow(g12_2[0])
    axs[1,0].set_title("domain2 ---->", size=10)
    axs[1,1].set_title("generator2 ---->", size=10)
    axs[1,2].set_title("generator1 (reconstructed)", size=10)

    axs[2,0].imshow(input_images2[0])
    axs[2,1].imshow(g11[0])
    axs[2,0].set_title("domain2 ---->", size=10)
    axs[2,1].set_title("generator1 (identity test)", size=10)

    axs[3,0].imshow(input_images1[0])
    axs[3,1].imshow(g22[0])
    axs[3,0].set_title("domain1 ---->", size=10)
    axs[3,1].set_title("generator2 (identity test)", size=10)

    for i in range(4):
        for j in range(3):
            axs[i,j].axis('off')
    plt.show()

In [22]:
# train pix2pix model
def train(g_model_1, d_model_1, g_model_2, d_model_2, gan_model_1, gan_model_2, dataset, n_epochs=100, n_batch=1, plot_interval=10):
    # determine the output square shape of the discriminator
    n_patch = d_model_1.output_shape[1]
    # calculate the number of batches per training epoch
    bat_per_epo = int(len(dataset[0]) / n_batch)
    # calculate the number of training iterations
    n_steps = bat_per_epo * n_epochs
    # manually enumerate epochs
    for i in tqdm(range(n_steps)):
        # select a batch of real samples from domain1 and domain2
        data1, y_real_1 = generate_real_samples(dataset[0], n_batch, n_patch)
        data2, y_real_2 = generate_real_samples(dataset[1], n_batch, n_patch)
        # generate a batch of fake samples
        g1, y_fake_1 = generate_fake_samples(g_model_1, data1, n_patch)
        g2, y_fake_2 = generate_fake_samples(g_model_2, data2, n_patch)
        # update discriminator for real samples
        d_model_1.train_on_batch(data2, y_real_2)
        d_model_2.train_on_batch(data1, y_real_1)
        # update discriminator for generated samples
        d_model_1.train_on_batch(g1, y_fake_1)
        d_model_2.train_on_batch(g2, y_fake_2)
        # update the generator
        gan_model_1.train_on_batch([data1, data2], [y_real_1, data2, data1, data2])
        gan_model_2.train_on_batch([data2, data1], [y_real_2, data1, data2, data1])
        if i % (bat_per_epo*plot_interval-1) == 0:
            plot_sample_images(g_model_1, g_model_2, dataset)
            # save the model weights
            d_model_1.save_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_horseD_weights.h5')
            g_model_1.save_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_horseG_weights.h5')
            d_model_2.save_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_zebraD_weights.h5')
            g_model_2.save_weights('/content/drive/MyDrive/Colab Notebooks/horse2zebra/cycleGAN_zebraG_weights.h5')

In [None]:
train(horseG, horseD, zebraG, zebraD, horseGAN, zebraGAN, training_dataset, n_epochs=10, n_batch=1, plot_interval=1)