Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

model.export() fails to save output layer name in the SavedModel #19771

Closed
chriscarollo opened this issue May 28, 2024 · 9 comments
Closed

model.export() fails to save output layer name in the SavedModel #19771

chriscarollo opened this issue May 28, 2024 · 9 comments
Assignees
Labels

Comments

@chriscarollo
Copy link

I'm training a model and saving as a SavedModel format (necessary because I want to serve with Triton Server which does not support .keras models). The problem I'm running into is that my output Dense layer, which I've specified a name on, does not get exported with that name, which breaks my Triton server config that relies on it being the name I specified.

If I model.summary() before exporting, I see this:
│ output_1 (Dense) │ (None, 1) │ 257 │ 11_20[0][0] │

but when I model.export() it says:

Output Type:
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)

If I also save as a .keras file, it correctly saves that layer with my "output_1" name. But as I said, I can't use the .keras file, because Triton does not support it.

@chriscarollo
Copy link
Author

Trivially reproduced:

input = keras.layers.Input( (1,), name='input_1' )
output = keras.layers.Dense( 8, name='output_1' )( input )
m = keras.Model( inputs=[input], outputs=[output], name='model_1' )
m.compile()
m.summary()
Model: "model_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                         ┃ Output Shape                ┃         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ input_1 (InputLayer)                 │ (None, 1)                   │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ output_1 (Dense)                     │ (None, 8)                   │              16 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 16 (64.00 B)
 Trainable params: 16 (64.00 B)
 Non-trainable params: 0 (0.00 B)
>>> m.export( 'test' )
INFO:tensorflow:Assets written to: test/assets
INFO:tensorflow:Assets written to: test/assets
Saved artifact at 'test'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 1), dtype=tf.float32, name='input_1')
Output Type:
  TensorSpec(shape=(None, 8), dtype=tf.float32, name=None)
Captures:
  140410840896736: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140410840896912: TensorSpec(shape=(), dtype=tf.resource, name=None)

@mehtamansi29 mehtamansi29 self-assigned this May 29, 2024
@mehtamansi29
Copy link
Collaborator

Hi @chriscarollo -

Thanks for reporting the issue. I have tested the code snippet and reproduces the reported behaviour. Attached gist file for reference.

We will look into the issue and update you the same.

@gowthamkpr gowthamkpr added the keras-team-review-pending Pending review by a Keras team member. label May 29, 2024
@chriscarollo
Copy link
Author

FWIW I'm running Tensorflow 2.16.1 with Keras 3.3.3, but it does seem like this repros with Tensorflow 2.15 as well.

@grasskin grasskin removed the keras-team-review-pending Pending review by a Keras team member. label May 30, 2024
@grasskin
Copy link
Member

Would you mind taking a look @hertschuh? CC: @nkovela1

@hertschuh
Copy link
Contributor

Hi @chriscarollo ,

Thanks for the report. I'm not really clear about what Triton needs from the saved model, but the outputs appear to have a name (more on that below).

but when I model.export() it says:

Output Type:
TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)

Apparently, the only thing that is relevant in this message is the "output type". Unfortunately (and I don't know why) the name is missing.

If I also save as a .keras file, it correctly saves that layer with my "output_1" name.

Correct. Layers and outputs are very different concepts though. Do you need the layer? Or do you care about the output?

It appears that the output is named. It's just numbered starting from zero to support multiple outputs. So in your case, the output is named output_0 (independent of the layer name).

Here's how you can find out by adding this to your code above:

# Reload the model
loaded = tf.saved_model.load('test')
print("Outputs", loaded.signatures['serve'].structured_outputs.items())

Which prints

Outputs dict_items([('output_0', TensorSpec(shape=(None, 8), dtype=tf.float32, name='output_0'))])

But overall, I'm surprised Triton needs more than the name of the function, which is serve by default.

@chriscarollo
Copy link
Author

It's true that the outputs are showing up to Triton as output_N (with N starting at 0), but there are places where I was previously specifying a name on my output by naming the Dense layer(s) that were listed in the outputs parameter of the Model constructor, and using matching names in the output portion of Triton's config.pbtxt file. For example:

output [ { name: "output_win_prob" data_type: TYPE_FP32 dims: [ 1 ] } ]

It was definitely useful to be able to name outputs using the name of their output layer -- it's nice to be confident about what output you're talking about by using a unique name rather than just "the third output" -- and it doesn't seem like that's the case anymore.

@hertschuh
Copy link
Contributor

@chriscarollo ,

If your output is a dict, then the outputs are named by the key in the dict. This is the way you can control the names of the outputs even if you have a single output. So if you prefer, just add a layer to your model that wraps the output in a dict.

In all other cases (single output, list, tuple) the outputs are numbered and called output_<n>.

model.export() works like this with both Keras 2 and Keras 3. Now, it looks like tf.saved_model.save works differently.

@chriscarollo
Copy link
Author

Oh! I didn't realize that outputs could take a dict, and the keys would become output names. Just tested it with Keras 3 and it works great -- Triton is happily loading my newly-trained model with all the output names it's expecting. Thanks!

Copy link

Are you satisfied with the resolution of your issue?
Yes
No

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

7 participants