In [1]:
%matplotlib inline

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
import shutil
import time

  return f(*args, **kwds)


In [2]:
STYLE_DIR = "./styles/"
CONTENT_DIR = "./train2014/"
VGG_WEIGHT = "./vgg16_weights.npz"
LOG_DIR = "./logs"
MODEL_DIR = "./models"
CKPT_DIR = "./ckpts"

NUM_EPOCHS = 60
LEARNING_RATE = 1e-3
CONTENT_BATCH_SIZE = 8
STYLE_BATCH_SIZE = 1
LOG_ITER = 100
SAMPLE_ITER = 100
STYLE_SIZE = 256
CONTENT_SIZE = 256

CONTENT_LOSS_WEIGHT = 1
STYLE_LOSS_WEIGHT = 250

In [3]:
def vgg16(x, weights):
    # substract imagenet mean
    mean = tf.constant([123.68, 116.779, 103.939], dtype=tf.float32, shape=[1, 1, 1, 3], name='imagenet_mean')
    x = x - mean
    
    with tf.variable_scope("vgg16", reuse=tf.AUTO_REUSE):
        with tf.name_scope('conv1_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv1_1_W"]), trainable=False, name='conv1_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv1_1_b"]), trainable=False, name='conv1_1_b')
            conv1_1 = tf.nn.conv2d(x, kernel, [1, 1, 1, 1], padding='SAME')
            conv1_1 = tf.nn.bias_add(conv1_1, biases)
            conv1_1 = tf.nn.relu(conv1_1, name=scope)

        with tf.name_scope('conv1_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv1_2_W"]), trainable=False, name='conv1_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv1_2_b"]), trainable=False, name='conv1_2_b')
            conv1_2 = tf.nn.conv2d(conv1_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv1_2 = tf.nn.bias_add(conv1_2, biases)
            conv1_2 = tf.nn.relu(conv1_2, name=scope)

        pool1 = tf.nn.avg_pool(conv1_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME', 
                               name='pool1')

        with tf.name_scope('conv2_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv2_1_W"]), trainable=False, name='conv2_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv2_1_b"]), trainable=False, name='conv2_1_b')
            conv2_1 = tf.nn.conv2d(pool1, kernel, [1, 1, 1, 1], padding='SAME')
            conv2_1 = tf.nn.bias_add(conv2_1, biases)
            conv2_1 = tf.nn.relu(conv2_1, name=scope)

        with tf.name_scope('conv2_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv2_2_W"]), trainable=False, name='conv2_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv2_2_b"]), trainable=False, name='conv2_2_b')
            conv2_2 = tf.nn.conv2d(conv2_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv2_2 = tf.nn.bias_add(conv2_2, biases)
            conv2_2 = tf.nn.relu(conv2_2, name=scope)

        pool2 = tf.nn.avg_pool(conv2_2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',
                               name='pool2')

        with tf.name_scope('conv3_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_1_W"]), trainable=False, name='conv3_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_1_b"]), trainable=False, name='conv3_1_b')
            conv3_1 = tf.nn.conv2d(pool2, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_1 = tf.nn.bias_add(conv3_1, biases)
            conv3_1 = tf.nn.relu(conv3_1, name=scope)

        with tf.name_scope('conv3_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_2_W"]), trainable=False, name='conv3_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_2_b"]), trainable=False, name='conv3_2_b')
            conv3_2 = tf.nn.conv2d(conv3_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_2 = tf.nn.bias_add(conv3_2, biases)
            conv3_2 = tf.nn.relu(conv3_2, name=scope)

        with tf.name_scope('conv3_3') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv3_3_W"]), trainable=False, name='conv3_3_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv3_3_b"]), trainable=False, name='conv3_3_b')
            conv3_3 = tf.nn.conv2d(conv3_2, kernel, [1, 1, 1, 1], padding='SAME')
            conv3_3 = tf.nn.bias_add(conv3_3, biases)
            conv3_3 = tf.nn.relu(conv3_3, name=scope)

        pool3 = tf.nn.avg_pool(conv3_3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME',
                               name='pool3')

        with tf.name_scope('conv4_1') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_1_W"]), trainable=False, name='conv4_1_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_1_b"]), trainable=False, name='conv4_1_b')
            conv4_1 = tf.nn.conv2d(pool3, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_1 = tf.nn.bias_add(conv4_1, biases)
            conv4_1 = tf.nn.relu(conv4_1, name=scope)

        with tf.name_scope('conv4_2') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_2_W"]), trainable=False, name='conv4_2_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_2_b"]), trainable=False, name='conv4_2_b')
            conv4_2 = tf.nn.conv2d(conv4_1, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_2 = tf.nn.bias_add(conv4_2, biases)
            conv4_2 = tf.nn.relu(conv4_2, name=scope)

        with tf.name_scope('conv4_3') as scope:
            kernel = tf.get_variable(initializer=tf.constant(weights["conv4_3_W"]), trainable=False, name='conv4_3_W')
            biases = tf.get_variable(initializer=tf.constant(weights["conv4_3_b"]), trainable=False, name='conv4_3_b')
            conv4_3 = tf.nn.conv2d(conv4_2, kernel, [1, 1, 1, 1], padding='SAME')
            conv4_3 = tf.nn.bias_add(conv4_3, biases)
            conv4_3 = tf.nn.relu(conv4_3, name=scope)
            
    return conv1_2, conv2_2, conv3_3, conv4_3

In [4]:
def instance_norm(x, name, epsilon=1e-5):
    with tf.variable_scope(name):
        gamma = tf.get_variable(shape=x.shape[-1], name="gamma")
        beta = tf.get_variable(shape=x.shape[-1], name="beta")
        mean, var = tf.nn.moments(x, axes=[1,2], keep_dims=True)
        x = tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon, name="norm")
    return x

def conv(x, name, filters, kernel_size, strides, norm=instance_norm, act=tf.nn.relu):
    padding = kernel_size//2
    with tf.variable_scope(name):
        x = tf.pad(x, paddings=[[0,0],[padding,padding],[padding,padding],[0,0]], mode="REFLECT")
        x = tf.layers.conv2d(x, filters=filters, kernel_size=kernel_size, strides=strides, name="conv")
        if norm is not None: x = norm(x, name="norm")
        if act is not None: x = act(x, name="act")
    return x

def fixed_conv(x, name, conv_w, conv_b, strides, norm=instance_norm, act=tf.nn.relu):
    padding = conv_w.shape[1]//2
    with tf.variable_scope(name):
        x = tf.pad(x, paddings=[[0,0],[padding,padding],[padding,padding],[0,0]], mode="REFLECT")
        x = tf.nn.conv2d(x, conv_w, [1, strides, strides, 1], "VALID", name="conv")
        x = tf.nn.bias_add(x, conv_b)
        if norm is not None: x = norm(x, name="norm")
        if act is not None: x = act(x, name="act")
    return x

def residual_block(x, name, conv1_w, conv1_b, conv2_w, conv2_b):
    with tf.variable_scope(name):
        residual = x
        x = fixed_conv(x, "conv1", conv1_w, conv1_b, strides=1)
        x = fixed_conv(x, "conv2", conv2_w, conv2_b, strides=1, act=None)
    return x + residual

def upsample(x, name, conv_w, conv_b, strides):
    shape = x.shape.as_list()
    inferred_shape = tf.shape(x)
    w, h = shape[1] or inferred_shape[1], shape[2] or inferred_shape[2]
    x = tf.image.resize_images(x, size=[w * strides, h * strides])
    x = fixed_conv(x, name, conv_w, conv_b, strides=1)
    return x

def tnet(x, weights, biases):   
    with tf.variable_scope("tnet", reuse=tf.AUTO_REUSE):
        conv1 = conv(x, "conv1", filters=32, kernel_size=9, strides=1)
        conv2 = fixed_conv(conv1, "conv2", weights["conv2"], biases["conv2"], strides=2)
        conv3 = fixed_conv(conv2, "conv3", weights["conv3"], biases["conv3"], strides=2,)
        res1 = residual_block(conv3, "res1", weights["res1_1"], biases["res1_1"], weights["res1_2"], 
                              biases["res1_2"])
        res2 = residual_block(res1, "res2", weights["res2_1"], biases["res2_1"], weights["res2_1"], 
                              biases["res2_1"])
        res3 = residual_block(res2, "res3", weights["res3_1"], biases["res3_1"], weights["res3_1"], 
                              biases["res3_1"])
        res4 = residual_block(res3, "res4", weights["res4_1"], biases["res4_1"], weights["res4_1"], 
                              biases["res4_1"])
        res5 = residual_block(res4, "res5", weights["res5_1"], biases["res5_1"], weights["res5_1"], 
                              biases["res5_1"])
        up1 = upsample(res5, "up1", weights["up1"], biases["up1"], strides=2)
        up2 = upsample(up1, "up2", weights["up2"], biases["up2"], strides=2)
        conv4 = conv(up2, "conv4", filters=3, kernel_size=9, strides=1, norm=None, act=None)
    return tf.clip_by_value(conv4, 0., 255.)

def meta(vgg_out):
    conv1_2, conv2_2, conv3_3, conv4_3 = vgg_out
    conv1_2_mean, conv1_2_var = tf.nn.moments(conv1_2, axes=[1,2])
    conv2_2_mean, conv2_2_var = tf.nn.moments(conv2_2, axes=[1,2])
    conv3_3_mean, conv3_3_var = tf.nn.moments(conv3_3, axes=[1,2])
    conv4_3_mean, conv4_3_var = tf.nn.moments(conv4_3, axes=[1,2])
    concat = tf.concat([conv1_2_mean, conv1_2_var, conv2_2_mean, conv2_2_var, conv3_3_mean, conv3_3_var,
                        conv4_3_mean, conv4_3_var], axis=1)
    dense = tf.layers.dense(concat, units=1792)
    split = tf.split(dense, num_or_size_splits=14, axis=1)
    
    weights = {}
    biases = {}
    
    weights["conv2"] = tf.reshape(tf.layers.dense(split[0], units=3 * 3 * 32 * 64), shape=(3, 3, 32, 64))
    biases["conv2"] = tf.squeeze(tf.layers.dense(split[0], units=64))
    weights["conv3"] = tf.reshape(tf.layers.dense(split[1], units=3 * 3 * 64 * 128), shape=(3, 3, 64, 128))
    biases["conv3"] = tf.squeeze(tf.layers.dense(split[1], units=128))
    
    weights["res1_1"] = tf.reshape(tf.layers.dense(split[2], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res1_1"] = tf.squeeze(tf.layers.dense(split[2], units=128))
    weights["res1_2"] = tf.reshape(tf.layers.dense(split[3], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res1_2"] = tf.squeeze(tf.layers.dense(split[3], units=128))
    weights["res2_1"] = tf.reshape(tf.layers.dense(split[4], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res2_1"] = tf.squeeze(tf.layers.dense(split[4], units=128))
    weights["res2_2"] = tf.reshape(tf.layers.dense(split[5], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res2_2"] = tf.squeeze(tf.layers.dense(split[5], units=128))
    weights["res3_1"] = tf.reshape(tf.layers.dense(split[6], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res3_1"] = tf.squeeze(tf.layers.dense(split[6], units=128))
    weights["res3_2"] = tf.reshape(tf.layers.dense(split[7], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res3_2"] = tf.squeeze(tf.layers.dense(split[7], units=128))
    weights["res4_1"] = tf.reshape(tf.layers.dense(split[8], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res4_1"] = tf.squeeze(tf.layers.dense(split[8], units=128))
    weights["res4_2"] = tf.reshape(tf.layers.dense(split[9], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res4_2"] = tf.squeeze(tf.layers.dense(split[9], units=128))
    weights["res5_1"] = tf.reshape(tf.layers.dense(split[10], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res5_1"] = tf.squeeze(tf.layers.dense(split[10], units=128))
    weights["res5_2"] = tf.reshape(tf.layers.dense(split[11], units=3 * 3 * 128 * 128), shape=(3, 3, 128, 128))
    biases["res5_2"] = tf.squeeze(tf.layers.dense(split[11], units=128))
    
    weights["up1"] = tf.reshape(tf.layers.dense(split[12], units=3 * 3 * 128 * 64), shape=(3, 3, 128, 64))
    biases["up1"] = tf.squeeze(tf.layers.dense(split[12], units=64))
    weights["up2"] = tf.reshape(tf.layers.dense(split[13], units=3 * 3 * 64 * 32), shape=(3, 3, 64, 32))
    biases["up2"] = tf.squeeze(tf.layers.dense(split[13], units=32))
    
    return weights, biases

In [5]:
def gram_matrix(x):
    batch_size, w, h, ch = x.shape.as_list()
    x = tf.reshape(x, [batch_size, w * h, ch])
    return tf.matmul(x, x, transpose_a=True) / (ch * w * h)

def loss_fun(target_style_features, target_content_features, transferred_features, transferred,
             style_loss_weight=STYLE_LOSS_WEIGHT, content_loss_weight=CONTENT_LOSS_WEIGHT):
    content_loss = tf.reduce_mean(tf.subtract(transferred_features[1], target_content_features[1]) ** 2, 
                                  [1, 2, 3])  

    style_loss = 0
    for i in range(len(transferred_features)):
        gram_target = gram_matrix(target_style_features[i])
        gram_transferred = gram_matrix(transferred_features[i])
        style_loss += tf.reduce_mean(tf.subtract(gram_target, gram_transferred) ** 2, [1, 2])  
        
    return tf.reduce_mean(content_loss_weight * content_loss 
                          + style_loss_weight * style_loss
                          + 1e-5 * tf.image.total_variation(transferred))

In [None]:
iterator = tf.keras.preprocessing.image.DirectoryIterator
datagen = tf.keras.preprocessing.image.ImageDataGenerator()
content_iter = iterator(directory=CONTENT_DIR, batch_size=CONTENT_BATCH_SIZE, 
                        target_size=(CONTENT_SIZE,CONTENT_SIZE), image_data_generator=datagen, shuffle=True)
style_iter = iterator(directory=STYLE_DIR, batch_size=STYLE_BATCH_SIZE,
                      target_size=(STYLE_SIZE,STYLE_SIZE), image_data_generator=datagen)

Found 82783 images belonging to 1 classes.
Found 21 images belonging to 1 classes.


In [None]:
if os.path.exists(CKPT_DIR):
    shutil.rmtree(CKPT_DIR)
    
vgg_weights = np.load(VGG_WEIGHT)

tf.reset_default_graph()
config = tf.ConfigProto()
config.gpu_options.allow_growth = True

total_iteration = NUM_EPOCHS * content_iter.n // CONTENT_BATCH_SIZE

training_graph = tf.Graph()

with training_graph.as_default() as g, tf.Session(config=config, graph=training_graph) as sess:    
    style = tf.placeholder(name="style", dtype=tf.float32, 
                           shape=[STYLE_BATCH_SIZE,STYLE_SIZE,STYLE_SIZE,3])
    content = tf.placeholder(name="content", dtype=tf.float32, 
                             shape=[CONTENT_BATCH_SIZE,CONTENT_SIZE,CONTENT_SIZE,3])

    target_style_features = vgg16(style, vgg_weights)
    target_content_features = vgg16(content, vgg_weights)
    
    vgg_out = vgg16(style, vgg_weights)
    weights, biases = meta(vgg_out)
    transferred = tnet(content, weights, biases)
    transferred_features = vgg16(transferred, vgg_weights)
    loss = loss_fun(target_style_features, target_content_features, transferred_features, transferred)
    train_op = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE).minimize(loss)

    summary = tf.summary.FileWriter(graph=g, logdir=LOG_DIR)
    style_summary = tf.summary.image("style", style)
    content_summary = tf.summary.image("content", content)
    transferred_summary = tf.summary.image("transferred", transferred)
    image_summary = tf.summary.merge([style_summary, content_summary, transferred_summary])
    loss_summary = tf.summary.scalar("loss", loss)

    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver(var_list=tf.trainable_variables())

    start = time.time()
    it = 0
    for i in range(NUM_EPOCHS):
        content_iter.reset()
        style_iter.reset()
        for c, _ in content_iter:
            it += 1
            
            if c.shape[0] < CONTENT_BATCH_SIZE:
                break
                 
            try:
                s, _ = style_iter.next()
            except StopIteration:
                style_iter.reset()
                s, _ = style_iter.next()
                    
            _, cur_loss, cur_loss_summary, cur_image_summary = sess.run([train_op, loss, loss_summary, 
                                                                         image_summary], 
                                                                        feed_dict={style: s, content: c})
            summary.add_summary(cur_loss_summary, it)

            if it % LOG_ITER == 0 or it == total_iteration:
                print("Iteration: [{it}/{num_iter}], loss: {loss}".format(it=it, num_iter=total_iteration,
                                                                          loss=cur_loss))
                
            if it % SAMPLE_ITER == 0 or it == total_iteration:
                summary.add_summary(cur_image_summary, it)
                
            summary.flush()
            
        ckpt_path = saver.save(sess, save_path=os.path.join(CKPT_DIR, "ckpt"), write_meta_graph=False, 
                               global_step=it)
        print("Checkpoint saved as: {ckpt_path}".format(ckpt_path=ckpt_path))
        
end = time.time()
print("Finished {num_iter} in {time} seconds".format(num_iter=total_iteration, time=end-start))

Iteration: [100/620872], loss: 203612288.0
Iteration: [200/620872], loss: 37005680.0
Iteration: [300/620872], loss: 88067048.0
Iteration: [400/620872], loss: 156271200.0
Iteration: [500/620872], loss: 106962256.0
Iteration: [600/620872], loss: 37562332.0
Iteration: [700/620872], loss: 104418896.0
Iteration: [800/620872], loss: 58532180.0
Iteration: [900/620872], loss: 30773192.0
Iteration: [1000/620872], loss: 55904152.0
Iteration: [1100/620872], loss: 7077652.0
Iteration: [1200/620872], loss: 12791204.0
Iteration: [1300/620872], loss: 34451856.0
Iteration: [1400/620872], loss: 41571176.0
Iteration: [1500/620872], loss: 53679216.0
Iteration: [1600/620872], loss: 9196070.0
Iteration: [1700/620872], loss: 43755056.0
Iteration: [1800/620872], loss: 9651416.0
Iteration: [1900/620872], loss: 17144158.0
Iteration: [2000/620872], loss: 22871446.0
Iteration: [2100/620872], loss: 22925834.0
Iteration: [2200/620872], loss: 17135156.0
Iteration: [2300/620872], loss: 14873421.0
Iteration: [2400/62