In [1]:
import tensorflow as tf

import sys
sys.path.insert(1, '../libs')

from stnet import STNet
from utils import get_images, save_images, list_images


def stylize(contents_path, styles_path, output_dir, encoder_path, model_path, 
            resize_height=None, resize_width=None, suffix=None):

    if isinstance(contents_path, str):
        contents_path = [contents_path]
    if isinstance(styles_path, str):
        styles_path = [styles_path]

    with tf.Graph().as_default(), tf.Session() as sess:
        # build the dataflow graph
        content = tf.placeholder(
            tf.float32, shape=(1, None, None, 3), name='content')
        style   = tf.placeholder(
            tf.float32, shape=(1, None, None, 3), name='style')

        stn = STNet(encoder_path)

        output_image = stn.transform(content, style)

        sess.run(tf.global_variables_initializer())

        # restore the trained model and run the style transferring
        saver = tf.train.Saver()
        saver.restore(sess, model_path)

        outputs = []
        for content_path in contents_path:

            content_img = get_images(content_path, 
                height=resize_height, width=resize_width)

            for style_path in styles_path:

                style_img   = get_images(style_path)

                result = sess.run(output_image, 
                    feed_dict={content: content_img, style: style_img})

                outputs.append(result[0])

    save_images(outputs, contents_path, styles_path, output_dir, suffix=suffix)

    return outputs

In [2]:
STYLE_WEIGHTS = [3.0]
CONTENT_WEIGHTS = [1.0]
MODEL_SAVE_PATHS = ['../../models/style_weight_2e0.ckpt']

# for inferring (stylize)
INFERRING_CONTENT_DIR = '../../_inference/content'
INFERRING_STYLE_DIR = '../../_inference/style'
OUTPUTS_DIR = '../../_inference/output'

In [3]:
content_imgs_path = list_images(INFERRING_CONTENT_DIR)
style_imgs_path   = list_images(INFERRING_STYLE_DIR)

In [13]:
for style_weight, content_weight, model_save_path in zip(STYLE_WEIGHTS, CONTENT_WEIGHTS, MODEL_SAVE_PATHS):
    print('\n>>> Begin to stylize images')
    
#     stylize(content_imgs_path, style_imgs_path, OUTPUTS_DIR, 
#             ENCODER_WEIGHTS_PATH, model_save_path, 
#             suffix='-' + str(style_weight) + '-' + str(content_weight))


>>> Begin to stylize images
