# Neural Style Transfer

In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras import backend as K
from imageio import mimsave
from IPython.display import display as display_fn
from IPython.display import Image, clear_output
%matplotlib inline

In [None]:
def tensor_to_image(tensor):
  '''converts a tensor to an image'''
  tensor_shape = tf.shape(tensor)
  number_elem_shape = tf.shape(tensor_shape)
  if number_elem_shape > 3:
    assert tensor_shape[0] == 1
    tensor = tensor[0]
  return tf.keras.preprocessing.image.array_to_img(tensor)

def load_img(path_to_img):
  '''loads an image as a tensor and scales it to 512 pixels'''
  max_dim = 512
  image = tf.io.read_file(path_to_img)
  image = tf.image.decode_jpeg(image)
  image = tf.image.convert_image_dtype(image, tf.float32)

  shape = tf.shape(image)[:-1]
  shape = tf.cast(tf.shape(image)[:-1], tf.float32)
  long_dim = max(shape)
  scale = max_dim / long_dim

  new_shape = tf.cast(shape * scale, tf.int32)

  image = tf.image.resize(image, new_shape)
  image = image[tf.newaxis, :]
  image = tf.image.convert_image_dtype(image, tf.uint8)
  return image

def load_images(content_path, style_path):
  '''loads the content and path images as tensors'''
  content_image = load_img('{}'.format(content_path))
  style_image = load_img('{}'.format(style_path))
  return content_image, style_image

def imshow(image, title=None):
  '''displays an image with corresponding title'''
  if len(image.shape) > 3:
    image = tf.squeeze(image, axis=0)

  plt.imshow(image)
  if title:
    plt.title(title)

def show_images_with_objects(images, titles=[]):
  '''displays a row of images with corresponding titles'''
  if len(images) != len(titles):
    return
  
  plt.figure(figsize=(20, 12))
  for idx, (image, title) in enumerate(zip(images, titles)):
    plt.subplot(1, len(images), idx + 1)
    plt.xticks([])
    plt.yticks([])
    imshow(image, title)

def display_gif(gif_path):
  '''displays the generated images as an animated gif'''
  with open(gif_path, 'rb') as f:
    display_fn(Image(data=f.read(), format='png'))

def create_gif(gif_path, images):
  '''create animation of generated images'''
  mimsave(gif_path, images, fps=1)
  return gif_path

def clip_image_values(image, min_value=0.0, max_value=255.0):
  '''clips the image pixel values by the given min and max'''
  return tf.clip_by_value(image, clip_value_min=min_value, clip_value_max=max_value)

def preprocess_image(image):
  '''centers the image pixel values of a given image to use with VGG-19'''
  image = tf.cast(image, dtype=tf.float32)
  image = tf.keras.applications.vgg19.preprocess_input(image)
  return image

## Download images

In [None]:
IMAGE_DIR = 'images'

!mkdir {IMAGE_DIR}

# download images to the directory you just created
!wget -q -O ./images/cafe.jpg https://cdn.pixabay.com/photo/2018/07/14/15/27/cafe-3537801_1280.jpg
!wget -q -O ./images/swan.jpg https://cdn.pixabay.com/photo/2017/02/28/23/00/swan-2107052_1280.jpg
!wget -q -O ./images/tnj.jpg https://i.dawn.com/large/2019/10/5db6a03a4c7e3.jpg
!wget -q -O ./images/rudolph.jpg https://cdn.pixabay.com/photo/2015/09/22/12/21/rudolph-951494_1280.jpg
!wget -q -O ./images/dynamite.jpg https://cdn.pixabay.com/photo/2015/10/13/02/59/animals-985500_1280.jpg
!wget -q -O ./images/painting.jpg https://storage.googleapis.com/download.tensorflow.org/example_images/Vassily_Kandinsky%2C_1913_-_Composition_7.jpg

print("image files you can choose from: ")
!ls images

In [None]:
# set default images
content_path = f'{IMAGE_DIR}/swan.jpg'
style_path = f'{IMAGE_DIR}/painting.jpg'

In [None]:
# display content and style images
content_image, style_image = load_images(content_path, style_path)
show_images_with_objects([content_image, style_image], 
                         titles=[f'content image: {content_path}',
                                 f'style image: {style_path}'])

In [None]:
print(np.max(content_image), content_image.shape)
print(np.max(style_image), style_image.shape)
print(np.max(preprocess_image(content_image)), np.min(preprocess_image(content_image)))

## Build model

In [None]:
# clear session to make layer naming consistent when re-running this cell
K.clear_session()

tmp_vgg = tf.keras.applications.vgg19.VGG19()
tmp_vgg.summary()

del tmp_vgg

In [None]:
style_layers = ['block{}_conv1'.format(x) for x in range(1, 6)]
content_layers = ['block5_conv2']
output_layers = style_layers + content_layers
NUM_CONTENT_LAYERS = len(content_layers)
NUM_STYLE_LAYERS = len(style_layers)

In [None]:
def vgg_model(layer_names):
  """ Creates a vgg model that outputs the style and content layer activations.
  
  Args:
    layer_names: a list of strings, representing the names of the desired content and style layers
    
  Returns:
    A model that takes the regular vgg19 input and outputs just the content and style layers.
  
  """
  vgg = tf.keras.applications.vgg19.VGG19(include_top=False,
                                          weights='imagenet')
  vgg.trainable = False
  outputs = [vgg.get_layer(name).output for name in layer_names]

  model = tf.keras.Model(inputs=vgg.input, outputs=outputs)
  return model

In [None]:
# clear session to make layer naming consistent if re-running the cell
K.clear_session()

vgg = vgg_model(output_layers)
vgg.summary()

## Define the loss

In [None]:
def get_style_loss(features, targets):
  """Expects two images of dimension h, w, c
  
  Args:
    features: tensor with shape: (height, width, channels)
    targets: tensor with shape: (height, width, channels)

  Returns:
    style loss (scalar)
  """
  return tf.reduce_mean(tf.square(features - targets))

def get_content_loss(features, targets):
  """Expects two images of dimension h, w, c
  
  Args:
    features: tensor with shape: (height, width, channels)
    targets: tensor with shape: (height, width, channels)
  
  Returns:
    content loss (scalar)
  """
  return 0.5 * tf.reduce_sum(tf.square(features - targets))

### Calculate Gram Matrix

In [None]:
def gram_matrix(input_tensor):
  """ Calculates the gram matrix and divides by the number of locations
  Args:
    input_tensor: tensor of shape (batch, height, width, channels)
    
  Returns:
    scaled_gram: gram matrix divided by the number of locations
  """
  gram = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
  input_shape = tf.shape(input_tensor)
  height = input_shape[1]
  width = input_shape[2]

  num_locations = tf.cast(height * width, tf.float32)
  scaled_gram = gram / num_locations
  return scaled_gram

### Get style image features

In [None]:
tmp_layer_list = [layer.output for layer in vgg.layers]
tmp_layer_list

In [None]:
def get_style_image_features(image):
  """ Get the style image features
  
  Args:
    image: an input image
    
  Returns:
    gram_style_features: the style features as gram matrices
  """
  preprocessed_style_image = preprocess_image(image)
  outputs = vgg(preprocessed_style_image)
  style_outputs = outputs[:NUM_STYLE_LAYERS]
  gram_style_features = [gram_matrix(style_layer) for style_layer in style_outputs]
  return gram_style_features

### Get content image features

In [None]:
def get_content_image_features(image):
  """ Get the content image features
  
  Args:
    image: an input image
    
  Returns:
    content_outputs: the content features of the image
  """
  preprocessed_content_image = preprocess_image(image)
  outputs = vgg(preprocessed_content_image)
  content_outputs = outputs[NUM_STYLE_LAYERS:]
  return content_outputs

### Calculate style and content loss

In [None]:
def get_style_content_loss(style_targets, style_outputs, content_targets, content_outputs, style_weight, content_weight):
  """ Combine the style and content loss
  Args:
    style_targets: style features of the style image
    style_outputs: style features of the generated image
    content_targets: content features of the content image
    content_outputs: content features of the generated image
    style_weight: weight given to the style loss
    content_weight: weight given to the content loss

  Returns:
    total_loss: the combined style and content loss
  """
  style_loss = tf.add_n([get_style_loss(style_output, style_target) for style_output, style_target in zip(style_outputs, style_targets)])
  content_loss = tf.add_n([get_content_loss(content_output, content_target) for content_output, content_target in zip(content_outputs, content_targets)])
  style_loss = style_loss * style_weight / NUM_STYLE_LAYERS
  content_loss = content_loss * content_weight / NUM_CONTENT_LAYERS
  total_loss = style_loss + content_loss
  return total_loss

## Generate the stylized image

In [None]:
def calculate_gradients(image, style_targets, content_targets, style_weight, content_weight, var_weight):
  """ Calculate the gradients of the loss with respect to the generated image
  Args:
    image: generated image
    style_targets: style features of the style image
    content_targets: content features of the content image
    style_weight: weight given to the style loss
    content_weight: weight given to the content loss
    var_weight: weight given to the total variation loss
  
  Returns:
    gradients: gradients of the loss with respect to the input image
  """
  with tf.GradientTape() as tape:
    style_features = get_style_image_features(image)
    content_features = get_content_image_features(image)
    loss = get_style_content_loss(style_targets, style_features, content_targets, content_features, style_weight, content_weight)
  gradients = tape.gradient(loss, image)
  return gradients

### Update the image with style

In [None]:
def update_image_with_style(image, style_targets, content_targets, style_weight, var_weight, content_weight, optimizer):
  """
  Args:
    image: generated image
    style_targets: style features of the style image
    content_targets: content features of the content image
    style_weight: weight given to the style loss
    content_weight: weight given to the content loss
    var_weight: weight given to the total variation loss
    optimizer: optimizer for updating the input image
  """
  gradients = calculate_gradients(image, style_targets, content_targets, style_weight, content_weight, var_weight)
  optimizer.apply_gradients([(gradients, image)])
  image.assign(clip_image_values(image, min_value=0.0, max_value=255.0))

## Style transfer

In [None]:
def fit_style_transfer(style_image, content_image, style_weight=1e-2, content_weight=1e-4, var_weight=0, optimizer='adam', epochs=1, steps_per_epoch=1):
  """ Performs neural style transfer.
  Args:
    style_image: image to get style features from
    content_image: image to stylize 
    style_targets: style features of the style image
    content_targets: content features of the content image
    style_weight: weight given to the style loss
    content_weight: weight given to the content loss
    var_weight: weight given to the total variation loss
    optimizer: optimizer for updating the input image
    epochs: number of epochs
    steps_per_epoch = steps per epoch
  
  Returns:
    generated_image: generated image at final epoch
    images: collection of generated images per epoch  
  """
  images = []
  step = 0
  style_targets = get_style_image_features(style_image)
  content_targets = get_content_image_features(content_image)
  
  generated_image = tf.cast(content_image, dtype=tf.float32)
  generated_image = tf.Variable(generated_image)

  images.append(content_image)

  for n in range(epochs):
    for m in range(steps_per_epoch):
      step += 1

      update_image_with_style(generated_image, style_targets, content_targets, style_weight, var_weight, content_weight, optimizer)
      print(".", end='')

      if (m + 1) % 10  == 0:
        images.append(generated_image)
      
    clear_output(wait=True)
    display_image = tensor_to_image(generated_image)
    display_fn(display_image)

    images.append(generated_image)
    print("Train step: {}".format(step))
  
  generated_image = tf.cast(generated_image, tf.uint8)
  return generated_image, images

In [None]:
style_weight = 1e-4
content_weight = 1e-32

adam = tf.optimizers.Adam(
    tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=30.0, decay_steps=100, decay_rate=0.80
    )
)

stylized_image, display_images = fit_style_transfer(style_image, content_image, style_weight, content_weight, var_weight=0, 
                                                    optimizer=adam, epochs=10, steps_per_epoch=100)

In [None]:
GIF_PATH = 'style_transfer.gif'
gif_images = [np.squeeze(image.numpy().astype(np.uint8), axis=0) for image in display_images]
gif_path = create_gif(GIF_PATH, gif_images)
display_gif(gif_path)

## Total variation loss

In [None]:
# plot utilities

def high_pass_x_y(image):
  x_var = image[:, :, 1:, :] - image[:, :, :-1, :]
  y_var = image[:, 1:, :, :] - image[:, :-1, :, :]
  return x_var, y_var

def plot_deltas_for_single_image(x_deltas, y_deltas, name="Original", row=1):
  plt.figure(figsize=(14,10))
  plt.subplot(row,2,1)
  plt.yticks([])
  plt.xticks([])

  clipped_y_deltas = clip_image_values(2*y_deltas+0.5, min_value=0.0, max_value=1.0)
  imshow(clipped_y_deltas, "Horizontal Deltas: {}".format(name))

  plt.subplot(row,2,2)
  plt.yticks([])
  plt.xticks([])
  
  clipped_x_deltas = clip_image_values(2*x_deltas+0.5, min_value=0.0, max_value=1.0)
  imshow(clipped_x_deltas, "Vertical Deltas: {}".format(name))


def plot_deltas(original_image_deltas, stylized_image_deltas):
  orig_x_deltas, orig_y_deltas = original_image_deltas
  
  stylized_x_deltas, stylized_y_deltas = stylized_image_deltas

  plot_deltas_for_single_image(orig_x_deltas, orig_y_deltas, name="Original")
  plot_deltas_for_single_image(stylized_x_deltas, stylized_y_deltas, name="Stylized Image", row=2)

In [None]:
# Display the frequency variations

original_x_deltas, original_y_deltas = high_pass_x_y(
    tf.image.convert_image_dtype(content_image, dtype=tf.float32))

stylized_image_x_deltas, stylized_image_y_deltas = high_pass_x_y(
    tf.image.convert_image_dtype(stylized_image, dtype=tf.float32))

plot_deltas((original_x_deltas, original_y_deltas), (stylized_image_x_deltas, stylized_image_y_deltas))

In [None]:
def calculate_gradients(image, style_targets, content_targets, 
                        style_weight, content_weight, var_weight):
  """ Calculate the gradients of the loss with respect to the generated image
  Args:
    image: generated image
    style_targets: style features of the style image
    content_targets: content features of the content image
    style_weight: weight given to the style loss
    content_weight: weight given to the content loss
    var_weight: weight given to the total variation loss
  
  Returns:
    gradients: gradients of the loss with respect to the input image
  """

  with tf.GradientTape() as tape:
    style_features = get_style_image_features(image)
    content_features = get_content_image_features(image)
    loss = get_style_content_loss(style_targets, style_features, content_targets, content_features, style_weight, content_weight)
    loss += var_weight * tf.image.total_variation(image)
  gradients = tape.gradient(loss, image)
  return gradients

In [None]:
style_weight =  1e-4
content_weight = 1e-32
var_weight = 1e-2

adam = tf.optimizers.Adam(
    tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=30.0, decay_steps=100, decay_rate=0.90
    )
)

stylized_image_reg, display_images_reg = fit_style_transfer(style_image=style_image, content_image=content_image, 
                                                    style_weight=style_weight, content_weight=content_weight,
                                                    var_weight=var_weight, optimizer=adam, epochs=10, steps_per_epoch=100)

In [None]:
# Display GIF
GIF_PATH = 'style_transfer_reg.gif'
gif_images_reg = [np.squeeze(image.numpy().astype(np.uint8), axis=0) for image in display_images_reg]
gif_path_reg = create_gif(GIF_PATH, gif_images_reg)
display_gif(gif_path_reg)

In [None]:
# Display Frequency Variations

original_x_deltas, original_y_deltas = high_pass_x_y(
    tf.image.convert_image_dtype(content_image, dtype=tf.float32))

stylized_image_reg_x_deltas, stylized_image_reg_y_deltas = high_pass_x_y(
    tf.image.convert_image_dtype(stylized_image_reg, dtype=tf.float32))

plot_deltas((original_x_deltas, original_y_deltas), (stylized_image_reg_x_deltas, stylized_image_reg_y_deltas))

In [None]:
show_images_with_objects([style_image, content_image, stylized_image], titles=['Style Image', 'Content Image', 'Stylized Image'])

In [None]:
show_images_with_objects([style_image, content_image, stylized_image_reg], titles=['Style Image', 'Content Image', 'Stylized Image with Regularization'])