Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove value error #8985

Merged
merged 5 commits into from
Dec 10, 2020
Merged

Remove value error #8985

merged 5 commits into from
Dec 10, 2020

Conversation

jplu
Copy link
Contributor

@jplu jplu commented Dec 8, 2020

What does this PR do?

This PR update the behavior of the input. We should not raise an error if the name is not among the parameters but act like if there was no name, this is more elegant and less annoying.

@sgugger
Copy link
Collaborator

sgugger commented Dec 8, 2020

Could you elaborate on the use case? It seems dangerous and magical to me. When passing a parameter to a function that is not in the signature, the user gets a ValueError.

@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2020

Sure! an EagerTensor doesn't have a .name attribute so we assume for that case that the values are given in the parameters order. That's ok because we don't have the choice, but why not having the same behavior in case someone decides to name the tensors as he wishs.

This is very picky, and I won't fight at all if not accepted ahah

@sgugger
Copy link
Collaborator

sgugger commented Dec 8, 2020

Mmm, but in this test we're not eager tensors since there is a .name attribute, or am I missing something?

@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2020

While I was trying to explain this, a use case came to my mind, and indeed this behavior is not correct for an edge use case:

from transformers import AutoTokenizer, TFBertForSequenceClassification, BertConfig
import tensorflow as tf
import datasets

config = BertConfig.from_pretrained("bert-base-cased", num_labels=6)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
ds = datasets.load_dataset('emotion')
encoded_train = ds['train'].map(lambda examples: tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128), batched = True)
encoded_train.set_format(type='tensorflow', columns=['input_ids', 'attention_mask', 'label'])
features_train = {x: encoded_train[x].to_tensor(default_value=0, shape=[None, 128]) for x in ['input_ids', 'attention_mask']}
train_ds = tf.data.Dataset.from_tensor_slices((features_train, encoded_train["label"])).batch(16)
input_ids = tf.keras.Input(shape=(128,), dtype='int32', name="input_ids")
attention_mask = tf.keras.Input(shape=(128, ), dtype='int32', name="attention_mask")
transformer = TFBertForSequenceClassification.from_pretrained("bert-base-cased", num_labels=6)
encoded = transformer([input_ids, attention_mask])
logits = encoded[0]
model = tf.keras.models.Model(inputs = [input_ids, attention_mask], outputs = logits)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-5, epsilon=1e-08, clipnorm=1.0), 
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy('accuracy')])
model.fit(train_ds, epochs=1, steps_per_epoch=1)

We get:

ValueError: The tensor named IteratorGetNext:1 does not belong to the authorized list of names ['input_ids', 'attention_mask', 'token_type_ids', 'position_ids', 'head_mask', 'inputs_embeds', 'output_attentions', 'output_hidden_states', 'return_dict', 'labels', 'training'].

Which is normal because .fit() wraps the dataset into an iterator, and then the tensors are renamed accordingly. Thanks @sgugger for asking the question :)

@sgugger
Copy link
Collaborator

sgugger commented Dec 8, 2020

Thanks for explaining, I understand better now :-)

@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2020

Ok,just realized it is even worse, the inputs gets an ID, here IteratorGetNext:1 and IteratorGetNext:0 but the order of the list is never guaranteed. I'm trying to think to a fix for this.

@jplu jplu marked this pull request as draft December 8, 2020 14:55
@jplu jplu marked this pull request as ready for review December 8, 2020 15:46
@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2020

Ok, as long as we are naming the inputs accordingly to the parameters, the order is safe. For example:

input_ids = tf.keras.Input(shape=(128,), dtype='int32', name="input_ids")
attention_mask = tf.keras.Input(shape=(128, ), dtype='int32', name="attention_mask")

model = tf.keras.models.Model(inputs = [input_ids, attention_mask], outputs = ...)

Is perfectly fine and works as expected, but:

input_ids = tf.keras.Input(shape=(128,), dtype='int32')
attention_mask = tf.keras.Input(shape=(128, ), dtype='int32')

model = tf.keras.models.Model(inputs = [input_ids, attention_mask], outputs = ...)

Brings an undefined behavior into the order.

Nevertheless, there is still an issue. Let's imagine this case:

input_embeds = tf.keras.Input(shape=(768,), dtype='float32')
attention_mask = tf.keras.Input(shape=(128, ), dtype='int32')

model = tf.keras.models.Model(inputs = [input_embeds, attention_mask], outputs = ...)

Won't work because internally, the input_ids parameter will take the value of the input_embeds input. This can be solved by integrating the names of each parameter directly inside the model, but we cannot do this because of a bug in TF <= 2.4, and will be solved in the TF 2.5 release. So as long as this release is not out, we cannot fix this, so we have to live with this bug, even though this is an edge use case.

What do you think?

@sgugger
Copy link
Collaborator

sgugger commented Dec 8, 2020

I think we should document that this does not work and encourage users to use named inputs then.

@jplu
Copy link
Contributor Author

jplu commented Dec 8, 2020

I have completed the documentation of the input_processing function. Does-it sounds enough as explanation for you?

@sgugger
Copy link
Collaborator

sgugger commented Dec 8, 2020

LGTM!

@@ -365,9 +367,7 @@ def input_processing(func, config, input_ids, **kwargs):
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from the comment that was added above it seems like the tensor_name has to be in parameter_names no? But we still allow tensors without any parameter name?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we then maybe at least raise a warning here at the moment? And extend the functions docstring with a sentence that there might be unexpected behavior if the inputs have no names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As you can see in the example, the tensors during explicite model creation cannot have a name that belongs to any parameter name. So yes we allow tensors that have a different name than the parameters. But, the input names must have a name that belongs to the parameters in order to be able to have a proper order in the list. It is more obvious when reading the full example above.

Copy link
Contributor Author

@jplu jplu Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten I have updated the comment in order to make it clearer. Does-it sounds more understandable for you? I think there was a confusion between the terms symbolic input and tensor.

@@ -365,9 +367,7 @@ def input_processing(func, config, input_ids, **kwargs):
if tensor_name in parameter_names:
output[tensor_name] = input
else:
raise ValueError(
f"The tensor named {input.name} does not belong to the authorized list of names {parameter_names}."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we instead raise a warning here maybe?

Copy link
Contributor Author

@jplu jplu Dec 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need of an error or warning, as now, we allow tensors that have a different name. We don't have the choice because all the tensors that will be represented by a symbolic input simply cannot have a name that belongs to the parameters. So basically a warning will be raised everytime you will use a tf.data.dataset and for each batch of the dataset, it will simply pollute your standard output for nothing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More specifically in the example given here #8985 (comment)

the symbolic inputs are:

input_ids = tf.keras.Input(shape=(128,), dtype='int32', name="input_ids")
attention_mask = tf.keras.Input(shape=(128, ), dtype='int32', name="attention_mask")

And the tensors that represents there symbolic inputs are respectively named IteratorGetNext:0 and IteratorGetNext:1 and these two names cannot be changed.

@jplu
Copy link
Contributor Author

jplu commented Dec 10, 2020

LGTM! @LysandreJik feel free to merge if the PR gets your approval!

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK!

@LysandreJik LysandreJik merged commit b01ddc9 into huggingface:master Dec 10, 2020
@jplu jplu deleted the fix-inputs branch December 11, 2020 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants