# **DD2424 Project - Image Style Transfer**
This is the colab notebook for project in DD2424 @ KTH.

# **Setup**

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
    raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd /content
!rm -rf image_style_transfer/

In [None]:
%cp -av "/content/gdrive/MyDrive/image_style_transfer" "/content"

In [None]:
import glob
from PIL import Image
import cv2                
from google.colab.patches import cv2_imshow

import IPython.display as display

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['axes.grid'] = False

import numpy as np
import time
import functools

# **Preprocessing**

In [None]:
def resize_images(images):
    for image in images:
        basewidth = 480
        img = Image.open(image)
        wpercent = (basewidth/float(img.size[0]))
        hsize = int((float(img.size[1])*float(wpercent)))
        img = img.resize((basewidth,hsize), Image.NEAREST)
        img = img.convert("RGB")
        img.save(image)

In [None]:
def show_images(images):
    for image in images:
        img = cv2.imread(image)
        cv2_imshow(img)
        print("\n")

In [None]:
contents = [img for img in glob.glob('image_style_transfer/images/*')]
styles = [img for img in glob.glob('image_style_transfer/styles/*')]
contents.sort()
styles.sort()

In [None]:
resize_images(contents)
resize_images(styles)

In [None]:
show_images(contents)

In [None]:
show_images(styles)

# **Utils**

In [None]:
def tensor_to_image(tensor):
    tensor = tensor*255
    tensor = np.array(tensor, dtype=np.uint8)
    if np.ndim(tensor)>3:
        assert tensor.shape[0] == 1
    tensor = tensor[0]
    
    return Image.fromarray(tensor)

In [None]:
def load_img(path_to_img):
    max_dim = 512
    img = tf.io.read_file(path_to_img)
    img = tf.image.decode_image(img, channels=3)
    img = tf.image.convert_image_dtype(img, tf.float32)

    shape = tf.cast(tf.shape(img)[:-1], tf.float32)
    long_dim = max(shape)
    scale = max_dim / long_dim

    new_shape = tf.cast(shape * scale, tf.int32)

    img = tf.image.resize(img, new_shape)
    img = img[tf.newaxis, :]
    
    return img

In [None]:
def imshow(image, title=None):
    if len(image.shape) > 3:
        image = tf.squeeze(image, axis=0)

    plt.imshow(image)
    if title:
        plt.title(title)

In [None]:
def clip_0_1(image):
    return tf.clip_by_value(image, clip_value_min=0.0, clip_value_max=1.0)

In [None]:
def plot_metrics(data, metrics, y_label, x_label='epoch', factor=1):
    legends = ['total_loss',
               'content_loss',
               'style_loss',
               'variation_loss']
    n = len(legends)
    plt.figure(figsize=(8, 6))
    plt.rc('font', size=14)

    for i in range(n):
        if legends[i] in metrics:
            x = (np.arange(len(data)) + 1) * factor
            y = [loss[i] for loss in data]
            plt.plot(x, y, label=legends[i], lw=2)
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.show()

In [None]:
def plot_images_list(images, col=5):
    n = len(images)
    row = np.ceil(n/col)
    plt.figure(figsize=(col*6,row*6))
    for i, img in enumerate(images):
        plt.subplot(row, col, i+1)
        plt.xticks([]),plt.yticks([])
        imshow(img, i+1)
    plt.show()

In [None]:
def plot_images(images, col=5):
    n = len(images)
    row = np.ceil(n/col)
    plt.figure(figsize=(col*8,row*8))
    for i, (key, val) in enumerate(images.items()):
        plt.subplot(row, col, i+1)
        plt.xticks([]),plt.yticks([])
        imshow(val, key)
    plt.show()

In [None]:
def print_vgg_layers():
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    print()
    for layer in vgg.layers:
        print(layer.name)

# **VGG extractor**

In [None]:
def vgg_layers(layer_names):
    vgg = tf.keras.applications.VGG19(include_top=False, weights='imagenet')
    vgg.trainable = False

    outputs = [vgg.get_layer(name).output for name in layer_names]

    model = tf.keras.Model([vgg.input], outputs)
    
    return model

In [None]:
def gram_matrix(input_tensor):
    result = tf.linalg.einsum('bijc,bijd->bcd', input_tensor, input_tensor)
    input_shape = tf.shape(input_tensor)
    num_locations = tf.cast(input_shape[1]*input_shape[2], tf.float32)
    
    return result/(num_locations)

In [None]:
class StyleContentModel(tf.keras.models.Model):
    def __init__(self, style_layers, content_layers):
        super(StyleContentModel, self).__init__()
        self.vgg = vgg_layers(style_layers + content_layers)
        self.style_layers = style_layers
        self.content_layers = content_layers
        self.num_style_layers = len(style_layers)
        self.vgg.trainable = False

    def call(self, inputs):
        inputs = inputs*255.0
        preprocessed_input = tf.keras.applications.vgg19.preprocess_input(inputs)
        
        outputs = self.vgg(preprocessed_input)
        style_outputs, content_outputs = (outputs[:self.num_style_layers], 
                                        outputs[self.num_style_layers:])

        style_outputs = [gram_matrix(style_output)
                        for style_output in style_outputs]

        content_dict = {content_name:value 
                        for content_name, value 
                        in zip(self.content_layers, content_outputs)}
        
        style_dict = {style_name:value
                    for style_name, value
                    in zip(self.style_layers, style_outputs)}

        return {'content':content_dict, 'style':style_dict}

# **Image transfer**

In [None]:
class ImageTransfer:
    def __init__(self, content_path, style_path, content_layers, style_layers,
                 epochs=20, steps_per_epoch=100, content_weight=0.1,
                 style_weight=0.9, total_variation_weight=30,
                 learning_rate=0.02, beta_1=0.99, beta_2=0.999, epsilon=1e-7):
      
        self.content_image = load_img(content_path)
        self.style_image = load_img(style_path)
        self.content_layers = content_layers
        self.style_layers = style_layers
        self.num_content_layers = len(content_layers)
        self.num_style_layers = len(style_layers)
        self.extractor = StyleContentModel(style_layers, content_layers)
        self.style_targets = self.extractor(self.style_image)['style']
        self.content_targets = self.extractor(self.content_image)['content']

        self.epochs = epochs
        self.steps_per_epoch = steps_per_epoch
        self.content_weight = content_weight
        self.style_weight = style_weight
        self.total_variation_weight = total_variation_weight
        self.learning_rate = learning_rate
        self.beta_1 = beta_1
        self.beta_2 = beta_2
        self.epsilon = epsilon

        self.opt = tf.optimizers.Adam(learning_rate=learning_rate,
                                      beta_1=beta_1,
                                      beta_2=beta_2,
                                      epsilon=epsilon)

    @tf.function()
    def train_step(self, image):
        with tf.GradientTape() as tape:
            outputs = self.extractor(image)
            loss = self.style_content_loss(outputs)
            v_loss = self.total_variation_weight * tf.image.total_variation(image)
            loss += v_loss

        grad = tape.gradient(loss, image)
        self.opt.apply_gradients([(grad, image)])
        image.assign(clip_0_1(image))

        return loss
    
    def style_content_loss(self, outputs):
        style_outputs = outputs['style']
        content_outputs = outputs['content']

        style_loss = tf.add_n([tf.reduce_mean((style_outputs[name]-self.style_targets[name])**2) 
                              for name in style_outputs.keys()])
        style_loss *= self.style_weight / self.num_style_layers
        content_loss = tf.add_n([tf.reduce_mean((content_outputs[name]-self.content_targets[name])**2) 
                                for name in content_outputs.keys()])
        content_loss *= self.content_weight / self.num_content_layers
        loss = style_loss + content_loss
        
        return loss
    
    def save_to_dict(self, image):
        hyper_dict = {'content_layers': self.content_layers,
                      'style_layers': self.style_layers,
                      'content_weight': self.content_weight,
                      'style_weight': self.style_weight,
                      'variation_weight': self.total_variation_weight,
                      'learning_rate': self.learning_rate,
                      'beta_1': self.beta_1,
                      'beta_2': self.beta_2,
                      'epsilon': self.epsilon,
                      'epochs': self.epochs}
        
        image_dict = {'Content Image': self.content_image,
                      'Style Image': self.style_image,
                      'Generated Image': image}
        
        return hyper_dict, image_dict
    
    def transfer(self):
        image = tf.Variable(self.content_image)
        losses = []
        best_img, best_loss = None, float('inf')
        start = time.time()
        step = 0

        for n in range(self.epochs):
            for m in range(self.steps_per_epoch):
                step += 1
                loss = self.train_step(image)
                if loss < best_loss:
                    best_loss = loss
                    best_img = image
                print(".", end='')
            display.clear_output(wait=True)
            display.display(tensor_to_image(image))
            print("Train step: {}".format(step))
            losses.append(loss)

        end = time.time()
        print("Total time: {:.1f}".format(end-start))
        
        hyper_dict, image_dict = self.save_to_dict(best_img)

        return image_dict, hyper_dict, losses

In [None]:
def run_training():
    # transfer
    transfer = ImageTransfer(content_path=content_path,
                             style_path=style_path,
                             content_layers=content_layers,
                             style_layers=style_layers,
                             epochs=epochs,
                             content_weight=content_weight,
                             style_weight=style_weight,
                             total_variation_weight=total_variation_weight,
                             learning_rate=learning_rate,
                             beta_1=beta_1,
                             beta_2=beta_2,
                             epsilon=epsilon)

    image_dict, hyper_dict, losses = transfer.transfer()

    # print(hyper_dict, '\n')
    plot_images(image_dict)

    if plot_loss:
        plot_metrics(data=losses, metrics=metrics, y_label='loss')

In [None]:
# hyper-parameters
content_layers = ['block3_conv4']
style_layers = ['block1_conv1',
                'block2_conv2',
                'block3_conv2',
                'block4_conv4',
                'block5_conv3']
epochs = 10

content_weight = 0.5
style_weight = 0.5

total_variation_weight = 5

learning_rate = 0.02
beta_1 = 0.99
beta_2 = 0.999
epsilon = 1e-1

# plot loss
plot_loss = False
metrics = ['total_loss']

In [None]:
content_path = contents[0]
style_path = styles[0]
run_training()

In [None]:
def plot_metrics(data, names, y_label='loss', x_label='epoch', factor=1):
    n = len(data)
    plt.figure(figsize=(8, 6))
    plt.rc('font', size=14)

    for i in range(n):
        x = (np.arange(len(data[0])) + 1) * factor
        y = data[i]
        plt.plot(x, y, label=names[i], lw=2)
    
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.legend()
    plt.show()

In [None]:
loss_data = [losses1, losses2, losses3]
names = ['Generated image 1', 'Generated image 2', 'Generated image 3']
plot_metrics(loss_data, names)