In [None]:
%matplotlib inline

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import time

# to retain original image color after style transfer
from matplotlib.colors import rgb_to_hsv, hsv_to_rgb

In [None]:
STYLE_IMG = "./images/style.png"
CONTENT_IMG = "./images/content.jpg"
RESULT_IMG = "./images/result.jpg"

STYLE_DIR = "../pandorastyles/"
CONTENT_DIR = "../contents/"
VGG_WEIGHT = "../vgg16_weights.npz"
LOG_DIR = "./logs"
MODEL_DIR = "./models"
CKPT_DIR = "./ckpts"

NUM_EPOCHS = 2
LEARNING_RATE = 1e-3
CONTENT_BATCH_SIZE = 8
STYLE_BATCH_SIZE = 1
LOG_ITER = 100
SAMPLE_ITER = 100
STYLE_SIZE = 256
CONTENT_SIZE = 256

CONTENT_LOSS_WEIGHT = 1
STYLE_LOSS_WEIGHT = 10

In [None]:
def vgg16(x, weights):
    # substract imagenet mean
    mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='imagenet_mean')
    x = x - mean
    
    with tf.variable_scope("vgg16", reuse=tf.AUTO_REUSE):
        with tf.name_scope('conv1_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv1_1_W"]), trainable=False, name='conv1_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv1_1_b"]), trainable=False, name='conv1_1_b')
            conv1_1 = tf.nn.conv2d(x, kernel, [1, 1, 1, 1], padding='SAME')
            conv1_1 = tf.nn.bias_add(conv1_1, biases)
            conv1_1 = tf.nn.relu(conv1_1, name=scope)

        with tf.name_scope('conv1_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv1_2_W"]), trainable=False, name='conv1_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv1_2_b"]), trainable=False, name='conv1_2_b')
            conv1_2 = tf.nn.conv2d(conv1_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv1_2 = tf.nn.bias_add(conv1_2, biases)
            conv1_2 = tf.nn.relu(conv1_2, name=scope)

        pool1 = tf.nn.avg_pool(conv1_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', 
                               name='pool1')

        with tf.name_scope('conv2_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv2_1_W"]), trainable=False, name='conv2_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv2_1_b"]), trainable=False, name='conv2_1_b')
            conv2_1 = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME')
            conv2_1 = tf.nn.bias_add(conv2_1, biases)
            conv2_1 = tf.nn.relu(conv2_1, name=scope)

        with tf.name_scope('conv2_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv2_2_W"]), trainable=False, name='conv2_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv2_2_b"]), trainable=False, name='conv2_2_b')
            conv2_2 = tf.nn.conv2d(conv2_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv2_2 = tf.nn.bias_add(conv2_2, biases)
            conv2_2 = tf.nn.relu(conv2_2, name=scope)

        pool2 = tf.nn.avg_pool(conv2_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',
                               name='pool2')

        with tf.name_scope('conv3_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_1_W"]), trainable=False, name='conv3_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_1_b"]), trainable=False, name='conv3_1_b')
            conv3_1 = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_1 = tf.nn.bias_add(conv3_1, biases)
            conv3_1 = tf.nn.relu(conv3_1, name=scope)

        with tf.name_scope('conv3_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_2_W"]), trainable=False, name='conv3_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_2_b"]), trainable=False, name='conv3_2_b')
            conv3_2 = tf.nn.conv2d(conv3_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_2 = tf.nn.bias_add(conv3_2, biases)
            conv3_2 = tf.nn.relu(conv3_2, name=scope)

        with tf.name_scope('conv3_3') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_3_W"]), trainable=False, name='conv3_3_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_3_b"]), trainable=False, name='conv3_3_b')
            conv3_3 = tf.nn.conv2d(conv3_2, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_3 = tf.nn.bias_add(conv3_3, biases)
            conv3_3 = tf.nn.relu(conv3_3, name=scope)

        pool3 = tf.nn.avg_pool(conv3_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',
                               name='pool3')

        with tf.name_scope('conv4_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_1_W"]), trainable=False, name='conv4_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_1_b"]), trainable=False, name='conv4_1_b')
            conv4_1 = tf.nn.conv2d(pool3, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_1 = tf.nn.bias_add(conv4_1, biases)
            conv4_1 = tf.nn.relu(conv4_1, name=scope)

        with tf.name_scope('conv4_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_2_W"]), trainable=False, name='conv4_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_2_b"]), trainable=False, name='conv4_2_b')
            conv4_2 = tf.nn.conv2d(conv4_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_2 = tf.nn.bias_add(conv4_2, biases)
            conv4_2 = tf.nn.relu(conv4_2, name=scope)

        with tf.name_scope('conv4_3') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_3_W"]), trainable=False, name='conv4_3_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_3_b"]), trainable=False, name='conv4_3_b')
            conv4_3 = tf.nn.conv2d(conv4_2, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_3 = tf.nn.bias_add(conv4_3, biases)
            conv4_3 = tf.nn.relu(conv4_3, name=scope)
            
    return conv1_2, conv2_2, conv3_3, conv4_3

In [None]:
def instance_norm(x, name, epsilon=1e-5, gamma=None, beta=None):
    with tf.variable_scope(name):
        if gamma is None: gamma = tf.get_variable(shape=x.shape[-1], name="gamma")
        if beta is None: beta = tf.get_variable(shape=x.shape[-1], name="beta")
        mean, var = tf.nn.moments(x, axes=[1,2], keep_dims=True)
        x = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name="norm")
    return x

def conv(x, name, filters, kernel_size, strides, norm=instance_norm, act=tf.nn.relu, norm_gamma=None,
         norm_beta=None):
    padding = kernel_size//2
    with tf.variable_scope(name):
        x = tf.pad(x, paddings=[[0,0],[padding,padding],[padding,padding],[0,0]], mode="REFLECT")
        x = tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size, strides=strides, name="conv")
        if norm is not None: x = norm(x, name="norm", gamma=norm_gamma, beta=norm_beta)
        if act is not None: x = act(x, name="act")
    return x

def residual_block(x, name, filters, kernel_size, norm_gamma1=None, norm_gamma2=None, norm_beta1=None, 
                   norm_beta2=None):
    with tf.variable_scope(name):
        residual = x
        x = conv(x, "conv1", filters, kernel_size, strides=1, norm_gamma=norm_gamma1, norm_beta=norm_beta1)
        x = conv(x, "conv2", filters, kernel_size, strides=1, act=None, norm_gamma=norm_gamma2, 
                 norm_beta=norm_beta2)
    return x + residual

def upsample(x, name, filters, kernel_size, strides, norm_gamma=None, norm_beta=None):
    _, w, h, _ = x.shape.as_list()
    x = tf.image.resize_images(x, size=[w * strides, h * strides])
    x = conv(x, name, filters, kernel_size, strides=1, norm_gamma=norm_gamma, norm_beta=norm_beta)
    return x

def tnet(x, gammas, betas):   
    with tf.variable_scope("tnet", reuse=tf.AUTO_REUSE):
        conv1 = conv(x, "conv1", filters=32, kernel_size=9, strides=1, norm_gamma=gammas["conv1"],
                     norm_beta=betas["conv1"])
        conv2 = conv(conv1, "conv2", filters=64, kernel_size=3, strides=2, norm_gamma=gammas["conv2"],
                     norm_beta=betas["conv2"])
        conv3 = conv(conv2, "conv3", filters=128, kernel_size=3, strides=2, norm_gamma=gammas["conv3"],
                     norm_beta=betas["conv3"])
        res1 = residual_block(conv3, "res1", filters=128, kernel_size=3, norm_gamma1=gammas["res1_1"],
                              norm_gamma2=gammas["res1_2"], norm_beta1=betas["res1_1"], 
                              norm_beta2=betas["res1_2"])
        res2 = residual_block(res1, "res2", filters=128, kernel_size=3, norm_gamma1=gammas["res2_1"],
                              norm_gamma2=gammas["res2_2"], norm_beta1=betas["res2_1"], 
                              norm_beta2=betas["res2_2"])
        res3 = residual_block(res2, "res3", filters=128, kernel_size=3, norm_gamma1=gammas["res3_1"],
                              norm_gamma2=gammas["res3_2"], norm_beta1=betas["res3_1"], 
                              norm_beta2=betas["res3_2"])
        res4 = residual_block(res3, "res4", filters=128, kernel_size=3, norm_gamma1=gammas["res4_1"],
                              norm_gamma2=gammas["res4_2"], norm_beta1=betas["res4_1"], 
                              norm_beta2=betas["res4_2"])
        res5 = residual_block(res4, "res5", filters=128, kernel_size=3, norm_gamma1=gammas["res5_1"],
                              norm_gamma2=gammas["res5_2"], norm_beta1=betas["res5_1"], 
                              norm_beta2=betas["res5_2"])
        up1 = upsample(res5, "up1", filters=64, kernel_size=3, strides=2, norm_gamma=gammas["up1"],
                       norm_beta=betas["up1"])
        up2 = upsample(up1, "up2", filters=32, kernel_size=3, strides=2, norm_gamma=gammas["up2"],
                       norm_beta=betas["up2"])
        conv4 = conv(up2, "conv4", filters=3, kernel_size=9, strides=1, norm=None, act=None)
    return tf.clip_by_value(conv4, 0., 255.)

def pnet_residual_block(x, name, filters, kernel_size):
    with tf.variable_scope(name):
        residual = x
        conv1 = conv(x, "conv1", filters, kernel_size, strides=1)
        conv2 = conv(conv1, "conv2", filters, kernel_size, strides=1)
    return conv2 + residual, conv1, conv2

def pnet_fc(x, name):
    with tf.variable_scope(name):
        x = tf.reshape(x, shape=[x.shape[1] * x.shape[2], x.shape[3]])
        w = tf.get_variable(shape=[1, x.shape[0]], name="w")
        b = tf.get_variable(shape=[1, x.shape[1]], name="b")
        fc = tf.squeeze(w @ x + b)
    return fc

def pnet(x): 
    gammas, betas = {}, {}
    with tf.variable_scope("pnet", reuse=tf.AUTO_REUSE):
        conv1 = conv(x, "conv1", filters=32, kernel_size=9, strides=1)
        gammas["conv1"], betas["conv1"] = pnet_fc(conv1, "fc_gamma_conv1"), pnet_fc(conv1, "fc_beta_conv1")
        
        conv2 = conv(conv1, "conv2", filters=64, kernel_size=3, strides=2)
        gammas["conv2"], betas["conv2"] = pnet_fc(conv2, "fc_gamma_conv2"), pnet_fc(conv2, "fc_beta_conv2")
        
        conv3 = conv(conv2, "conv3", filters=128, kernel_size=3, strides=2)
        gammas["conv3"], betas["conv3"] = pnet_fc(conv3, "fc_gamma_conv3"), pnet_fc(conv3, "fc_beta_conv3")
        
        res1, res1_1, res1_2 = pnet_residual_block(conv3, "res1", filters=128, kernel_size=3)
        gammas["res1_1"], betas["res1_1"] = pnet_fc(res1_1, "fc_gamma_res1_1"), pnet_fc(res1_1, "fc_beta_res1_1")
        gammas["res1_2"], betas["res1_2"] = pnet_fc(res1_2, "fc_gamma_res1_2"), pnet_fc(res1_2, "fc_beta_res1_2")
        
        res2, res2_1, res2_2 = pnet_residual_block(res1, "res2", filters=128, kernel_size=3)
        gammas["res2_1"], betas["res2_1"] = pnet_fc(res2_1, "fc_gamma_res2_1"), pnet_fc(res2_1, "fc_beta_res2_1")
        gammas["res2_2"], betas["res2_2"] = pnet_fc(res2_2, "fc_gamma_res2_2"), pnet_fc(res2_2, "fc_beta_res2_2")
        
        res3, res3_1, res3_2 = pnet_residual_block(res2, "res3", filters=128, kernel_size=3)
        gammas["res3_1"], betas["res3_1"] = pnet_fc(res3_1, "fc_gamma_res3_1"), pnet_fc(res3_1, "fc_beta_res3_1")
        gammas["res3_2"], betas["res3_2"] = pnet_fc(res3_2, "fc_gamma_res3_2"), pnet_fc(res3_2, "fc_beta_res3_2")
        
        res4, res4_1, res4_2 = pnet_residual_block(res3, "res4", filters=128, kernel_size=3)
        gammas["res4_1"], betas["res4_1"] = pnet_fc(res4_1, "fc_gamma_res4_1"), pnet_fc(res4_1, "fc_beta_res4_1")
        gammas["res4_2"], betas["res4_2"] = pnet_fc(res4_2, "fc_gamma_res4_2"), pnet_fc(res4_2, "fc_beta_res4_2")
        
        res5, res5_1, res5_2 = pnet_residual_block(res4, "res5", filters=128, kernel_size=3)
        gammas["res5_1"], betas["res5_1"] = pnet_fc(res5_1, "fc_gamma_res5_1"), pnet_fc(res5_1, "fc_beta_res5_1")
        gammas["res5_2"], betas["res5_2"] = pnet_fc(res5_2, "fc_gamma_res5_2"), pnet_fc(res5_2, "fc_beta_res5_2")
        
        up1 = upsample(res5, "up1", filters=64, kernel_size=3, strides=2)
        gammas["up1"], betas["up1"] = pnet_fc(up1, "fc_gamma_up1"), pnet_fc(up1, "fc_beta_up1")
        
        up2 = upsample(up1, "up2", filters=32, kernel_size=3, strides=2)
        gammas["up2"], betas["up2"] = pnet_fc(up2, "fc_gamma_up2"), pnet_fc(up2, "fc_beta_up2")
    return gammas, betas

In [None]:
def gram_matrix(x):
    batch_size, w, h, ch = x.shape.as_list()
    x = tf.reshape(x, [batch_size, w * h, ch])
    return tf.matmul(x, x, transpose_a=True) / (ch * w * h)

def loss_fun(target_style_features, target_content_features, transferred_features, 
             style_loss_weight=STYLE_LOSS_WEIGHT, content_loss_weight=CONTENT_LOSS_WEIGHT):
    # using relu2_2 as content features
    content_loss = tf.reduce_mean(tf.subtract(transferred_features[1], target_content_features[1]) ** 2, 
                                  [1, 2, 3])  

    style_loss = 0
    for i in range(len(transferred_features)):
        gram_target = gram_matrix(target_style_features[i])
        gram_transferred = gram_matrix(transferred_features[i])
        style_loss += tf.reduce_mean(tf.subtract(gram_target, gram_transferred) ** 2, [1, 2])  
        
    return tf.reduce_mean(content_loss_weight * content_loss + style_loss_weight * style_loss)

In [None]:
iterator = tf.keras.preprocessing.image.DirectoryIterator
datagen = tf.keras.preprocessing.image.ImageDataGenerator()
content_iter = iterator(directory=CONTENT_DIR, batch_size=CONTENT_BATCH_SIZE, 
                        target_size=(CONTENT_SIZE,CONTENT_SIZE), image_data_generator=datagen, shuffle=True)
style_iter = iterator(directory=STYLE_DIR, batch_size=STYLE_BATCH_SIZE,
                      target_size=(STYLE_SIZE,STYLE_SIZE), image_data_generator=datagen, shuffle=True)

In [None]:
if os.path.exists(CKPT_DIR):
    shutil.rmtree(CKPT_DIR)
    
vgg_weights = np.load(VGG_WEIGHT)

tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

total_iteration = NUM_EPOCHS * content_iter.n // CONTENT_BATCH_SIZE

training_graph = tf.Graph()

with training_graph.as_default() as g, tf.Session(config=config, graph=training_graph) as sess:    
    style = tf.placeholder(name="style", dtype=tf.float32, 
                           shape=[STYLE_BATCH_SIZE,STYLE_SIZE,STYLE_SIZE,3])
    content = tf.placeholder(name="content", dtype=tf.float32, 
                             shape=[CONTENT_BATCH_SIZE,CONTENT_SIZE,CONTENT_SIZE,3])

    target_style_features = vgg16(style, vgg_weights)
    target_content_features = vgg16(content, vgg_weights)
    gammas, betas = pnet(style)
    transferred = tnet(content, gammas, betas)
    transferred_features = vgg16(transferred, vgg_weights)
    loss = loss_fun(target_style_features, target_content_features, transferred_features)
    train_op = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)

    summary = tf.summary.FileWriter(graph=g, logdir=LOG_DIR)
    style_summary = tf.summary.image("style", style)
    content_summary = tf.summary.image("content", content)
    transferred_summary = tf.summary.image("transferred", transferred)
    image_summary = tf.summary.merge([style_summary, content_summary, transferred_summary])
    loss_summary = tf.summary.scalar("loss", loss)

    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    start = time.time()
    it = 0
    for i in range(NUM_EPOCHS):
        content_iter.reset()
        style_iter.reset()
        for c, _ in content_iter:
            it += 1
            
            if c.shape[0] < CONTENT_BATCH_SIZE:
                break
                 
            try:
                s, _ = style_iter.next()
            except StopIteration:
                style_iter.reset()
                s, _ = style_iter.next()
                    
            _, cur_loss, cur_loss_summary, cur_image_summary = sess.run([train_op, loss, loss_summary, 
                                                                         image_summary], 
                                                                        feed_dict={style: s, content: c})
            summary.add_summary(cur_loss_summary, it)

            if it % LOG_ITER == 0 or it == total_iteration:
                print("Iteration: [{it}/{num_iter}], loss: {loss}".format(it=it, num_iter=total_iteration,
                                                                          loss=cur_loss))
                
            if it % SAMPLE_ITER == 0 or it == total_iteration:
                summary.add_summary(cur_image_summary, it)
                
            summary.flush()
            
        ckpt_path = saver.save(sess, save_path=os.path.join(CKPT_DIR, "ckpt"), write_meta_graph=False, 
                               global_step=it)
        print("Checkpoint saved as: {ckpt_path}".format(ckpt_path=ckpt_path))
        
end = time.time()
print("Finished {num_iter} in {time} seconds".format(num_iter=total_iteration, time=end-start))

In [None]:
if os.path.exists(MODEL_DIR):
    shutil.rmtree(MODEL_DIR)

tf.reset_default_graph()
eval_graph = tf.Graph()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
exporter = tf.saved_model.builder.SavedModelBuilder(MODEL_DIR)
latest_ckpt = tf.train.latest_checkpoint(CKPT_DIR)

with eval_graph.as_default() as g, tf.Session(config=config, graph=eval_graph) as sess:    
    style = tf.placeholder(name="style", dtype=tf.float32, shape=[1,STYLE_SIZE,STYLE_SIZE,3])
    inputs = tf.placeholder(name="inputs", dtype=tf.float32, shape=[None,CONTENT_SIZE,CONTENT_SIZE,3])
    gammas, betas = pnet(style)
    outputs = tf.identity(tnet(inputs, gammas, betas), name="outputs")
    
    saver = tf.train.Saver()
    saver.restore(sess, latest_ckpt)
    
    exporter.add_meta_graph_and_variables(
        sess, 
        tags=[tf.saved_model.tag_constants.SERVING], 
        signature_def_map={
            tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
            tf.saved_model.signature_def_utils.predict_signature_def(inputs={"inputs": inputs,
                                                                             "style": style}, 
                                                                     outputs={"outputs": outputs})
        })
    exporter.save()

In [None]:
content_image = tf.keras.preprocessing.image.img_to_array(img=tf.keras.preprocessing.image.load_img(CONTENT_IMG, target_size=(CONTENT_SIZE,CONTENT_SIZE)))
style_image = tf.keras.preprocessing.image.img_to_array(img=
                                                        tf.keras.preprocessing.image.load_img("../styles/0/the_scream.jpg", target_size=(STYLE_SIZE,STYLE_SIZE)))

tf.reset_default_graph()
eval_graph = tf.Graph()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True

with eval_graph.as_default() as g, tf.Session(config=config, graph=eval_graph) as sess:  
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], MODEL_DIR)
    inputs = g.get_tensor_by_name("inputs:0")
    style = g.get_tensor_by_name("style:0")
    outputs = g.get_tensor_by_name("outputs:0")    
    c, s = sess.run([tf.expand_dims(content_image, axis=0), tf.expand_dims(style_image, axis=0)])
    start = time.time()
    result = sess.run(tf.squeeze(outputs), feed_dict={inputs: c, style: s})
    end = time.time()
    print("Inference time: {time} seconds".format(time=end-start))
    
# retain original image color
# def use_original_color(original, result):
#     result_hsv = rgb_to_hsv(result)
#     orig_hsv = rgb_to_hsv(original)
#     oh, os, ov = np.split(orig_hsv, axis=-1, indices_or_sections=3)
#     rh, rs, rv = np.split(result_hsv, axis=-1, indices_or_sections=3)
#     return hsv_to_rgb(np.concatenate([oh, os, rv], axis=-1))

# final_result = use_original_color(content_image.reshape((CONTENT_SIZE, CONTENT_SIZE, 3)), result)
final_result = result
plt.imshow(final_result / 255.)    
plt.show()

result_image = tf.keras.preprocessing.image.array_to_img(final_result)
result_image.save(RESULT_IMG)