# Load library

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from PIL import Image

import tensorflow as tf
from tensorflow import keras

from tqdm import tqdm_notebook as tqdm

import numpy as np

# Open TensorFlow session

In [None]:
config = tf.ConfigProto()
# Allocate only necessary amount of GPU memory 
config.gpu_options.allow_growth = True
sess = tf.InteractiveSession(config=config)

# Load pretrained VGG19 model.
(download VGG19 model if necessary)

In [None]:
print("Load model ...")
vgg19 = tf.keras.applications.VGG19()
vgg19.summary()

# Load image and plot

In [None]:
print("Load images ...")
#img_style = Image.open("./landscape.png").convert("RGB").resize((224,224))
#img_style = Image.open("./gogh.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./style1.png").convert("RGB").resize((224,224))
img_style = Image.open("./style2.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./frozen.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./udnie.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./the_scream.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./rain_princess.jpg").convert("RGB").resize((224,224))
#img_style = Image.open("./the_shipwreck_of_the_minotaur.jpg").convert("RGB").resize((224,224))

#img_content = Image.open("./cat.jpg").convert("RGB").resize((224,224))
#img_content = Image.open("./seoul.jpg").convert("RGB").resize((224,224))
#img_content = Image.open("./chicago.jpg").convert("RGB").resize((224,224))    #4e-4
#img_content = Image.open("./deadpool.jpg").convert("RGB").resize((224,224))
img_content = Image.open("./willy_wonka_old.jpg").convert("RGB").resize((224,224))

fig = plt.figure(figsize=(25,10))
fig_style = fig.add_subplot(1, 2, 1)
fig_content = fig.add_subplot(1, 2, 2)

fig_style.imshow(img_style)
fig_style.set_title("Style")
fig_content.imshow(img_content)
fig_content.set_title("Content")

# Set hyperparameters
<img src="style_transfer_model.png">

In [None]:
content_layer = "block4_conv2"
style_layer_list = [
    "block1_conv1",
    "block2_conv1",
    "block3_conv1",
    "block4_conv1",
    "block5_conv1"
]
w_style = [
    0.2,
    0.2,
    0.2,
    0.2,
    0.2
]
alpha = 3.0
beta = 10.0
gamma = 0.3

# Content Loss
<img src="content_loss.PNG">

In [None]:
content_img_content = tf.constant(sess.run(vgg19.get_layer(content_layer).output, feed_dict={vgg19.input:np.expand_dims(img_content, 0)}))
content_img_output = vgg19.get_layer(content_layer).output
content_loss = 0.5 * tf.reduce_sum(tf.pow(content_img_output - content_img_content, 2))

# Compute gram matrix
## 1. Flatten filters
<img src="flatten_filter.png">

## 2. Multiply the matrix to generate (\# of filters) * (\# of filters)
\begin{equation}
G_l = (\hat{F}_l)^T\hat{F}_l
\end{equation}

In [None]:
def gram_matrix(feature_matrix):
    n_filter = int(feature_matrix.get_shape()[-1])
    
    flatten_feature_matrix = tf.reshape(feature_matrix, [-1, n_filter])
    
    return tf.matmul(tf.transpose(flatten_feature_matrix), flatten_feature_matrix)

# Style loss
<img src="style_loss1.PNG">
<img src="style_loss2.PNG">

In [None]:
total_style_loss = 0.0
for style_layer_idx in range(len(style_layer_list)):
    style_img_style = tf.constant(sess.run(vgg19.get_layer(style_layer_list[style_layer_idx]).output, feed_dict={vgg19.input:np.expand_dims(img_style, 0)}))
    gram_style_img_style = gram_matrix(style_img_style)
    style_img_output = vgg19.get_layer(style_layer_list[style_layer_idx]).output
    gram_style_img_output = gram_matrix(style_img_output)
    
    n_filter = tf.constant(int(style_img_style.get_shape()[-1]), dtype=tf.float32)
    filter_size = tf.constant(int(style_img_style.get_shape()[-2]) * int(style_img_style.get_shape()[-3]), dtype=tf.float32)
    style_loss = tf.reduce_sum(tf.pow(gram_style_img_output - gram_style_img_style, 2)) / (4.0 * tf.pow(n_filter * filter_size, 2))
    total_style_loss += w_style[style_layer_idx] * style_loss

# Denoise Loss

In [None]:
denoise_loss =  tf.reduce_sum(tf.abs(vgg19.input[:,1:,:,:] - vgg19.input[:,:-1,:,:])) \
                + tf.reduce_sum(tf.abs(vgg19.input[:,:,1:,:] - vgg19.input[:,:,:-1,:]))

# Part 1. Effect of Content Loss

In [None]:
loss = alpha * content_loss 

opt = keras.optimizers.Adam()
grad_output = opt.get_gradients(loss, vgg19.input)

learning_rate = 3e-5
n_epoch = 60

img_output = np.random.rand(224,224,3) + 128

fig = plt.figure(figsize=(25,10))
fig_style = fig.add_subplot(1, 3, 1)
fig_output = fig.add_subplot(1, 3, 2)
fig_content = fig.add_subplot(1, 3, 3)

fig_style.imshow(img_style)
fig_style.set_title("Style")
fig_output.imshow(np.int32(img_output))
fig_output.set_title("Initial Output")
fig_content.imshow(img_content)
fig_content.set_title("Content")

for ep in tqdm(range(n_epoch)):
    grad, l_content = sess.run([grad_output, content_loss], feed_dict={vgg19.input:np.expand_dims(img_output, 0)})
    img_output -= learning_rate*grad[0][0]
    img_output = np.clip(img_output, 0.0, 255.0)
    if ((ep + 1) % 10 == 0):
        fig = plt.figure(figsize=(25,10))
        fig_style = fig.add_subplot(1, 3, 1)
        fig_output = fig.add_subplot(1, 3, 2)
        fig_content = fig.add_subplot(1, 3, 3)
        
        fig_style.imshow(img_style)
        fig_style.set_title("Style")
        fig_output.imshow(np.int32(img_output))
        fig_output.set_title("Output after %d iter" % (ep+1))
        fig_content.imshow(img_content)
        fig_content.set_title("Content")
        
        print("Content loss: %f" % (l_content))

# Part 2. Effect of Style Loss

In [None]:
loss = beta * total_style_loss

opt = keras.optimizers.Adam()
grad_output = opt.get_gradients(loss, vgg19.input)

learning_rate = 3e-5
n_epoch = 60

img_output = np.random.rand(224,224,3) + 128

fig = plt.figure(figsize=(25,10))
fig_style = fig.add_subplot(1, 3, 1)
fig_output = fig.add_subplot(1, 3, 2)
fig_content = fig.add_subplot(1, 3, 3)

fig_style.imshow(img_style)
fig_style.set_title("Style")
fig_output.imshow(np.int32(img_output))
fig_output.set_title("Initial Output")
fig_content.imshow(img_content)
fig_content.set_title("Content")

for ep in tqdm(range(n_epoch)):
    grad, l_style = sess.run([grad_output, style_loss], feed_dict={vgg19.input:np.expand_dims(img_output, 0)})
    img_output -= learning_rate*grad[0][0]
    img_output = np.clip(img_output, 0.0, 255.0)
    if ((ep + 1) % 10 == 0):
        fig = plt.figure(figsize=(25,10))
        fig_style = fig.add_subplot(1, 3, 1)
        fig_output = fig.add_subplot(1, 3, 2)
        fig_content = fig.add_subplot(1, 3, 3)
        
        fig_style.imshow(img_style)
        fig_style.set_title("Style")
        fig_output.imshow(np.int32(img_output))
        fig_output.set_title("Output after %d iter" % (ep+1))
        fig_content.imshow(img_content)
        fig_content.set_title("Content")
        
        print("Style loss: %f" % (l_style))


# Part 3. Basic Style Transfer

In [None]:
loss = alpha * content_loss + beta * total_style_loss + gamma * denoise_loss

opt = keras.optimizers.Adam()
grad_output = opt.get_gradients(loss, vgg19.input)

learning_rate = 3e-5
n_epoch = 60

img_output = np.random.rand(224,224,3) + 128

fig = plt.figure(figsize=(25,10))
fig_style = fig.add_subplot(1, 3, 1)
fig_output = fig.add_subplot(1, 3, 2)
fig_content = fig.add_subplot(1, 3, 3)

fig_style.imshow(img_style)
fig_style.set_title("Style")
fig_output.imshow(np.int32(img_output))
fig_output.set_title("Initial Output")
fig_content.imshow(img_content)
fig_content.set_title("Content")

for ep in tqdm(range(n_epoch)):
    grad, l_content, l_style, l_denoise = sess.run([grad_output, content_loss, style_loss, denoise_loss], feed_dict={vgg19.input:np.expand_dims(img_output, 0)})
    img_output -= learning_rate*grad[0][0]
    img_output = np.clip(img_output, 0.0, 255.0)
    if ((ep + 1) % 10 == 0):
        fig = plt.figure(figsize=(25,10))
        fig_style = fig.add_subplot(1, 3, 1)
        fig_output = fig.add_subplot(1, 3, 2)
        fig_content = fig.add_subplot(1, 3, 3)
        
        fig_style.imshow(img_style)
        fig_style.set_title("Style")
        fig_output.imshow(np.int32(img_output))
        fig_output.set_title("Output after %d iter" % (ep+1))
        fig_content.imshow(img_content)
        fig_content.set_title("Content")
        
        print("Content loss: %f \tStyle loss: %f\tDenoise loss: %f" % (l_content, l_style, l_denoise))
