# **Neural Style Transfer**
This is the colab implementation of an image style transfer technique called neural style transfer. The inputs are a content image and a style image. The output is a new image that keeps the content in the content image and copy the style of the style image.

# **Setup**

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd /content
# !rm -rf neural_style_transfer/

In [None]:
# copy images to /content
%cp -av "/content/gdrive/MyDrive/neural_style_transfer" "/content"

In [None]:
import glob        
import IPython.display as display
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False
import numpy as np
import time
import functools

In [None]:
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input
from tensorflow.keras.models import Model

# **Utils**

In [None]:
def load_img(path_to_img):
    """ load an image from path as tensor, height 400 """
    max_h = 400
    img = tf.io.read_file(path_to_img)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)

    shape = tf.cast(tf.shape(img)[:-1], tf.float32)
    h = shape[0]
    scale = max_h / h

    new_shape = tf.cast(shape * scale, tf.int32)

    img = tf.image.resize(img, new_shape)
    img = img[tf.newaxis, :]
    
    return img

In [None]:
def plot_img(image, title=None):
    """ plot a tensor image """
    if len(image.shape) > 3:
        image = tf.squeeze(image, axis=0)

    plt.imshow(image)
    if title:
        plt.title(title)

In [None]:
def plot_image_array(images, col=3):
    """ plot multiple tensor images as an array """
    fig = plt.figure(figsize=(col*6, 4), constrained_layout=True)
    gs = fig.add_gridspec(1, col)
    for i, (key, val) in enumerate(images.items()):
        fig.add_subplot(gs[0, i])
        plot_img(val, key)
        plt.xticks([]),plt.yticks([])
    plt.show()

In [None]:
def tensor_to_image(tensor):
    """ convert tensor to image """
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
    tensor = tensor[0]
    
    return Image.fromarray(tensor)

In [None]:
def display_images(paths):
    """ diplay multiple images """
    for i in paths:
        img = load_img(i)
        img = tensor_to_image(img)
        display.display(img)

In [None]:
def clip_0_1(image):
    """ keep pixel values between 0 and 1 """
    return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)

In [None]:
def plot_data(data, names, y_label='loss', x_label='epoch'):
    """ plot a data list """
    n = len(data)
    plt.figure(figsize=(8, 6))
    plt.rc('font', size=14)

    for i in range(n):
        x = (np.arange(len(data[0])) + 1)
        y = data[i]
        plt.plot(x, y, label=names[i], lw=2)
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.show()

# **Show images**

In [None]:
contents = [img for img in glob.glob('neural_style_transfer/contents/*')]
styles = [img for img in glob.glob('neural_style_transfer/styles/*')]
contents.sort()
styles.sort()

In [None]:
display_images(contents)

In [None]:
display_images(styles)

# **VGG extractor**

In [None]:
vgg = VGG19(include_top=False, weights='imagenet')
vgg.trainable = False

In [None]:
def vgg_layers(layer_names):
    outputs = [vgg.get_layer(name).output for name in layer_names]
    model = Model([vgg.input], outputs)
    return model

In [None]:
def gram_matrix(input_tensor):
    result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
    input_shape = tf.shape(input_tensor)
    num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
    return result/(num_locations)

In [None]:
class StyleContentModel(Model):
    def __init__(self, style_layers, content_layers):
        super(StyleContentModel, self).__init__()
        self.vgg = vgg_layers(style_layers + content_layers)
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)
        self.vgg.trainable = False

    def call(self, inputs):
        inputs = inputs*255.0
        preprocessed_input = preprocess_input(inputs)
        
        outputs = self.vgg(preprocessed_input)
        style_outputs, content_outputs = (outputs[:self.num_style_layers], 
                                        outputs[self.num_style_layers:])

        style_outputs = [gram_matrix(style_output)
                        for style_output in style_outputs]

        content_dict = {content_name:value 
                        for content_name, value 
                        in zip(self.content_layers, content_outputs)}
        
        style_dict = {style_name:value
                    for style_name, value
                    in zip(self.style_layers, style_outputs)}

        return {'content':content_dict, 'style':style_dict}

In [None]:
def vgg_convs():
    idx = []
    for i in range(len(vgg.layers)):
        layer = vgg.layers[i]
        if 'conv' not in layer.name:
            continue
        idx.append(i)
        print(i, layer.name, layer.output.shape)
    return idx

In [None]:
def display_feature_maps(ixs, img_path, n_maps=10):
    """ only display 10 feature maps at each layer """
    n = len(ixs)
    outputs = [vgg.layers[i].output for i in ixs]
    names = [vgg.layers[i].name for i in ixs]
    model = Model(inputs=vgg.inputs, outputs=outputs)
    # vgg.summary()
    img = load_img(img_path)
    img = preprocess_input(img*255)
    img = tf.image.resize(img, (224, 224))
    feature_maps = model.predict(img)
    if n == 1:
        feature_maps = [feature_maps]
    for i, f_map in enumerate(feature_maps):
        fig = plt.figure(figsize=(20, 2))
        for j in range(n_maps):
            ax = plt.subplot(1, n_maps, j+1)
            ax.set_xticks([])
            ax.set_yticks([])
            plt.imshow(f_map[0, :, :, j], cmap='gray')
        plt.suptitle(names[i])
        plt.show()

# **Transfer functions**

In [None]:
class ImageTransfer:
    def __init__(self, content_path, style_path, content_layers, style_layers,
                 epochs=10, steps_per_epoch=100, content_weight=1,
                 style_weight=1, total_variation_weight=1,
                 learning_rate=0.02, beta_1=0.99, beta_2=0.999, epsilon=1e-1):
      
        self.content_image = load_img(content_path)
        self.style_image = load_img(style_path)
        self.content_layers = content_layers
        self.style_layers = style_layers
        self.num_content_layers = len(content_layers)
        self.num_style_layers = len(style_layers)
        self.extractor = StyleContentModel(style_layers, content_layers)
        self.style_targets = self.extractor(self.style_image)['style']
        self.content_targets = self.extractor(self.content_image)['content']

        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.total_variation_weight = total_variation_weight
        self.learning_rate = learning_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon

        self.opt = tf.optimizers.Adam(learning_rate=learning_rate,
                                      beta_1=beta_1,
                                      beta_2=beta_2,
                                      epsilon=epsilon)

    @tf.function()
    def train_step(self, image):
        with tf.GradientTape() as tape:
            outputs = self.extractor(image)
            loss = self.style_content_loss(outputs)
            v_loss = self.total_variation_weight * tf.image.total_variation(image)
            loss[0] += v_loss
            loss.append(v_loss)

        grad = tape.gradient(loss, image)
        self.opt.apply_gradients([(grad, image)])
        image.assign(clip_0_1(image))

        return loss
    
    def style_content_loss(self, outputs):
        style_outputs = outputs['style']
        content_outputs = outputs['content']

        style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-self.style_targets[name])**2) 
                              for name in style_outputs.keys()])
        style_loss *= self.style_weight / self.num_style_layers
        content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-self.content_targets[name])**2) 
                                for name in content_outputs.keys()])
        content_loss *= self.content_weight / self.num_content_layers
        loss = style_loss + content_loss
        
        return [loss, content_loss, style_loss]
    
    def save_to_dict(self, image):
        hyper_dict = {'content_layers': self.content_layers,
                      'style_layers': self.style_layers,
                      'content_weight': self.content_weight,
                      'style_weight': self.style_weight,
                      'variation_weight': self.total_variation_weight,
                      'learning_rate': self.learning_rate,
                      'beta_1': self.beta_1,
                      'beta_2': self.beta_2,
                      'epsilon': self.epsilon,
                      'epochs': self.epochs}
        
        image_dict = {'Content Image': self.content_image,
                      'Style Image': self.style_image,
                      'Generated Image': image}
        
        return hyper_dict, image_dict
    
    def transfer(self):
        image = tf.Variable(self.content_image)
        losses = []
        c_losses = []
        s_losses = []
        v_losses = []
        inter_images = []
        best_img, best_loss = None, float('inf')
        start = time.time()
        step = 0

        for n in range(self.epochs):
            for m in range(self.steps_per_epoch):
                step += 1
                loss = self.train_step(image)
                if loss[0] < best_loss:
                    best_loss = loss[0]
                    best_img = image
                print(".", end='')
            display.clear_output(wait=True)
            display.display(tensor_to_image(image))
            print("Train step: {}".format(step))
            losses.append(loss[0])
            c_losses.append(loss[1])
            s_losses.append(loss[2])
            v_losses.append(loss[3])
            inter_images.append(image)

        end = time.time()
        print("Total time: {:.1f}".format(end-start))
        
        hyper_dict, image_dict = self.save_to_dict(best_img)

        loss_list = [losses, c_losses, s_losses, v_losses]

        return image_dict, loss_list, hyper_dict, inter_images

In [None]:
def run_training():
    # transfer
    transfer = ImageTransfer(content_path=content_path,
                             style_path=style_path,
                             content_layers=content_layers,
                             style_layers=style_layers,
                             epochs=epochs,
                             steps_per_epoch=steps_per_epoch,
                             content_weight=content_weight,
                             style_weight=style_weight,
                             total_variation_weight=total_variation_weight,
                             learning_rate=learning_rate,
                             beta_1=beta_1)

    return transfer.transfer()

# **Visualize feature maps**

In [None]:
display_images([contents[4], styles[0]])

In [None]:
idx_all = vgg_convs()

In [None]:
display_feature_maps(idx_all, contents[4])

In [None]:
display_feature_maps(idx_all, styles[16])

# **Training**

In [None]:
# hyper-parameters
content_layers = ['block3_conv4']
style_layers = ['block1_conv2',
                'block2_conv2',
                'block4_conv4']
epochs = 5
steps_per_epoch = 200

content_weight = 0.6
style_weight = 0.4
total_variation_weight = 3

learning_rate = 0.02
beta_1 = 0.9

In [None]:
content_path = contents[4]
style_path = styles[16]
image_dict, losses, _, inter_images = run_training()
plot_image_array(image_dict)

In [None]:
for i, image in enumerate(inter_images):
    image = tensor_to_image(image)
    path = 'saved/image_' + str(i+1) + '.jpg'
    image.save(path)