In [None]:
#!pip install git+https://www.github.com/keras-team/keras-contrib.git

In [24]:
from random import random
from numpy import load
from numpy import zeros
from numpy import ones
from numpy import asarray
from numpy.random import randint
from tensorflow.keras.optimizers import Adam
from keras.initializers import RandomNormal
from keras.models import Model
from keras.models import Input
from keras.layers import Conv2D
from keras.layers import Conv2DTranspose
from keras.layers import LeakyReLU
from keras.layers import Activation
from keras.layers import Concatenate
from keras_contrib.layers.normalization.instancenormalization import InstanceNormalization
from matplotlib import pyplot
import tensorflow as tf
import keras
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output
from zipfile import ZipFile
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split


In [25]:
AUTOTUNE = tf.data.AUTOTUNE

LAMBDA = 10
EPOCHS = 20
loss_obj = keras.losses.BinaryCrossentropy(from_logits=True)

kernel_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)
gamma_init = keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

orig_img_size = (286, 286)
input_img_size = (256, 256, 3)
buffer_size = 256
batch_size = 1

In [26]:
def normalize_img(img):
    img = tf.cast(img, dtype=tf.float32)
    return (img / 127.5) - 1.0


def preprocess_train_image(img):
    img = tf.image.random_flip_left_right(img)
    img = normalize_img(img)
    return img


def preprocess_test_image(img):
    img = tf.image.resize(img, [input_img_size[0], input_img_size[1]])
    img = normalize_img(img)
    return img

                    


In [27]:
# photo_train = tf.keras.utils.image_dataset_from_directory("images/photos_activities",validation_split=0.2,subset="training",seed=123)
# photo_test = tf.keras.utils.image_dataset_from_directory("images/photos_activities",validation_split=0.2,subset="validation",seed=123)
# paint_train = tf.keras.utils.image_dataset_from_directory("images/afremov",validation_split=0.2,subset="training",seed=123)
# paint_test = tf.keras.utils.image_dataset_from_directory("images/afremov",validation_split=0.2,subset="validation",seed=123)
paint = tf.keras.utils.image_dataset_from_directory("images/afremov/",label_mode=None,batch_size=1)
#photo = tf.keras.utils.image_dataset_from_directory("images/photos_activities/",label_mode=None,batch_size=1)
photo = tf.keras.utils.image_dataset_from_directory("images/landscape/",label_mode=None,batch_size=1)

photo_test = tf.keras.utils.image_dataset_from_directory("images/test/",label_mode=None,batch_size=1)

#print(photo)
photo = (
    photo.unbatch().map(preprocess_train_image, num_parallel_calls=AUTOTUNE)
    .cache().batch(1)
)
paint = (
    paint.unbatch().map(preprocess_train_image, num_parallel_calls=AUTOTUNE)
    .cache().batch(1)
)
photo_test = (
    photo_test.unbatch().map(preprocess_train_image, num_parallel_calls=AUTOTUNE)
    .cache().batch(1)
)
sample_photo = next(iter(photo_test))


Found 61 files belonging to 1 classes.
Found 266 files belonging to 1 classes.
Found 1 files belonging to 1 classes.


In [28]:
def define_discriminator(image_shape):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# source image input
	in_image = Input(shape=image_shape)
	# C64
	d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
	d = LeakyReLU(alpha=0.2)(d)
	# C128
	d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C256
	d = Conv2D(256, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# C512
	d = Conv2D(512, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# second last output layer
	d = Conv2D(512, (4,4), padding='same', kernel_initializer=init)(d)
	d = InstanceNormalization(axis=-1)(d)
	d = LeakyReLU(alpha=0.2)(d)
	# patch output
	patch_out = Conv2D(1, (4,4), padding='same', kernel_initializer=init)(d)
	# define model
	model = Model(in_image, patch_out)
	# compile model
	model.compile(loss='mse', optimizer=Adam(lr=0.0002, beta_1=0.5), loss_weights=[0.5])
	return model

In [29]:
def resnet_block(n_filters, input_layer):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# first layer convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(input_layer)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# second convolutional layer
	g = Conv2D(n_filters, (3,3), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	# concatenate merge channel-wise with input layer
	g = Concatenate()([g, input_layer])
	return g

def define_generator(image_shape, n_resnet=9):
	# weight initialization
	init = RandomNormal(stddev=0.02)
	# image input
	in_image = Input(shape=image_shape)

	g = Conv2D(64, (7,7), padding='same', kernel_initializer=init)(in_image)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d128
	g = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# d256
	g = Conv2D(256, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# R256
	for _ in range(n_resnet):
		g = resnet_block(256, g)
	# u128
	g = Conv2DTranspose(128, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)
	# u64
	g = Conv2DTranspose(64, (3,3), strides=(2,2), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	g = Activation('relu')(g)

	g = Conv2D(3, (7,7), padding='same', kernel_initializer=init)(g)
	g = InstanceNormalization(axis=-1)(g)
	out_image = Activation('tanh')(g)
	# define model
	model = Model(in_image, out_image)
	return model

In [30]:
image_shape=(256,256,3)

photo_generator = define_generator(image_shape, 9)
style_generator = define_generator(image_shape, 9)
style_discriminator = define_discriminator(image_shape)
photo_discriminator = define_discriminator(image_shape)
photo_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
style_generator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)

style_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)
photo_discriminator_optimizer = tf.keras.optimizers.Adam(learning_rate=2e-4, beta_1=0.5)



In [31]:
# Loss function for evaluating adversarial loss
adv_loss_fn = keras.losses.MeanSquaredError()

# Define the loss function for the generators
def generator_loss_fn(fake):
    fake_loss = adv_loss_fn(tf.ones_like(fake), fake)
    return fake_loss


# Define the loss function for the discriminators
def discriminator_loss_fn(real, fake):
    real_loss = adv_loss_fn(tf.ones_like(real), real)
    fake_loss = adv_loss_fn(tf.zeros_like(fake), fake)
    return (real_loss + fake_loss) * 0.5

cycle_loss_fn = keras.losses.MeanAbsoluteError()
identity_loss_fn = keras.losses.MeanAbsoluteError()

In [32]:
lambda_cycle=10.0,
lambda_identity=0.5,
@tf.function
def train_step(style, photo):

    with tf.GradientTape(persistent=True) as tape:
        photo_generated = photo_generator(style, training=True)
        style_generated = style_generator(photo, training=True)

        cycled_style = style_generator(photo_generated, training=True)
        cycled_photo = photo_generator(style_generated, training=True)

        style_identity = style_generator(style, training=True)
        photo_identity = photo_generator(photo, training=True)

        disc_real_x = style_discriminator(style, training=True)
        disc_fake_x = style_discriminator(style_generated, training=True)

        disc_real_y = photo_discriminator(photo, training=True)
        disc_fake_y = photo_discriminator(photo_generated, training=True)

        gen_G_loss = generator_loss_fn(disc_fake_y)
        gen_F_loss = generator_loss_fn(disc_fake_x)

        cycle_loss_G = cycle_loss_fn(photo, cycled_photo) * lambda_cycle
        cycle_loss_F = cycle_loss_fn(style, cycled_style) * lambda_cycle

        id_loss_G = identity_loss_fn(photo, photo_identity)* lambda_cycle* lambda_identity
        id_loss_F = identity_loss_fn(style, style_identity)* lambda_cycle* lambda_identity

        total_loss_G = gen_G_loss + cycle_loss_G + id_loss_G
        total_loss_F = gen_F_loss + cycle_loss_F + id_loss_F

        disc_X_loss = discriminator_loss_fn(disc_real_x, disc_fake_x)
        disc_Y_loss = discriminator_loss_fn(disc_real_y, disc_fake_y)

    grads_G = tape.gradient(total_loss_G, photo_generator.trainable_variables)
    grads_F = tape.gradient(total_loss_F, style_generator.trainable_variables)

    disc_X_grads = tape.gradient(disc_X_loss, style_discriminator.trainable_variables)
    disc_Y_grads = tape.gradient(disc_Y_loss, photo_discriminator.trainable_variables)

    photo_generator_optimizer.apply_gradients(
        zip(grads_G, photo_generator.trainable_variables)
    )
    style_generator_optimizer.apply_gradients(
        zip(grads_F, style_generator.trainable_variables)
    )
    style_discriminator_optimizer.apply_gradients(
        zip(disc_X_grads, style_discriminator.trainable_variables)
    )
    photo_discriminator_optimizer.apply_gradients(
        zip(disc_Y_grads, photo_discriminator.trainable_variables)
    )

In [33]:
def generate_images(model, test_input):
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [34]:
checkpoint_path = "./checkpoints/train"

ckpt = tf.train.Checkpoint(photo_generator=photo_generator,
                           style_generator=style_generator,
                           style_discriminator=style_discriminator,
                           photo_discriminator=photo_discriminator,
                           photo_generator_optimizer=photo_generator_optimizer,
                           style_generator_optimizer=style_generator_optimizer,
                           style_discriminator_optimizer=style_discriminator_optimizer,
                           photo_discriminator_optimizer=photo_discriminator_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [35]:
EPOCHS = 20
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((photo, paint)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print (str(n)+'.', end='')
    n += 1
  clear_output(wait=True)
  generate_images(photo_generator, sample_photo)
  if (epoch + 1) % 1 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,ckpt_save_path))
  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,time.time()-start))

0.

KeyboardInterrupt: 

In [None]:
for inp in photo_test.take(1):
  generate_images(photo_generator, inp)