## Neural Style Transfer using Keras

NST can be implmented using any pretrained convnet. Here we will use the VGG19 network.

NOTE: What this technique achieves is merely a form of image retexturing or texture trasnfer. It works best with style refernce images with strong textures. and with content target that dont require high level of details to recognize.

This algorithm is closer to classical signal processing that to AI, so dont expect it to work like magic.

### Defining initial variables

In [56]:
from keras.preprocessing.image import load_img, img_to_array

# Path of the img imput files
target_image_path = './img/portrait.png'
style_reference_image_path = './img/transfer_style_reference.png'

# Dimension of the generated picture
width, height = load_img(target_image_path).size
img_height = 400
img_width = int(width * img_height / height)

In [57]:
width, height, img_width

(283, 427, 265)

### Auxilary functions

In [58]:
import numpy as np
from keras.applications import vgg19

In [59]:
def preprocess_image(image_path):
    img = load_img(image_path, target_size=(img_height, img_width))
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = vgg19.preprocess_input(img)
    return img

In [60]:
def deprocess_image(x):
    """ Zero centering by removing the mean pixel value from ImageNet.
        This reverses a transformation done by vgg19.preprocess_input """
    x[:, :, 0] += 103.939
    x[:, :, 1] += 116.779
    x[:, :, 2] += 123.68
    # Convert images from BGR to RGB(Part of vgg19 reversal)
    x = x[:, :, ::-1]
    x = np.clip(x, 0, 255).astype('uint8')
    return x

### Loading the pretrained VGG19 network and applying it to the three images

In [61]:
from tensorflow.keras import backend as K
import tensorflow as tf
tf.compat.v1.disable_eager_execution()

In [62]:
target_image = K.constant(preprocess_image(target_image_path))
style_reference_image = K.constant(preprocess_image(style_reference_image_path))
combination_image = K.placeholder((1, img_height, img_width, 3))

# Combining three images into a single batch
input_tensor = K.concatenate([target_image, style_reference_image, combination_image], axis=0)

# Build the VGG19 network with the batch of three images as input. The model will be laoded with pretrained ImageNet weights
model = vgg19.VGG19(input_tensor=input_tensor, weights='imagenet', include_top=False)
print('Model Loaded.')

Model Loaded.


### Content Loss

In [63]:
def content_loss(base, combination):
    return K.sum(K.square(combination -  base))

### Style Loss

In [64]:
def gram_matrix(x):
    features = K.batch_flatten(K.permute_dimensions(x, (2, 0,1)))
    gram = K.dot(features, K.transpose(features))
    return gram

def style_loss(style, combination):
    s = gram_matrix(style)
    c = gram_matrix(combination)
    channels = 3
    size = img_height * img_width
    return K.sum(K.square(s-c)) / (4. * (channels ** 2) * (size ** 2))

### Total variation loss

With the above 2 losses we are adding 3rd loss total variation loss which opeartes on the pixels of the generated combination image. It encourages spatial continuity in the generated image, thus avioding overly pixalated results. We can interprete as a regulisation loss.

In [65]:
def total_variation_loss(x):
    a = K.square(x[:, :img_height - 1, :img_width - 1, :] - x[:, 1:, :img_width - 1, :])
    b = K.square(x[:, :img_height - 1, :img_width - 1, :] - x[:, :img_height - 1, 1:, :])
    return  K.sum(K.pow(a + b, 1.25))

### Defining final loss we will minimize

In [66]:
# Dinctinary that maps layer names to activation tensors
output_dict = dict([(layer.name, layer.output) for layer in model.layers])
content_layer = 'block5_conv2' #For content loss
style_layers = ['block1_conv1', 'block2_conv1', 'block3_conv1', 'block4_conv1', 'block5_conv1'] # FOr style loss

# Weights in the weighted average of the loss components
total_variation_weight = 1e-4
style_weight = 1.
content_weight = 0.25

## Adds the content loss
# Defining the loss by adding all components to this scalar variable
loss = K.variable(0.)
layer_features = output_dict[content_layer]
target_image_features = layer_features[0, :, :, :]
combination_features = layer_features[2, :, :, :]
loss = loss + content_weight * content_loss(target_image_features, combination_features)

# Add a style loss component for each target layer
for layer_name in style_layers:
    layers_features = output_dict[layer_name]
    style_reference_features = layer_features[1, :, :, :]
    combination_features = layer_features[2, :, :, :]
    sl = style_loss(style_reference_features, combination_features)
    loss = loss + (style_weight / len(style_layers)) * sl
    
# Add the total variation loss
loss = loss + total_variation_weight * total_variation_loss(combination_image)

### Setting up the gradient-descent process

Computing loss and Gradient process seprately will be very slow as a lot of redundant process will run in between them hence we will setup a Python class named Evaluator that computes both the loss value and gradient value at once returns the loss when called the first time and caches the gradient for the next call.

In [67]:
# Get the gradients of the generated image with regard to the loss
grads = K.gradients(loss, combination_image)[0]

# Function to fetch the values of the current loss and the current gradients
fetch_loss_and_grads = K.function([combination_image], [loss, grads])

class Evaluator(object):
    
    def __init__(self):
        self.loss_value = None
        self.grads_value = None
        
    def loss(self, x):
        assert self.loss_value is None
        x = x.reshape((1, img_height, img_width, 3))
        outs = fetch_loss_and_grads([x])
        loss_value = outs[0]
        grad_values = outs[1].flatten().astype('float64')
        self.loss_value = loss_value
        self.grad_values = grad_values
        return self.loss_value
    
    def grads(self, x):
        assert self.loss_value is not None
        grad_values = np.copy(self.grad_values)
        self.loss_value = None
        self.grad_values = None
        return grad_values
    
evaluator = Evaluator()

### Style transfer loop

In [68]:
from scipy.optimize import fmin_l_bfgs_b
import matplotlib.pyplot as plt
import time

In [69]:
result_prefix = 'my_result'
iterations = 20

x = preprocess_image(target_image_path)
# Need to flatten bcz fmin_l_bfgs_b can works on flattened image
x = x.flatten()

# Run L-BFGS optimizations over the pixels of the generated image to minimize the style loss.
# We have to pass the function that computes the loss ad the function that computs the gradients as two separate arguments

for i in range(iterations):
    print(f"Start of iterations: {i}")
    start_time = time.time()
    x, min_val, info = fmin_l_bfgs_b(evaluator.loss, x, fprime=evaluator.grads, maxfun=20)
    print(f"Current loss value: {min_val}")
    img = x.copy().reshape((img_height, img_width, 3))
    img = deprocess_image(img)
    fname = f"./img/generated_img/{result_prefix}_at_iteration_{i}.png"
    plt.imsave(fname, img)
    print(f"Image saved as :{fname}")
    end_time = time.time()
    print(f"Iteration {i} completed in {end_time - start_time}.")

Start of iterations: 0
Current loss value: 123301.0546875
Image saved as :./img/generated_img/my_result_at_iteration_0.png
Iteration 0 completed in 122.82134366035461.
Start of iterations: 1
Current loss value: 123296.71875
Image saved as :./img/generated_img/my_result_at_iteration_1.png
Iteration 1 completed in 129.85703492164612.
Start of iterations: 2
Current loss value: 123293.453125
Image saved as :./img/generated_img/my_result_at_iteration_2.png
Iteration 2 completed in 142.89585065841675.
Start of iterations: 3
Current loss value: 123291.0859375
Image saved as :./img/generated_img/my_result_at_iteration_3.png
Iteration 3 completed in 138.08231687545776.
Start of iterations: 4
Current loss value: 123289.1484375
Image saved as :./img/generated_img/my_result_at_iteration_4.png
Iteration 4 completed in 133.66913509368896.
Start of iterations: 5
Current loss value: 123288.2109375
Image saved as :./img/generated_img/my_result_at_iteration_5.png
Iteration 5 completed in 146.11634063720