In [None]:
import json
import tensorflow as tf
import functools
from tensorflow.python import saved_model
from magenta.models.image_stylization import model, ops

def build_prediction_graph():
    graph = tf.Graph()
    with graph.as_default():
        
        image_byte = tf.placeholder(tf.string, shape=[None])
        
        images = tf.map_fn(
            functools.partial(tf.image.decode_jpeg, channels=3),
            image_byte,
            dtype=tf.uint8
        )
        
        image_floats = tf.cast(images, tf.float32) / 255.0
        
        image_floats = tf.reshape(image_floats, shape=[-1, 384, 512, 3])

        stylized_images = model.transform(image_floats,
            normalizer_params={
                'labels': tf.constant([1]),
                'num_categories': 32,
                'center': True,
                'scale': True})
        
        output_images = tf.cast(stylized_images * 255.0, tf.uint8)

        images = tf.map_fn(tf.image.encode_jpeg, output_images, dtype=tf.string)
        output = tf.encode_base64(images)


        inputs_info = {
            'image_byte': saved_model.utils.build_tensor_info(image_byte)
        }

        outputs_info = {
            'output_image': saved_model.utils.build_tensor_info(output)
        }

    return graph, inputs_info, outputs_info

In [None]:
graph, inputs_info, outputs_info = build_prediction_graph()

signature_def = saved_model.signature_def_utils.build_signature_def(
    inputs=inputs_info,
    outputs=outputs_info,
    method_name=saved_model.signature_constants.PREDICT_METHOD_NAME)

exporter = saved_model.builder.SavedModelBuilder('savemodel/demo')
checkpoint = 'checkpoints/multistyle-pastiche-generator-varied.ckpt'

with tf.Session(graph=graph) as session:
    saver = tf.train.Saver()
    session.run([tf.local_variables_initializer(), tf.tables_initializer()])
    saver.restore(session, checkpoint)
    exporter.add_meta_graph_and_variables(
        session,
        tags=[saved_model.tag_constants.SERVING],
        signature_def_map={
            saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
                signature_def
        },)

exporter.save()