Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: You have to specify either input_ids or inputs_embeds! #3626

Closed
innat opened this issue Apr 4, 2020 · 22 comments · Fixed by #3636
Closed

ValueError: You have to specify either input_ids or inputs_embeds! #3626

innat opened this issue Apr 4, 2020 · 22 comments · Fixed by #3636
Assignees

Comments

@innat
Copy link

innat commented Apr 4, 2020

Details

I'm quite new to NLP task. However, I was trying to train the T5-large model and set things as follows. But unfortunately, I've got an error.

def build_model(transformer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(cls_token)
    model = Model(inputs=input_word_ids, outputs=out)
    return model

model = build_model(transformer_layer, max_len=MAX_LEN)

It thorws

ValueError: in converted code:
ValueError                                Traceback (most recent call last)
<ipython-input-19-8ad6e68cd3f5> in <module>
----> 5     model = build_model(transformer_layer, max_len=MAX_LEN)
      6 
      7 model.summary()

<ipython-input-17-e001ed832ed6> in build_model(transformer, max_len)
     31     """
     32     input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
---> 33     sequence_output = transformer(input_word_ids)[0]
     34     cls_token = sequence_output[:, 0, :]
     35     out = Dense(1, activation='sigmoid')(cls_token)
ValueError: You have to specify either input_ids or inputs_embeds
@patrickvonplaten patrickvonplaten self-assigned this Apr 5, 2020
@patrickvonplaten
Copy link
Contributor

Hi @innat,

T5 is an encoder-decoder model so you will have to provide both input_ids and decoder_input_ids to the model. Maybe taking a look at the T5 docs (especially the "Examples") can help you :-)

@patrickvonplaten patrickvonplaten linked a pull request Apr 5, 2020 that will close this issue
@patrickvonplaten
Copy link
Contributor

Just noticed that the Examples docstring for TF T5 was wrong. Is fixed with #3636 .

@innat
Copy link
Author

innat commented Apr 5, 2020

@patrickvonplaten
hello, sorry to bother you. Would you please justify the following piece of code:

Imports

from transformers import TFAutoModel, AutoTokenizer

# First load the real tokenizer
tokenizer = AutoTokenizer.from_pretrained('t5-small')
transformer_layer = TFAutoModel.from_pretrained('t5-small')

Define Encoder

def encode(texts, tokenizer, maxlen=512):
    enc_di = tokenizer.batch_encode_plus(
        texts, 
        return_attention_masks=False, 
        return_token_type_ids=False,
        pad_to_max_length=True,
        max_length=maxlen
    )
    return np.array(enc_di['input_ids'])

# tokenized
x_train = encode('text', tokenizer, maxlen=200)
y_train

Define Model and Call

def build_mod(transformer, max_len=512):
    input_word_ids = Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
    sequence_output = transformer(input_word_ids)[0]
    cls_token = sequence_output[:, 0, :]
    out = Dense(1, activation='sigmoid')(cls_token)
    
    model = Model(inputs=input_word_ids, outputs=out)
    model.compile(Adam(lr=1e-5), loss='binary_crossentropy', metrics=['accuracy'])

    return model

# calling
model = build_model(transformer_layer, max_len=200)

Now, according to the docstring, should I do,

outputs = model(input_ids=x_train, decoder_input_ids=x_train)[0]

?

@patrickvonplaten
Copy link
Contributor

I'm not 100% sure what you want to do here exactly. T5 is always trained in a text-to-text format. We have a section here on how to train T5: https://huggingface.co/transformers/model_doc/t5.html#training

Otherwise I'd recommend taking a look at the official paper.

@enzoampil
Copy link
Contributor

@patrickvonplaten Thanks for this. I encountered the same issue and this resolved it!

I'm wondering if it makes sense to make the error message capture the requirement of having both input_ids and decoder_input_ids since this is an encoder-decoder model? This may make the fix clearer for users of encoder decoder models in the future.

I.e., for encoded-decoder models, switch the error message from:

ValueError: You have to specify either input_ids or inputs_embeds

to:

ValueError: You have to specify either (input_ids and decoder_input_ids) or inputs_embeds

I can sent this as a PR as well if you think it makes sense!

@patrickvonplaten
Copy link
Contributor

Hi @enzoampil,

A PR for a cleaner Error message would be nice if you feel like it :-). It would be good if the error message could change between ValueError: You have to specify either input_ids or inputs_embeds if self.is_decoder == False and ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds if self.is_decoder == True. So adding a simple if statement to the error message is definitely a good idea!

@enzoampil
Copy link
Contributor

Got it will do. Thanks for the pointers! 😄

@ratthachat
Copy link
Contributor

ratthachat commented May 7, 2020

Hi, I also got the same error when training seq2seq on tf.keras and I could not follow the example you provide on https://huggingface.co/transformers/model_doc/t5.html#training (this example is for pytorch I think)

I create x_encoder as input_ids and x_decoder_in for decoder_input_ids

model = TFT5Model.from_pretrained('t5-base')
model.compile('adam',loss='sparse_binary_crossentropy')

So when I want to train the model I simply do
model.fit({'input_ids': x_encoder, 'decoder_input_ids': x_decoder_in})

where I clearly provide input_ids , but still got this error message :
ValueError: You have to specify either input_ids or inputs_embeds

Note that changing input from dict to list got the same error. Changing model from TFT5Model to TFT5ForConditionalGeneration got the same error. Changing loss to BCE got the same error.

Moreover, changing input to only one array
model.fit({'input_ids': x_encoder})
is also error :

ValueError: No data provided for "decoder_input_ids". Need data for each key in: ['decoder_input_ids', 'input_ids']

@ratthachat
Copy link
Contributor

ratthachat commented May 7, 2020

In class TFT5Model(TFT5PreTrainedModel):

I found this line (899-900):
```

    # retrieve arguments
    input_ids = kwargs.get("inputs", None)

    ```

Shouldn't it be kwargs.get("input_ids", None) ??

@patrickvonplaten
Copy link
Contributor

@ratthachat - thanks for you message!
We definitely need to provide more TF examples for the T5 Model. I want to tackle this problem in ~2 weeks.

In TF we use the naming convention inputs, so the you should change to model.fit({"inputs": x_encoder}) . I very much agree that the error message is quite misleading and correct it in this PR: #4401.

@ratthachat
Copy link
Contributor

Thanks for your consideration, Patrick!

@ratthachat
Copy link
Contributor

@patrickvonplaten Sorry to tag you in this old thread, but is there any official T5 TF example (as you mentioned in the last thread)?

@patrickvonplaten
Copy link
Contributor

@ratthachat - no worries, we should definitely add more TF T5 examples and we still don't have a good TF T5 notebook.
I am moving the discussion to the forum and if no one answers I will spent some time coping a T5 PT notebook to TF.

@HarrisDePerceptron
Copy link
Contributor

Hi @patrickvonplaten i wanted to fine tune using T5 using TF 2.0 but its soo confusing at each end as compared to pytorch which is really well documented all current examples (community + offcial) are for pytorch. is the work for TFT5 notebook underway?

@patrickvonplaten
Copy link
Contributor

Okey, seems like no-one has a complete TF T5 notebook. I will start working on it this week: https://discuss.huggingface.co/t/how-to-train-t5-with-tensorflow/641/6

Should be done by next week sometime :-)

@prashant-kikani
Copy link

prashant-kikani commented Sep 15, 2020

Hi @patrickvonplaten
Please help me with this error.

I'm doing inference with a T5-base model which I finetuned on GLUE tasks.

It's giving error like
ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

While doing inference, we just need to provide input_ids for the encoder right?
Why do we need decoder_input_ids?

And as it's inference, my labels will also be None.
So, this part will not execute.
decoder_input_ids = self._shift_right(labels)

Waiting for your reply.
Thank you.

@HarrisDePerceptron
Copy link
Contributor

HarrisDePerceptron commented Sep 15, 2020

@prashant-kikani it is indeed a strange behavior. have you tried passing input_ids to decoder_input_ids like:

input_ids = tokenizer(..., return_tensor='tf')   # replace pt for pytorch
outputs= model(input_ids=input_ids, decoder_input_ids=input_ids)

assert len(outputs)==3, 'must return 3 tensors when inferencing'

@prashant-kikani
Copy link

prashant-kikani commented Sep 16, 2020

Hi @HarrisDePerceptron
We can do it & it's giving some output also. But it's not the right thing to do.

You see, T5 which Transformer itself, is a text to text model.
So, it can do inference in linear time by matrix multiplication when label is available.

But, when label is not available, we need to go sequentially by doing forward pass in decoder for each word till </s> doesn't come.
We need to concatenate last output of decoder with new input if decoder each time.

What do you think?

@ratthachat
Copy link
Contributor

ratthachat commented Oct 27, 2020

@prashant-kikani @HarrisDePerceptron

For decoder_input_ids , we just need to put a single BOS token so that the decoder will know that this is the beginning of the output sentence. (Even in GLUE task, T5 still looks at every output label as a complete sentence )

We can see a concrete example by looking at the function
prepare_inputs_for_generation which is called by model.generate
(generate function is here : https://github.com/huggingface/transformers/blob/master/src/transformers/generation_tf_utils.py )

See line 298 in the above link :

if self.config.is_encoder_decoder:
            if decoder_start_token_id is None:
                decoder_start_token_id = bos_token_id

and line 331:

# create empty decoder_input_ids
            input_ids = (
                tf.ones(
                    (effective_batch_size * num_beams, 1),
                    dtype=tf.int32,
                )
                * decoder_start_token_id
            )

and see T5's prepare_inputs_for_generation which change the above input_ids into decoder_input_ids implementation at :

def prepare_inputs_for_generation(self, inputs, past, attention_mask, use_cache, **kwargs):

@dxlong2000
Copy link

dxlong2000 commented Mar 15, 2022

Hi @patrickvonplaten Patrick,

Thanks for your great work and great comment. I mimic the process of inferencing T5 as below and I got a bug, is it possible that you could help me to advise what has happended?

from transformers import AutoModel, AutoTokenizer 
model_name = "castorini/t5-base-canard" 

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

context = '''
    Frank Zappa ||| Disbandment ||| What group disbanded ||| Zappa and the Mothers of Invention ||| When did they disband?
'''

encoded_input = tokenizer(
    context,
    padding='max_length',
    max_length=512,
    truncation=True,
    return_tensors="pt",
)
decoder_input = tokenizer(
    context,
    padding='max_length',
    max_length=512,
    truncation=True,
    return_tensors="pt",
)

encoder_output = model.generate(input_ids=encoded_input["input_ids"], decoder_input_ids=decoder_input["input_ids"])
output = tokenizer.decode(
    encoder_output[0],
    skip_special_tokens=True
)
output

I got error, though I alreadly provided decoder_input_ids:

Some weights of the model checkpoint at castorini/t5-base-canard were not used when initializing T5Model: ['lm_head.weight']
- This IS expected if you are initializing T5Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing T5Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Input length of decoder_input_ids is 512, but ``max_length`` is set to 20. This can lead to unexpected behavior. You should consider increasing ``config.max_length`` or ``max_length``.
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-11-b9fe12b71812>](https://localhost:8080/#) in <module>()
     24 )
     25 
---> 26 encoder_output = model.generate(input_ids=encoded_input["input_ids"], decoder_input_ids=decoder_input["input_ids"])
     27 output = tokenizer.decode(
     28     encoder_output[0],

6 frames
[/usr/local/lib/python3.7/dist-packages/transformers/models/t5/modeling_t5.py](https://localhost:8080/#) in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
    925         else:
    926             err_msg_prefix = "decoder_" if self.is_decoder else ""
--> 927             raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
    928 
    929         if inputs_embeds is None:

ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds

Thanks!

@patrickvonplaten
Copy link
Contributor

Hey @dxlong2000,

I'll open a new issue for this to make it more visible as I think this error happens quite often. See: #16234

@sunyuhan19981208
Copy link

Good issue! really helps me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

8 participants