-
Notifications
You must be signed in to change notification settings - Fork 25.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add pretraining loss computation for TF Bert pretraining #8470
Changes from 3 commits
b0e5391
768a40b
82b6b87
3c700a7
a387203
9686870
472f2e4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,6 +36,7 @@ | |
TF_MODEL_FOR_MASKED_LM_MAPPING, | ||
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, | ||
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, | ||
TF_MODEL_FOR_PRETRAINING_MAPPING, | ||
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, | ||
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, | ||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, | ||
|
@@ -98,6 +99,14 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d | |
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) | ||
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values(): | ||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) | ||
elif model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values(): | ||
# Only BERT needs the next sentence label for pre-training | ||
if model_class.base_model_prefix in ["bert"]: | ||
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32) | ||
|
||
inputs_dict["labels"] = tf.zeros( | ||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32 | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's BERT specific, I'd prefer to override this in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How would you do that? As this is only for BERT + PRetraining. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @LysandreJik What do you think of this approach to solve this? (see last commit) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this isn't what I had in mind. I think overriding this in We should always extrapolate that a model-specific change will lead to other model-specific changes down the road, and this would become unmaintainable. Overriding in the model-specific tester makes much more sense from a maintainability point of view. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok got it! Sorry I didn't really know how to overwrite a test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just made an update. Does-it looks better? |
||
elif model_class in [ | ||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(), | ||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(), | ||
|
@@ -834,7 +843,9 @@ def test_loss_computation(self): | |
if getattr(model, "compute_loss", None): | ||
# The number of elements in the loss should be the same as the number of elements in the label | ||
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True) | ||
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]] | ||
added_label = prepared_for_class[ | ||
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0] | ||
] | ||
loss_size = tf.size(added_label) | ||
|
||
if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(): | ||
|
@@ -859,23 +870,30 @@ def test_loss_computation(self): | |
|
||
# Get keys that were added with the _prepare_for_class function | ||
label_keys = prepared_for_class.keys() - inputs_dict.keys() | ||
signature = inspect.getfullargspec(model.call)[0] | ||
signature = inspect.signature(model.call).parameters | ||
signature_names = list(signature.keys()) | ||
|
||
# Create a dictionary holding the location of the tensors in the tuple | ||
tuple_index_mapping = {1: "input_ids"} | ||
tuple_index_mapping = {0: "input_ids"} | ||
for label_key in label_keys: | ||
label_key_index = signature.index(label_key) | ||
label_key_index = signature_names.index(label_key) | ||
tuple_index_mapping[label_key_index] = label_key | ||
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items()) | ||
|
||
# Initialize a list with None, update the values and convert to a tuple | ||
list_input = [None] * sorted_tuple_index_mapping[-1][0] | ||
# Initialize a list with their default values, update the values and convert to a tuple | ||
list_input = [] | ||
|
||
for name in signature_names: | ||
list_input.append(signature[name].default) | ||
|
||
for index, value in sorted_tuple_index_mapping: | ||
list_input[index - 1] = prepared_for_class[value] | ||
list_input[index] = prepared_for_class[value] | ||
|
||
tuple_input = tuple(list_input) | ||
|
||
# Send to model | ||
loss = model(tuple_input)[0] | ||
loss = model(tuple_input[:-1])[0] | ||
|
||
self.assertEqual(loss.shape, [loss_size]) | ||
|
||
def _generate_random_bad_tokens(self, num_bad_tokens, model): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just realized that this seems to be an issue in all TF models: if
return_dict
is defined in theinputs
(they're either a tuple, a list or a dict), the value ofreturn_dict
won't be used here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep! But will be solved with the new input parsing. A fix will be in a PR that will arrive soon.