In [None]:
import numpy as np 
import matplotlib.pyplot as plt
import tensorflow as tf 
import tensorflow_datasets as tfds
import cv2

# Model Definition

In [None]:
class InstanceNormalization(tf.keras.layers.Layer):
    #https://arxiv.org/abs/1607.08022
    def __init__(self, epsilon=0.001):
        super(InstanceNormalization, self).__init__()
        self.epsilon = epsilon

    def build(self, input_shape):
        self.gamma = self.add_weight(name='gamma', shape=input_shape[-1:], initializer='ones', trainable=True)
        self.beta = self.add_weight(name='beta', shape=input_shape[-1:], initializer='zeros', trainable=True)

    def call(self, x):
        mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
        inverseStdDev = tf.math.rsqrt(variance + self.epsilon)
        normalized = (x - mean) * inverseStdDev
        return self.gamma * normalized + self.beta
    

def get_resnet_generator(num_input_channels=3, num_output_channels=3):
    ngf = 64
    kernel_size = 4
    n_blocks = 9
    n_downsampling_steps = 2

    def get_resiudal_block(inputs, n_filters, dropout=False):
        x = inputs
        x = tf.keras.layers.Conv2D(filters=n_filters, kernel_size=kernel_size, strides=1, padding='SAME', activation=None)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ReLU()(x)
        if dropout:
            x = tf.keras.layers.Dropout(0.5)(x)
        x = tf.keras.layers.Conv2D(filters=n_filters, kernel_size=kernel_size, strides=1, padding='SAME', activation=None)(x)
        x = tf.keras.layers.BatchNormalization()(x)
        x = x + inputs
        return x

    class ResidualBlock(tf.keras.layers.Layer):
        def __init__(self, n_filters):
            super(ResidualBlock, self).__init__()
            self.n_filters = n_filters

        def build(self, input_shape):
            inputs = tf.keras.Input(input_shape[1:])
            x = get_resiudal_block(inputs, self.n_filters)
            self.block = tf.keras.Model(inputs=inputs, outputs=x)
        def call(self, inputs):
            return self.block(inputs)


    layers = [
        tf.keras.Input((None,None,num_input_channels)),
        tf.keras.layers.Conv2D(filters=ngf, kernel_size=8, padding='SAME', activation=None),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.ReLU()
        ]


    for i in range(n_downsampling_steps):
        layers += [ 
            tf.keras.layers.Conv2D(filters=ngf*2*(2**i), kernel_size=kernel_size, strides=2, padding='SAME', activation=None),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU()
        ]

    for i in range(n_blocks):    
        layers += [ResidualBlock(ngf*(2**n_downsampling_steps))]

    for i in range(n_downsampling_steps):
        n_filters = ngf * (2**(n_downsampling_steps - i - 1))
        layers += [ 
            tf.keras.layers.Conv2DTranspose(filters=n_filters, kernel_size=4, strides=2, padding='SAME', activation=None),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.ReLU()
        ]

    layers += [
        tf.keras.layers.Conv2DTranspose(filters=num_output_channels, kernel_size=8, padding='SAME', activation=None),
        tf.keras.layers.Lambda(lambda el: tf.keras.activations.tanh(el))
        ]

    model = tf.keras.Sequential(layers)
    return model

def get_discriminator(num_output_channels=3):
    ndf = 64
    kernel_size = 4
    n_layers = 4
    n_layers -= 1

    layers = [
        tf.keras.Input((None,None,num_output_channels)),
        tf.keras.layers.Conv2D(filters=ndf, kernel_size=kernel_size, strides=2, padding='SAME', activation=None),
        tf.keras.layers.LeakyReLU(0.2)
    ]
    
    for i in range(1, n_layers):
        layers += [
            tf.keras.layers.Conv2D(filters=ndf * (2**i), kernel_size=kernel_size, strides=2, padding='SAME', activation=None),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(0.2)
        ]
        
    layers += [
        tf.keras.layers.Conv2D(filters=ndf * (2**n_layers), kernel_size=kernel_size, strides=1, padding='SAME', activation=None),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.LeakyReLU(0.2),
        tf.keras.layers.Conv2D(filters=1, kernel_size=kernel_size, strides=1, padding='SAME', activation=None)
    ]
    model = tf.keras.Sequential(layers)
    return model

# Dataset Preparation

In [None]:
IMAGE_SIZE = 128

def normalize(img):
    return (tf.cast(img, tf.float32) / 128.0) - 1.0

def norm_gen_output(img):
    return (img + 1) / 2

def resize_and_crop(img):
    img = tf.image.resize(img, (IMAGE_SIZE+30,IMAGE_SIZE+30))
    img = tf.image.random_crop(img, (IMAGE_SIZE,IMAGE_SIZE,img.shape[-1]))
    return img

def perp_train_ds(ds):
    ds = ds.map(lambda a,b: a)
    ds = ds.map(normalize)
    ds = ds.cache()
    ds = ds.map(resize_and_crop)
    ds = ds.map(tf.image.random_flip_left_right)
    ds = ds.shuffle(1000)
    ds = ds.repeat()
    ds = ds.batch(1)
    return ds

def perp_test_ds(ds):
    ds = ds.map(lambda a,b: a)
    ds = ds.map(normalize)
    ds = ds.cache()
    ds = ds.shuffle(1000)
    ds = ds.repeat()
    ds = ds.batch(1)
    return ds

dataset, metadata = tfds.load('cycle_gan/horse2zebra',  with_info=True, as_supervised=True)
#dataset, metadata = tfds.load('cycle_gan/vangogh2photo',  with_info=True, as_supervised=True)
#dataset, metadata = tfds.load('cycle_gan/monet2photo',  with_info=True, as_supervised=True)

 
train_A, train_B = dataset['trainB'], dataset['trainA']
test_A, test_B = dataset['testB'], dataset['testA']

train_A = perp_train_ds(train_A)
train_B = perp_train_ds(train_B)

test_A = perp_test_ds(test_A)
test_B = perp_test_ds(test_B)

iter_A = (iter(train_A))
iter_B = (iter(train_B))

In [None]:
img_A = next(iter_A)[0]
img_B = next(iter_B)[0]
plt.subplot(1,2,1)
plt.imshow(norm_gen_output(img_A))

plt.subplot(1,2,2)
plt.imshow(norm_gen_output(img_B))
plt.show()

# Training

In [None]:
binary_crossentropy_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(disc_out_real, disc_out_generated):
    real_loss = binary_crossentropy_loss(tf.ones_like(disc_out_real), disc_out_real)
    generated_loss = binary_crossentropy_loss(tf.zeros_like(disc_out_generated), disc_out_generated)
    return (real_loss + generated_loss) / 2.0

def generator_loss(disc_out_generated):
    return binary_crossentropy_loss(tf.ones_like(disc_out_generated)+0.0, disc_out_generated)

def mean_absolute_error(a,b):
    return tf.reduce_mean(tf.abs(a-b))

In [None]:
generator_A = get_resnet_generator()
generator_B = get_resnet_generator()

discriminator_A = get_discriminator()
discriminator_B = get_discriminator()

generator_optimizer_A = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
generator_optimizer_B = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)

discriminator_optimizer_A = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
discriminator_optimizer_B = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
losses = []

In [None]:
batch_A = next(iter_A)
batch_B = next(iter_B)
generator_A(batch_A)
generator_B(batch_B)

discriminator_A(batch_B)
discriminator_A(batch_A)
pass

In [None]:
@tf.function
def train(batch_A, batch_B):
    with tf.GradientTape(persistent=True) as tape:
        generated_B = generator_A(batch_A, training=True)
        cycle_A = generator_B(generated_B, training=True)
    
        generated_A = generator_B(batch_B, training=True)
        cycle_B = generator_A(generated_A, training=True)
    

        cycle_loss = (mean_absolute_error(batch_A, cycle_A) + mean_absolute_error(batch_B, cycle_B)) / 2.0
    
        disc_out_real_A = discriminator_A(batch_B, training=True)
        disc_out_generated_A = discriminator_A(generated_B, training=True)
    
        disc_out_real_B = discriminator_B(batch_A, training=True)
        disc_out_generated_B = discriminator_B(generated_A, training=True)

        discriminator_loss_A = discriminator_loss(disc_out_real_A, disc_out_generated_A)
        discriminator_loss_B = discriminator_loss(disc_out_real_B, disc_out_generated_B)

        identity_loss_A = mean_absolute_error(batch_B, generator_A(batch_B, training=True))
        identity_loss_B = mean_absolute_error(batch_A, generator_B(batch_A, training=True))
        
        generator_loss_A = 10 * cycle_loss + generator_loss(disc_out_generated_A) + 5 * identity_loss_A
        generator_loss_B = 10 * cycle_loss + generator_loss(disc_out_generated_B) + 5 * identity_loss_B
        
        
        
    gen_grads_A = tape.gradient(generator_loss_A, generator_A.trainable_variables)
    gen_grads_B = tape.gradient(generator_loss_B, generator_B.trainable_variables)

    generator_optimizer_A.apply_gradients(zip(gen_grads_A, generator_A.trainable_variables))
    generator_optimizer_B.apply_gradients(zip(gen_grads_B, generator_B.trainable_variables))

    disc_grads_A = tape.gradient(discriminator_loss_A, discriminator_A.trainable_variables)
    disc_grads_B = tape.gradient(discriminator_loss_B, discriminator_B.trainable_variables)

    discriminator_optimizer_A.apply_gradients(zip(disc_grads_A, discriminator_A.trainable_variables))
    discriminator_optimizer_B.apply_gradients(zip(disc_grads_B, discriminator_B.trainable_variables))
    return cycle_loss

In [None]:
%pylab inline
pylab.rcParams['figure.figsize'] = (20, 20)
for epoche in range(100):
    for step in range(20):
        batch_A = next(iter_A)
        batch_B = next(iter_B)
        cycle_loss = train(batch_A, batch_B)
        losses.append(cycle_loss)
        print(step)
    batch_A = next(iter_A)
    batch_B = next(iter_B)

    plt.subplot(1,2,1)
    plt.imshow(norm_gen_output(batch_A)[0])

    plt.subplot(1,2,2)
    plt.imshow(norm_gen_output(generator_A(batch_A, training=True))[0])

    plt.show()

    plt.subplot(1,2,1)
    plt.imshow(norm_gen_output(batch_A)[0])

    plt.subplot(1,2,2)
    plt.imshow(norm_gen_output(generator_B(generator_A(batch_A, training=True), training=True))[0])

    plt.show()
    
    plt.plot(losses)
    plt.show()

# Model Export

In [None]:
class GeneratorModel(tf.Module):
    def __init__(self, model):
        super(GeneratorModel, self).__init__()
        self.model = model

    @tf.function(input_signature=[tf.TensorSpec([None, None,None,3], tf.float32)])
    def __call__(self, x):
        #x = tf.expand_dims(x,0)
        return self.model(x, training=True) #[0]

In [None]:
tf.saved_model.save(GeneratorModel(generator_A), './generator_A_2')

In [None]:
tf.saved_model.save(GeneratorModel(generator_B), './generator_B_2')

# Model Import

In [None]:
model = tf.saved_model.load('./generator_A_2')

In [None]:
img = next(iter_A)

In [None]:

plt.subplot(1,2,1)
plt.imshow(norm_gen_output(model(cv2.resize(img[0].numpy(),(0,0), fx=6, fy=6)[np.newaxis,...])[0]))

plt.subplot(1,2,2)
plt.imshow(norm_gen_output(img[0]))
plt.show()