Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Saving TFVisionEncoderDecoderModel as SavedModel: The following keyword arguments are not supported by this model: ['attention_mask', 'token_type_ids']. #22731

Closed
2 of 4 tasks
DevinTDHa opened this issue Apr 12, 2023 · 9 comments · Fixed by #22743
Assignees

Comments

@DevinTDHa
Copy link

System Info

  • transformers version: 4.27.4
  • Platform: Linux-6.2.6-76060206-generic-x86_64-with-debian-bookworm-sid
  • Python version: 3.7.16
  • Huggingface_hub version: 0.13.4
  • PyTorch version (GPU?): 1.13.1 (False)
  • Tensorflow version (GPU?): 2.11.0 (False)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: False
  • Using distributed or parallel set-up in script?: False

Who can help?

@gante Could be related to #16400?

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hello,

I am trying to save a TFVisionEncoderDecoderModel in a SavedModel format. Specifically, I am using the nlpconnect/vit-gpt2-image-captioning pretrained model. It seems like the model is able to be intiallised from the PyTorch checkpoint. However, when trying to save it as a SavedModel, it fails with the error.

ValueError: The following keyword arguments are not supported by this model: ['attention_mask', 'token_type_ids'].

Link to Google Colab Reproduction:
https://colab.research.google.com/drive/1N2TVejxiBT5S7bRJ2LSmJ8IIR45folGA#scrollTo=aIL92KqPDDjf

Thanks for your time!

Expected behavior

The model should be saved as a SavedModel without problems, similarly to other pretrained models.

@amyeroberts
Copy link
Collaborator

cc @ydshieh

@ydshieh ydshieh self-assigned this Apr 13, 2023
@ydshieh
Copy link
Collaborator

ydshieh commented Apr 13, 2023

Hi @DevinTDHa Just a quick update: instead of input_ids in the signature, we have to use decoder_input_ids, as the text inputs are for the decoder.

                "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
                "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),

This change will fix the issue you mentioned, but the saving is still not working due to other problems - I am still looking how to fix them.

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 13, 2023

Two extra steps to make the saving working are:

  • First, after model = TFVisionEncoderDecoderModel.from_pretrained(MODEL_NAME, from_pt=True) in your code, add
    model.config.torch_dtype = None
  • Then, in the file src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py, for the class TFVisionEncoderDecoderModel, change the method from
        def serving_output(self, output):
            pkv = tf.tuple(output.past_key_values)[1] if self.config.use_cache else None
            ...
    to
        def serving_output(self, output):
            pkv = tf.tuple(output.past_key_values)[1] if self.config.decoder.use_cache else None
            ...

You can do these changes in your own fork if you want to proceed quickly.

I will discuss the team about the fix in our codebase.

@DevinTDHa
Copy link
Author

Thanks a lot, especially for the suggested edits!

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 13, 2023

@DevinTDHa

In fact, what I did that works is I added the following block for the class TFVisionEncoderDecoderModel in the file src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py

@tf.function(
    input_signature=[
        {
            "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
            "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
        }
    ]
)
def serving(self, inputs):
    """
    Method used for serving the model.

    Args:
        inputs (`Dict[str, tf.Tensor]`):
            The input of the saved model as a dictionary of tensors.
    """
    output = self.call(inputs)

    return self.serving_output(output)

I am not sure why using the approach in your notebook doesn't work (i.e. by specifying serving_fn explicitly)

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 13, 2023

The fixes have been merged to the main branch. The only thing to do manually is to add the correct input_signature to the proper place as shown in the above comment. However, this could not be done in transformers codebase I believe, but you can still do it in your own fork.

I will discuss with our TF experts regarding why specifying signatures as you did is not working. But I am going to close this issue. If you still have any related question on this issue, don't hesitate to leave comments 🤗

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 14, 2023

Hi @Rocketknight1 Since you are a TF saving expert 🔥 , could you take a look on the code snippet below, and see why it doesn't work when we specify signatures manually, please? (it works if I add serving method to TFVisionEncoderDecoderModel directly.

(You have to pull main branch to incorporate 2 fixes first)

Thank you in advanceeeeeeee ~

import tensorflow as tf
from transformers import TFVisionEncoderDecoderModel

# load a fine-tuned image captioning model and corresponding tokenizer and image processor
MODEL_NAME = "nlpconnect/vit-gpt2-image-captioning"
model = TFVisionEncoderDecoderModel.from_pretrained(MODEL_NAME, from_pt=True)
EXPORT_PATH = f"exports/{MODEL_NAME}"

# ========================================================================================================================
# This works

# Add this block to `TFVisionEncoderDecoderModel` in `src/transformers/models/vision_encoder_decoder/modeling_tf_vision_encoder_decoder.py`
"""
    @tf.function(
        input_signature=[
            {
                "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
                "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
            }
        ]
    )

    def serving(self, inputs):
        output = self.call(inputs)
        return self.serving_output(output)
"""
#model.save_pretrained(
#   EXPORT_PATH,
#    saved_model=True,
#    # signatures={"serving_default": my_serving_fn},
#)
# ========================================================================================================================
# Not working (without changing `TFVisionEncoderDecoderModel`)

@tf.function(
    input_signature=[
        {
            "pixel_values": tf.TensorSpec((None, None, None, None), tf.float32, name="pixel_values"),
            "decoder_input_ids": tf.TensorSpec((None, None), tf.int32, name="decoder_input_ids"),
        }
    ]
)
def my_serving_fn(inputs):
    output = model.call(inputs)
    return model.serving_output(output)

# This fails
model.save_pretrained(
    EXPORT_PATH,
    saved_model=True,
    signatures={"serving_default": my_serving_fn},
)
# ========================================================================================================================

@DevinTDHa
Copy link
Author

@ydshieh I have a question regarding this actually:

Currently I'm trying to access the decoder (GPT-2) from the saved model but it seems to my knowledge that it is not possible. The default serving signature you suggested outputs the encoder (ViT) outputs only (or am I wrong in this regard?)

However, trying to create a serving for the model.generate() function, seems to cause the same error. The error is the same as with saving the model with a custom signature. Would this be possible in theory (combining encoder and decoder in one serving function)?

@ydshieh
Copy link
Collaborator

ydshieh commented Apr 19, 2023

@ydshieh I have a question regarding this actually:

Currently I'm trying to access the decoder (GPT-2) from the saved model but it seems to my knowledge that it is not possible. The default serving signature you suggested outputs the encoder (ViT) outputs only (or am I wrong in this regard?)

I believe it gives the outputs of both the encoder and decoder. But if you find it is not the case, please open a new issue and we are more than happy to look into it 🤗 .

However, trying to create a serving for the model.generate() function, seems to cause the same error. The error is the same as with saving the model with a custom signature.
I have never created a saved model format with generate and not sure if it will work in most case(s) - @gante Do you have any knowledge if this is supposed to work (in most cases). cc @Rocketknight1 too.

Would this be possible in theory (combining encoder and decoder in one serving function)?
See my comment in the first paragraph 😃

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants