Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issues in Keras model loading in Tensorflow Serving #2310

Closed
viksit opened this issue Apr 14, 2016 · 19 comments
Closed

Issues in Keras model loading in Tensorflow Serving #2310

viksit opened this issue Apr 14, 2016 · 19 comments

Comments

@viksit
Copy link

viksit commented Apr 14, 2016

Running on keras 1.0.0 and TF Serving 0.7

Here's how I'm approaching this problem,

  1. Create a model via Keras
  2. Export it via the Tensorflow serving exporter infrastructure
  3. Load it into TF serving's C++ wrappers and expose it on a port.

My problem is in (2).

Here's how I export the model.

  model = getKerasCompiledModel() # this is just a standard keras model
  sess = K.get_session()
  saver = tf.train.Saver(sharded=True)
  model_exporter = exporter.Exporter(saver)
  signature = exporter.classification_signature(input_tensor=model.input,
                                                scores_tensor=model.output)
  model_exporter.init(sess.graph.as_graph_def(),
                      default_graph_signature=signature)
  export_path = "./"
  model_exporter.export(export_path, tf.constant(FLAGS.export_version), sess)

So far so good. We've managed to export the keras model session via the session exporter.

However, the issue here is the way TF serving's exporter module expects a classification signature.

signature = exporter.classification_signature(input_tensor=model.input,
                                                scores_tensor=model.output)

In a non Keras model, this works fine. But in a Keras model using the TF backend, unless your input is of the form

[<tf.Tensor 'lstm_input_2:0' shape=(?, 128, 300) dtype=float32>,
 <tf.Tensor 'keras_learning_phase:0' shape=<unknown> dtype=uint8>]

with the second item being the K.learning_phase() placeholder, you can't invoke the prediction function.

You get this error,

You must feed a value for placeholder tensor 'keras_learning_phase' with dtype uint8
     [[Node: keras_learning_phase = Placeholder[dtype=DT_UINT8, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]")

Since the classification API only supports one input tensor and not a list, I can't see a way to export this model in a way that can be read by the Serving infrastructure.

Is there a workaround?

@fchollet
Copy link
Member

I will think about it. What you can do to fix it right now is to make the
learning phase constant. Here's how:

K._LEARNING_PHASE = tf.constant(0) # test mode

(define your model here)

(export your model here)

In the future we might expose a built-in interface for exporting and
serving models, so you don't have to do this by hand.

Let me know how it goes.

On 13 April 2016 at 17:07, Viksit Gaur notifications@github.com wrote:

Running on keras 1.0.0 and TF Serving 0.7

Here's how I'm approaching this problem,

  1. Create a model via Keras
  2. Export it via the Tensorflow serving exporter infrastructure
  3. Load it into TF serving's C++ wrappers and expose it on a port.

My problem is in (2).

Here's how I export the model.

model = getKerasCompiledModel() # this is just a standard keras model
sess = K.get_session()
saver = tf.train.Saver(sharded=True)
model_exporter = exporter.Exporter(saver)
signature = exporter.classification_signature(input_tensor=model.input,
scores_tensor=model.output)
model_exporter.init(sess.graph.as_graph_def(),
default_graph_signature=signature)
export_path = "./"
model_exporter.export(export_path, tf.constant(FLAGS.export_version), sess)

So far so good. We've managed to export the keras model session via the
session exporter.

However, the issue here is the way TF serving's exporter module expects a
classification signature.

signature = exporter.classification_signature(input_tensor=model.input,
scores_tensor=model.output)

In a non Keras model, this works fine. But in a Keras model using the TF
backend, unless your input is of the form

[<tf.Tensor 'lstm_input_2:0' shape=(?, 128, 300) dtype=float32>,
<tf.Tensor 'keras_learning_phase:0' shape= dtype=uint8>]

with the second item being the K.learning_phase() placeholder, you can't
invoke the prediction function.

You get this error,

You must feed a value for placeholder tensor 'keras_learning_phase' with dtype uint8
[[Node: keras_learning_phase = Placeholderdtype=DT_UINT8, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]]")

Since the classification API
https://github.com/tensorflow/serving/blob/master/tensorflow_serving/session_bundle/exporter.py#L66
only supports one input tensor and not a list, I can't see a way to
export this model in a way that can be read by the Serving infrastructure.

Is there a workaround?


You are receiving this because you are subscribed to this thread.
Reply to this email directly or view it on GitHub
#2310

@viksit
Copy link
Author

viksit commented Apr 14, 2016

@fchollet ah - so, I've tried doing that and no luck.

Steps

  • Set K._LEARNING_PHASE = tf.constant(0) when building model right after importing backend as K.
  • Export via the methods described above

Does an exported model contain the learning phase value within it? I'm still seeing the same error.

@fchollet
Copy link
Member

The method I described does work. Just make sure that all ops in your model
are defined after you set the learning phase to a constant.

On 13 April 2016 at 19:10, Viksit Gaur notifications@github.com wrote:

@fchollet https://github.com/fchollet ah - so, I've tried doing that
already and no luck.

Steps

  • Set K._LEARNING_PHASE = tf.constant(0) when building model
  • Export via the methods described above

Does an exported model contain the learning phase value within it? I'm
still seeing the same error.


You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub
#2310 (comment)

@viksit
Copy link
Author

viksit commented Apr 14, 2016

Alright, let me retry with a fresh environment.

One thing to note - I train and export my model/weights into json and .h5 files to disk, and then load them/compile them again in the export script which is sitting in a different location. I'm hoping that won't cause any issues.

@fchollet
Copy link
Member

As long as you are setting the learning phase as constant before you build the fresh model (not just compile, everything), it will be fine. The learning phase does not affect weight loading.

@viksit
Copy link
Author

viksit commented Apr 14, 2016

Gotcha. Is there something I can check within a model to see that the learning phase constant is being set correctly?

@viksit viksit changed the title Issues in Keras model loading in TF Serving Issues in Keras model loading in Tensorflow Serving Apr 14, 2016
@viksit
Copy link
Author

viksit commented Apr 14, 2016

No luck. Retried the entire model creation in one go with K._LEARNING_PHASE set and exporting via the Exporter.

Then loaded it via the C++ interface - and the same error message.

NetworkError(code=StatusCode.INVALID_ARGUMENT, details="You must feed a value for placeholder tensor 'keras_learning_phase' with dtype uint8
     [[Node: keras_learning_phase = Placeholder[dtype=DT_UINT8, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]")

Somewhere, this constant is not being picked up. I'm not sure where, in my case.

@viksit
Copy link
Author

viksit commented Apr 14, 2016

Since this issue should be reproducible even locally - that is, if I create a Keras model and then create a K.function() using the input and output - and then execute it directly, we bypass the internal _make_predict_function() (which adds the K.learning_phase()) - we should see this issue.

Here's the code to reproduce.

What am I missing?

from __future__ import print_function
import sys
import tensorflow as tf
import numpy as np
np.random.seed(8008)  # for reproducibility

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD, Adam, RMSprop
from keras.utils import np_utils
from keras import backend as K

K._LEARNING_PHASE = tf.constant(0) # try with 1

batch_size = 128
nb_classes = 10
nb_epoch = 1

# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

model = Sequential()
model.add(Dense(512, input_shape=(784,)))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))

model.summary()

model.compile(loss='categorical_crossentropy',
              optimizer=RMSprop(),
              metrics=['accuracy'])

history = model.fit(X_train, Y_train,
                    batch_size=batch_size, nb_epoch=nb_epoch,
                    verbose=1, validation_data=(X_test, Y_test))


# since we've set the constant for learning_phase, this should work?
pred = K.function([model.input], [model.output])
pred([X_test])

"""
W tensorflow/core/common_runtime/executor.cc:1102] 0x1102d1e50 Compute status: Invalid argument: You must feed a value for placeholder tensor 'keras_learning_phase' with dtype uint8
         [[Node: keras_learning_phase = Placeholder[dtype=DT_UINT8, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]()]]
"""

# setting it explicitly works fine
pred = K.function([model.input, K.learning_phase()], [model.output])
pred([X_test, 1])
"""
Out[63]:
[array([[  1.00347834e-05,   3.74655019e-06,   1.54657057e-03, ...,
           9.96998191e-01,   2.86282466e-06,   2.00398354e-05],
        [  2.67221221e-05,   9.46493747e-05,   9.94248748e-01, ...,
           4.02291249e-07,   2.70981618e-05,   1.71050818e-08],
        [  3.57027639e-05,   9.91683960e-01,   2.65269703e-03, ...,
           3.92835867e-03,   2.83869915e-04,   5.37958549e-05],
        ...,
        [  4.45166961e-08,   4.54224505e-07,   8.53104893e-07, ...,
           1.04896884e-04,   3.14820500e-04,   1.58036346e-04],
        [  2.00932700e-05,   3.58963516e-05,   3.47329410e-06, ...,
           6.35720880e-07,   1.79862080e-03,   4.55387266e-07],
        [  3.58260934e-07,   2.49237320e-09,   2.99912131e-06, ...,
           9.83266024e-10,   1.59663855e-08,   8.79595341e-10]], dtype=float32)]
"""

@tboquet
Copy link
Contributor

tboquet commented Apr 14, 2016

Weird stuff:

...
K._LEARNING_PHASE = "bob" # try with 1
...

Still working:

60000 train samples
10000 test samples
____________________________________________________________________________________________________
Layer (type)                       Output Shape        Param #     Connected to                     
====================================================================================================
dense_4 (Dense)                    (None, 512)         401920      dense_input_2[0][0]              
____________________________________________________________________________________________________
activation_4 (Activation)          (None, 512)         0           dense_4[0][0]                    
____________________________________________________________________________________________________
dropout_3 (Dropout)                (None, 512)         0           activation_4[0][0]               
____________________________________________________________________________________________________
dense_5 (Dense)                    (None, 512)         262656      dropout_3[0][0]                  
____________________________________________________________________________________________________
activation_5 (Activation)          (None, 512)         0           dense_5[0][0]                    
____________________________________________________________________________________________________
dropout_4 (Dropout)                (None, 512)         0           activation_5[0][0]               
____________________________________________________________________________________________________
dense_6 (Dense)                    (None, 10)          5130        dropout_4[0][0]                  
____________________________________________________________________________________________________
activation_6 (Activation)          (None, 10)          0           dense_6[0][0]                    
====================================================================================================
Total params: 669706
____________________________________________________________________________________________________
Train on 60000 samples, validate on 10000 samples
Epoch 1/1
59904/60000 [============================>.] - ETA: 0s - loss: 0.2772 - acc: 0.9160

With the same error:

...
InvalidArgumentError: You must feed a value for placeholder tensor 'keras_learning_phase' with dtype uint8
     [[Node: keras_learning_phase = Placeholder[dtype=DT_UINT8, shape=[], _device="/job:localhost/replica:0/task:0/cpu:0"]
...

When I try to compile:

pred = K.function([model.input], [model.output])
pred([X_test])

@tboquet
Copy link
Contributor

tboquet commented Apr 14, 2016

It's working if I modify it directly in tensorflow_backend:

pred = K.function([model.input], [model.output])
pred([X_test])

[array([[  2.50758745e-07,   3.05174382e-07,   4.29301508e-05, ...,
           9.99435365e-01,   1.76608160e-06,   1.35888593e-04],
        [  2.79599544e-06,   4.31858498e-05,   9.99716461e-01, ...,
           1.24848683e-07,   2.41637827e-05,   3.72424114e-09],
        [  2.41596881e-05,   9.93469715e-01,   1.15284347e-03, ...,
           3.07259732e-03,   3.31207732e-04,   1.67881037e-04],
        ..., 
        [  3.63393269e-07,   2.36019332e-07,   3.89473325e-06, ...,
           2.09421967e-04,   1.90011575e-04,   4.02843840e-02],
        [  3.75664458e-05,   1.40146096e-06,   6.29712886e-08, ...,
           1.97701837e-07,   1.85844809e-04,   4.84997031e-07],
        [  1.27407093e-05,   1.05844613e-07,   2.26282689e-04, ...,
           7.78336506e-09,   1.69879925e-07,   6.13535818e-08]], dtype=float32)]

@tboquet
Copy link
Contributor

tboquet commented Apr 14, 2016

Import K from keras.layers.core and it should work!

...
from keras.layers.core import K
...

@fchollet
Copy link
Member

This is Python imports shenanigans, I get it now. This is easily solved via
adding a K.set_learning_phase() method, which will work regardless of where
you import K from.

On 14 April 2016 at 06:59, Thomas Boquet notifications@github.com wrote:

Import K from keras.layers.core and it should work!

...from keras.layers.core import K...


You are receiving this because you were mentioned.
Reply to this email directly or view it on GitHub
#2310 (comment)

@fchollet
Copy link
Member

I added K.set_learning_phase(), but only for the TF backend, since it creates additional issues with Theano. In general it should only be used in such cases, since its effect will be to globally and permanently set the learning phase to a constant.

@viksit
Copy link
Author

viksit commented Apr 14, 2016

Thanks @fchollet - will update and test out. Hopefully I can make this a repeatable process as well.

This python import problem was such a headache :)

@SargamModak
Copy link

Running on Keras 2.0.5, tensorflow 1.2.1 and tensorflow-serving-api 1.0.0
I added K.set_learning_phase(False) and then loaded model and defined signature and then exported model. It solved the issue of keras_learning_phase while loading model.

@2M-kotb
Copy link

2M-kotb commented Feb 5, 2018

it is not working for me at all.

@razmik
Copy link

razmik commented Mar 6, 2018

This worked for me.

@haydenth
Copy link

This worked wonders for my export script for h5 -> pb. THANK YOU!

@moyuli-vps
Copy link

reinstall tensorflow==1.11.0 and keras==2.1.2 and tensorflow-serving-api==1.0.0 works for me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants