# Neural Style Transfer

Feature Extraction

In [None]:
import tensorflow as tf


class FeatureExtractor(object):
  
  @classmethod
  def custom_model(cls):
    filters = [64, 128, 64, 32]
    layer_names = [f"conv_block_{i}" for i in range(2, 2 + len(filters))]

    input_layer = tf.keras.layers.Input(shape = (256, 256, 3),
                                        name = "input_layer")

    x = tf.keras.layers.Conv2D(filters = 32,
                               kernel_size = (2, 2),
                               strides = (1, 1),
                               padding = "same",
                               name = "conv_block_1")(input_layer)
                   
    layer_num = len(layer_names) or len(filters)
    activation_layers = list()
    for i in range(layer_num):
      if i % 2 == 1 and i != 0:
        x = tf.keras.layers.ReLU(name = f"relu_layer_{i}th_iter")(x)
        x = tf.keras.layers.BatchNormalization()(x)

      x = tf.keras.layers.Conv2D(filters = filters[i],
                               kernel_size = (3, 3),
                               strides = (1, 1),
                               padding = "same",
                               name = layer_names[i])(x)
      
      activation_layers.append(x)
                                  
    out = tf.keras.layers.Conv2D(filters = 3,
                               kernel_size = (3, 3),
                               strides = (2, 2),
                               padding = "same",
                               name = "conv_block_out")(x)

    activation_layers.append(out)

    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = activation_layers,
                                  name = "feature_extractor_model")

    return model         


  @classmethod
  def vgg_extractor_model(cls):
    vgg_model = tf.keras.applications.VGG19(include_top = False,
                                            weights = "imagenet")
    
    style_conv_blocks = [f"block{i}_conv1" for i in range(1, 6)]
    content_conv_block = ["block5_conv2"]
    all_activation_layers = style_conv_blocks + content_conv_block

    input_layer = vgg_model.inputs
    output_layers = [vgg_model.get_layer(i).output for i in all_activation_layers]
    
    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = output_layers,
                                  name = "vgg19_extractor")

    return model

  
  @classmethod
  def vgg_with_recurrent(cls):
    input_layer = tf.keras.layers.Input(shape = (256, 256, 3))
    vgg_layer = tf.keras.applications.VGG19(include_top = False,
                                            weights = "imagenet")(input_layer)

    content_activation = tf.keras.layers.ConvLSTM2D(64, (1, 1), (1, 1),
                                                    padding = "same")    
    
    num_layers = 1
    recurrent_layer_names = [f"rnn_conv_{i}" for i in range(1, num_layers)]
    recurrent_layers = list()
    for layer_name in recurrent_layer_names:
      rnn_layer = tf.keras.layers.ConvLSTM2D(64, (1, 1), (1, 1), padding = "same",
                                             name = layer_name)
      recurrent_layers.append(rnn_layer)

    x = input_layer
    model_outputs = list()
    for rnn_layer in recurrent_layers:
      x = tf.expand_dims(x, axis = 0)
      x = rnn_layer(x)
      model_outputs.append(x)

    x = tf.expand_dims(x, axis = 0)
    out = content_activation(x)
    model_outputs.append(out)
    
    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = model_outputs,
                                  name = "vgg_19_with_recurrent")
    
    return model
    

  @classmethod
  def extract(cls, image_stack, model):
    """Image stack is (3, None, None, 3)
    shaped image data which contains
    content, style and generated images"""
    return model()(image_stack)

  
  @staticmethod
  def get_layers(features):
    content = features[-1]
    style = features[:-1]
    
    return content, style

Loss Function

In [None]:
class Loss(object):
  
  @classmethod
  def gram_matrix(cls, arr):
    """Gramian matrix for calculating style loss"""
    x = tf.transpose(arr, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))

    return gram


  @classmethod
  def content_loss(cls, content, generated):
    """1/2 * sum of (generated - original) ** 2"""
    content_loss = tf.reduce_sum(tf.square((generated - content)))
    
    return content_loss * 5e-1


  @classmethod
  def style_loss(cls, style, generated):
    style_gram = cls.gram_matrix(style)
    generated_gram = cls.gram_matrix(generated)

    style_loss = tf.reduce_mean(tf.square(generated_gram - style_gram))

    return style_loss

Train Function

In [None]:
import tensorflow as tf

from loss import Loss
from feature_extractor import FeatureExtractor


class Constants(object):
  CONTENT_WEIGHT = 2e-5
  STYLE_WEIGHT = 1e-4


class Train(Loss, FeatureExtractor):
  
  @classmethod
  def calculate_step_loss(cls, model, content, style, generated):
    tensor = tf.concat([content, style, generated], axis = 0)
    features = cls.extract(image_stack = tensor, model = model)
    content_act, style_act = cls.get_layers(features)

    content_loss = cls.content_loss(content_act[0], content_act[-1])

    style_loss = 0.
    for layer in style_act:
      layer_loss = cls.style_loss(layer[1], layer[-1])
      style_loss += layer_loss

    loss = (content_loss * Constants.CONTENT_WEIGHT) \
           + (style_loss * Constants.STYLE_WEIGHT)

    return loss


def train(model, content, style, generated, epochs = 10):
  optimizer = tf.keras.optimizers.SGD(learning_rate = 1e-4) 
  for epoch in range(epochs):
    with tf.GradientTape() as GT:
      loss = Train.calculate_step_loss(model, content, style, generated)

    print(f"EPOCH: {epoch + 1} \nLOSS: {loss}\n" + ("---" * 15))

    gradients = GT.gradient(loss, generated)
    optimizer.apply_gradients([(gradients, generated)])

  return generated

Loading image

In [None]:
from skimage import io
import tensorflow as tf


def load_image(path: str):
  image = io.imread(path)
  expanded = tf.expand_dims(tf.cast(tf.convert_to_tensor(image),
                                    tf.float32) / 255., axis = 0)

  return expanded

### All Imports

In [None]:
import os
import tensorflow as tf
tf.compat.v1.enable_eager_execution()

import matplotlib.pyplot as plt
from skimage import io, transform

# rewrite the paths for your own images and base folder
BASE = "./drive/MyDrive/neural-style-transfer"
CONTENT_IMAGE_PATH = os.path.join(BASE, "images", "content.jpeg")
STYLE_IMAGE_PATH = os.path.join(BASE, "images", "style.png")

### Loading Images - I/O

In [None]:
def load_image(path: str):
  image = io.imread(path)
  expanded = tf.expand_dims(tf.cast(tf.convert_to_tensor(image),
                                    tf.float32) / 255., axis = 0)

  return expanded

class Constants(object):
  CONTENT_WEIGHT = 2e-5
  STYLE_WEIGHT = 1e-4

CONTENT = load_image(CONTENT_IMAGE_PATH)
STYLE = load_image(STYLE_IMAGE_PATH)
COMBINED = tf.Variable(CONTENT)

### Loss Functions

In [None]:
class Loss(object):

  @classmethod
  def gram_matrix(cls, arr):
    """Gramian matrix for calculating style loss"""
    x = tf.transpose(arr, (2, 0, 1))
    features = tf.reshape(x, (tf.shape(x)[0], -1))
    gram = tf.matmul(features, tf.transpose(features))
    
    return gram


  @classmethod
  def content_loss(cls, content, generated):
    """1/2 * sum of (generated - original) ** 2"""
    content_loss = tf.reduce_sum(tf.square((generated - content)))
    
    return content_loss * 5e-1


  @classmethod
  def style_loss(cls, style, generated):
    style_gram = cls.gram_matrix(style)
    generated_gram = cls.gram_matrix(generated)

    style_loss = tf.reduce_mean(tf.square(generated_gram - style_gram))

    return style_loss

### Models

In [None]:
class FeatureExtractor(object):
  
  @classmethod
  def custom_model(cls):
    filters = [64, 128, 64, 32]
    layer_names = [f"conv_block_{i}" for i in range(2, 2 + len(filters))]

    input_layer = tf.keras.layers.Input(shape = (256, 256, 3),
                                        name = "input_layer")

    x = tf.keras.layers.Conv2D(filters = 32,
                               kernel_size = (2, 2),
                               strides = (1, 1),
                               padding = "same",
                               name = "conv_block_1")(input_layer)
                   
    layer_num = len(layer_names) or len(filters)
    activation_layers = list()
    for i in range(layer_num):
      if i % 2 == 1 and i != 0:
        x = tf.keras.layers.ReLU(name = f"relu_layer_{i}th_iter")(x)
        x = tf.keras.layers.BatchNormalization()(x)

      x = tf.keras.layers.Conv2D(filters = filters[i],
                               kernel_size = (3, 3),
                               strides = (1, 1),
                               padding = "same",
                               name = layer_names[i])(x)
      
      activation_layers.append(x)
                                  
    out = tf.keras.layers.Conv2D(filters = 3,
                               kernel_size = (3, 3),
                               strides = (2, 2),
                               padding = "same",
                               name = "conv_block_out")(x)

    activation_layers.append(out)

    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = activation_layers,
                                  name = "feature_extractor_model")

    return model         


  @classmethod
  def vgg_extractor_model(cls):
    vgg_model = tf.keras.applications.VGG19(include_top = False,
                                            weights = "imagenet")
    
    style_conv_blocks = [f"block{i}_conv1" for i in range(1, 6)]
    content_conv_block = ["block5_conv2"]
    all_activation_layers = style_conv_blocks + content_conv_block

    input_layer = vgg_model.inputs
    output_layers = [vgg_model.get_layer(i).output for i in all_activation_layers]
    
    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = output_layers,
                                  name = "vgg19_extractor")

    return model

  
  @classmethod
  def vgg_with_recurrent(cls):
    input_layer = tf.keras.layers.Input(shape = (256, 256, 3))
    vgg_layer = tf.keras.applications.VGG19(include_top = False,
                                            weights = "imagenet")(input_layer)

    content_activation = tf.keras.layers.ConvLSTM2D(64, (1, 1), (1, 1),
                                                    padding = "same")    
    
    num_layers = 1
    recurrent_layer_names = [f"rnn_conv_{i}" for i in range(1, num_layers)]
    recurrent_layers = list()
    for layer_name in recurrent_layer_names:
      rnn_layer = tf.keras.layers.ConvLSTM2D(64, (1, 1), (1, 1), padding = "same",
                                             name = layer_name)
      recurrent_layers.append(rnn_layer)

    x = input_layer
    model_outputs = list()
    for rnn_layer in recurrent_layers:
      x = tf.expand_dims(x, axis = 0)
      x = rnn_layer(x)
      model_outputs.append(x)

    x = tf.expand_dims(x, axis = 0)
    out = content_activation(x)
    model_outputs.append(out)
    
    model = tf.keras.models.Model(inputs = input_layer,
                                  outputs = model_outputs,
                                  name = "vgg_19_with_recurrent")
    
    return model
    

  @classmethod
  def extract(cls, image_stack, model):
    """Image stack is (3, None, None, 3)
    shaped image data which contains
    content, style and generated images"""
    return model()(image_stack)

  
  @staticmethod
  def get_layers(features):
    content = features[-1]
    style = features[:-1]
    
    return content, style

### Training step - calculating loss

In [None]:
class Train(Loss, FeatureExtractor):

  @classmethod
  def calculate_step_loss(cls, model, content, style, generated):
    tensor = tf.concat([content, style, generated], axis = 0)
    features = cls.extract(image_stack = tensor, model = model)
    content_act, style_act = cls.get_layers(features)

    content_loss = cls.content_loss(content_act[0], content_act[-1])

    style_loss = 0.
    for layer in style_act:
      layer_loss = cls.style_loss(layer[1], layer[-1])
      style_loss += layer_loss

    loss = (content_loss * Constants.CONTENT_WEIGHT) \
           + (style_loss * Constants.STYLE_WEIGHT)

    return loss

### Training Loop

In [None]:
def train(model, content, style, generated, epochs = 10):
  optimizer = tf.keras.optimizers.SGD(learning_rate = 1e-4) 
  for epoch in range(epochs):
    with tf.GradientTape() as GT:
      loss = Train.calculate_step_loss(model, content, style, generated)

    print(f"EPOCH: {epoch + 1} \nLOSS: {loss}\n" + ("---" * 15))

    gradients = GT.gradient(loss, generated)
    optimizer.apply_gradients([(gradients, generated)])

  return generated

### Main Training Process

##### Classsical Style Transfer Approach with VGG19 Extractor

In [None]:
styled_img_vgg = train(model = FeatureExtractor.vgg_extractor_model,
                       content = CONTENT,
                       style = STYLE,
                       generated = COMBINED,
                       epochs = 50)

In [None]:
plt.title("Styled Image - Classical VGG19 Extractor")
plt.imshow(tf.squeeze(styled_img_vgg, axis = 0))
plt.show()

##### Style Transfer with Recurrent Convolution Layers (ConvLSTM2D)

In [None]:
styled_img_rnn = train(model = FeatureExtractor.vgg_with_recurrent,
                   content = CONTENT,
                   style = STYLE,
                   generated = COMBINED,
                   epochs = 30)

In [None]:
plt.title("Styled Image - Recurrent Convolutional Layers")
plt.imshow(tf.squeeze(styled_img_rnn, axis = 0))
plt.show()

##### Style Transfer with custom convolutional layers

In [None]:
styled_img_custom = train(model = FeatureExtractor.custom_model,
                          content = CONTENT,
                          style = STYLE,
                          generated = COMBINED,
                          epochs = 30)

In [None]:
plt.title("Styled Image - Custom Feature Extractor")
plt.imshow(tf.squeeze(styled_img_custom, axis = 0))
plt.show()