In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import scipy.misc
import os

import tensorflow as tf
slim = tf.contrib.slim
from nets import vgg

In [2]:
MEAN_VALUES = np.array([123, 117, 104]).reshape((1,1,1,3))
shape = (227,227,3)

checkpoints_dir = './checkpoints'
CONTENT_IMG =  './images/Taipei1012.jpg'
STYLE_IMG = './images/StarryNight2.jpg'

In [3]:
def vgg_(images):
    with slim.arg_scope(vgg.vgg_arg_scope()):
        logits, endpoints = vgg.vgg_19(images,
                                       num_classes=1000,
                                       is_training=False,
                                       spatial_squeeze=False)
        return endpoints

In [4]:
with tf.Graph().as_default():
    image = tf.placeholder(tf.float32,shape=shape)
    images  = tf.expand_dims(image, 0)
    processed_images = image - MEAN_VALUES
    image_endpoints = vgg_(processed_images)
    
    init_fn = slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'vgg_19.ckpt'),
        slim.get_model_variables('vgg_19'))
    
    with tf.Session() as sess:
        init_fn(sess)
        content_img_ends = sess.run(image_endpoints , feed_dict={image: scipy.misc.imread(CONTENT_IMG)})
        style_img_ends = sess.run(image_endpoints , feed_dict={image: scipy.misc.imread(STYLE_IMG)})

In [5]:
CONTENT_LAYERS = 'vgg_19/conv4/conv4_2'
STYLE_LAYERS=[('vgg_19/conv1/conv1_1',1.),('vgg_19/conv2/conv2_1',1.5),('vgg_19/conv3/conv3_1',2.),
              ('vgg_19/conv4/conv4_1',2.5),('vgg_19/conv5/conv5_1',3.)]

In [6]:
def gram_matrix_val(layer):
    _ , height, weight, depth = layer.shape
    area = height * weight
    
    x1 = layer.reshape(area,depth)
    g = np.dot(x1.T, x1)
    return g

In [7]:
def gram_matrix_tensor(layer):
    shape = layer.get_shape().as_list()
    area = shape[1] * shape[2]
    depth = shape[3]
    
    x1 = tf.reshape(layer,(area,depth))
    g = tf.matmul(tf.transpose(x1), x1)
    return g

In [8]:
def build_style_loss(canvas , style):
    style_loss = 0
    for layer in STYLE_LAYERS:
        layer_key , w = layer
        layer_shape = style[layer_key].shape
        A = gram_matrix_val(style[layer_key])
        G = gram_matrix_tensor(canvas[layer_key])
        M = layer_shape[1] * layer_shape[2]
        N = layer_shape[3]
        layer_loss = (1./(4 * N**2 * M**2)) * tf.reduce_sum(tf.pow((G - A),2))
        
        style_loss = style_loss + w * layer_loss
    
    return style_loss

In [9]:
def build_content_loss(canvas , content):
    M = content.shape[1]*content.shape[2]
    N = content.shape[3]
    loss = (1./(2* N**0.5 * M**0.5 )) * tf.reduce_sum(tf.pow((canvas - content),2))  
    return loss

In [14]:
content_img = scipy.misc.imread(CONTENT_IMG).reshape((1,227,227,3)) - MEAN_VALUES
nosie = np.random.uniform(-20,20,(1,227,227,3))
init_canvas = 0.3*nosie + 0.7*content_img

In [None]:
with tf.Graph().as_default():
    canvas = tf.Variable(init_canvas,dtype=tf.float32)
    canvas_endpoints = vgg_(canvas)
    
    content_loss = build_content_loss(canvas_endpoints[CONTENT_LAYERS], 
                                      content_img_ends[CONTENT_LAYERS])
    style_loss = build_style_loss(canvas_endpoints , style_img_ends)
    total_loss = content_loss + 500 * style_loss
    
    optimizer = tf.train.AdamOptimizer(2)
    train = optimizer.minimize(total_loss,var_list=[canvas])
    
    init_fn = slim.assign_from_checkpoint_fn(
        os.path.join(checkpoints_dir, 'vgg_19.ckpt'),
        slim.get_model_variables('vgg_19'))
    
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        init_fn(sess)
        
        for i in range(1001):
            _ , total_loss_, canvas_ = sess.run([train , total_loss , canvas])
            if i % 100 == 0:
                print total_loss_
                output_img = canvas_ + MEAN_VALUES
                file_name = './results_%s.png' % i
                scipy.misc.imsave(file_name, np.clip(output_img[0], 0, 255).astype('uint8'))

Instructions for updating:
Use `tf.global_variables_initializer` instead.
2.73779e+13
1.89883e+11
1.09599e+11
8.19842e+10
6.74809e+10
