In [1]:
import skimage.io
import skimage.transform
import tensorflow as tf

sess = tf.InteractiveSession()

In [2]:
VGG_NETWORK_NAME = "vgg"

def load_image(path):
    # load image
    img = skimage.io.imread(path) / 255.0
    assert (0 <= img).all() and (img <= 1.0).all()
    # we crop image from center
    short_edge = min(img.shape[:2])
    yy = int((img.shape[0] - short_edge) / 2)
    xx = int((img.shape[1] - short_edge) / 2)
    crop_img = img[yy : yy + short_edge, xx : xx + short_edge]
    # resize to 224, 224
    resized_img = skimage.transform.resize(crop_img, (224, 224))
    return resized_img

content_image = tf.placeholder("float", [1, 224, 224, 3])
style_image = tf.placeholder("float", [1, 224, 224, 3])
synthesized_image = tf.Variable(tf.random_uniform([1, 224, 224, 3]), "synth")
network_input = tf.concat(0, [content_image, style_image, synthesized_image])
with open("models/vgg16.tfmodel", mode='rb') as f:
            file_content = f.read()
graph_def = tf.GraphDef()
graph_def.ParseFromString(file_content)
tf.import_graph_def(graph_def, input_map={"images": network_input}, name=VGG_NETWORK_NAME)

In [3]:
from style_helpers import gramian

def gramian_for_layer(layer):
    """
    Returns a matrix of cross-correlations between the activations of convolutional channels in a given layer.
    """
    activations = tf.get_default_graph().get_tensor_by_name("{0}/conv{1}_1/Relu:0".format(VGG_NETWORK_NAME, layer))

    # Reshape from (batch, width, height, channels) to (batch, channels, width, height)
    shuffled_activations = tf.transpose(activations, perm=[0, 3, 1, 2])
    return gramian(shuffled_activations)

layers = [i for i in range(1, 6)]
activations = [tf.get_default_graph().get_tensor_by_name("{0}/conv{1}_1/Relu:0".format(VGG_NETWORK_NAME, i)) for i in layers]
gramians = [gramian_for_layer(x) for x in layers]
# Slices are for style and synth image
gramian_diffs = [tf.sub(tf.slice(g, [1,0,0], [1,-1,-1]), tf.slice(g, [2,0,0], [1,-1,-1])) for g in gramians]
Ns = [g.get_shape().as_list()[2] for g in gramians]
Ms = [a.get_shape().as_list()[1] * a.get_shape().as_list()[2] for a in activations]
scaled_diffs = [tf.square(g) for g in gramian_diffs]
style_loss = tf.div(tf.add_n([tf.div(tf.reduce_sum(x), 4*(N**2)*(M**2)) for x, N, M in zip(scaled_diffs, Ns, Ms)]), len(layers))

activation_diffs = [tf.sub(tf.slice(a, [0,0,0,0], [1,-1,-1,-1]), tf.slice(a, [2,0,0,0], [1,-1,-1,-1])) for a in activations]
content_loss = tf.div(tf.add_n([tf.reduce_sum(tf.square(a)) for a in activation_diffs]), 2.0)

alpha = 0.001
beta = 1.0
combined_loss = tf.add(tf.mul(beta, style_loss), tf.mul(alpha, content_loss))

In [5]:
init = tf.initialize_all_variables()
sess.run(init)
optimizer = tf.train.GradientDescentOptimizer(0.0000001)
train_step = optimizer.minimize(combined_loss)
style_image_input = load_image("img/style.jpg").reshape((1, 224, 224, 3))
content_image_input = load_image("img/content.jpg").reshape((1, 224, 224, 3))
print("Loss", sess.run(combined_loss, feed_dict={content_image: content_image_input, style_image: style_image_input}))
for i in range(2):
    train_step.run(feed_dict={content_image: content_image_input, style_image: style_image_input})
    print("Loss for step {0}: {1}".format(i, sess.run(combined_loss, feed_dict={content_image: content_image_input, style_image: style_image_input})))
    print(synthesized_image.eval()[0][0][0]) # To make sure it looks reasonable

Loss 7.1462e+09
Loss for step 0: 9214424064.0
[ 0.22602895  0.25953799  0.55995113]
Loss for step 1: 16164547584.0
[ 0.26341969  0.28666052  0.58144754]
