Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pretraining loss computation for TF Bert pretraining #8470

Merged
merged 7 commits into from
Nov 12, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 79 additions & 5 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@
]


class TFBertPreTrainingLoss:
"""
Loss function suitable for BERT-like pre-training, that is, the task of pretraining a language model by combining
NSP + MLM. .. note:: Any label of -100 will be ignored (along with the corresponding logits) in the loss
computation.
"""

def compute_loss(self, labels, logits):
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction=tf.keras.losses.Reduction.NONE
)
# make sure only labels that are not equal to -100
# are taken into account as loss
masked_lm_active_loss = tf.not_equal(tf.reshape(labels["labels"], (-1,)), -100)
masked_lm_reduced_logits = tf.boolean_mask(
tf.reshape(logits[0], (-1, shape_list(logits[0])[2])),
masked_lm_active_loss,
)
masked_lm_labels = tf.boolean_mask(tf.reshape(labels["labels"], (-1,)), masked_lm_active_loss)
next_sentence_active_loss = tf.not_equal(tf.reshape(labels["next_sentence_label"], (-1,)), -100)
next_sentence_reduced_logits = tf.boolean_mask(tf.reshape(logits[1], (-1, 2)), next_sentence_active_loss)
next_sentence_label = tf.boolean_mask(
tf.reshape(labels["next_sentence_label"], (-1,)), mask=next_sentence_active_loss
)
masked_lm_loss = loss_fn(masked_lm_labels, masked_lm_reduced_logits)
next_sentence_loss = loss_fn(next_sentence_label, next_sentence_reduced_logits)
masked_lm_loss = tf.reshape(masked_lm_loss, (-1, shape_list(next_sentence_loss)[0]))
masked_lm_loss = tf.reduce_mean(masked_lm_loss, 0)

return masked_lm_loss + next_sentence_loss


class TFBertEmbeddings(tf.keras.layers.Layer):
"""Construct the embeddings from word, position and token_type embeddings."""

Expand Down Expand Up @@ -688,6 +720,7 @@ class TFBertForPreTrainingOutput(ModelOutput):
heads.
"""

loss: Optional[tf.Tensor] = None
prediction_logits: tf.Tensor = None
seq_relationship_logits: tf.Tensor = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
Expand Down Expand Up @@ -814,7 +847,7 @@ def call(self, inputs, **kwargs):
""",
BERT_START_DOCSTRING,
)
class TFBertForPreTraining(TFBertPreTrainedModel):
class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)

Expand All @@ -827,7 +860,21 @@ def get_output_embeddings(self):

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@replace_return_docstrings(output_type=TFBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
def call(self, inputs, **kwargs):
def call(
self,
inputs=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
labels=None,
next_sentence_label=None,
training=False,
):
r"""
Return:

Expand All @@ -843,17 +890,44 @@ def call(self, inputs, **kwargs):
>>> prediction_scores, seq_relationship_scores = outputs[:2]

"""
return_dict = kwargs.get("return_dict")
return_dict = return_dict if return_dict is not None else self.bert.return_dict
outputs = self.bert(inputs, **kwargs)

if isinstance(inputs, (tuple, list)):
labels = inputs[9] if len(inputs) > 9 else labels
next_sentence_label = inputs[10] if len(inputs) > 10 else next_sentence_label
if len(inputs) > 9:
inputs = inputs[:9]
elif isinstance(inputs, (dict, BatchEncoding)):
labels = inputs.pop("labels", labels)
next_sentence_label = inputs.pop("next_sentence_label", next_sentence_label)
Comment on lines -846 to +902
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized that this seems to be an issue in all TF models: if return_dict is defined in the inputs (they're either a tuple, a list or a dict), the value of return_dict won't be used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep! But will be solved with the new input parsing. A fix will be in a PR that will arrive soon.


outputs = self.bert(
inputs,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores = self.mlm(sequence_output, training=kwargs.get("training", False))
prediction_scores = self.mlm(sequence_output, training=training)
seq_relationship_score = self.nsp(pooled_output)
total_loss = None

if labels is not None and next_sentence_label is not None:
d_labels = {"labels": labels}
d_labels["next_sentence_label"] = next_sentence_label
total_loss = self.compute_loss(labels=d_labels, logits=(prediction_scores, seq_relationship_score))

if not return_dict:
return (prediction_scores, seq_relationship_score) + outputs[2:]

return TFBertForPreTrainingOutput(
loss=total_loss,
prediction_logits=prediction_scores,
seq_relationship_logits=seq_relationship_score,
hidden_states=outputs.hidden_states,
Expand Down
8 changes: 0 additions & 8 deletions src/transformers/modeling_tf_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,7 +1196,6 @@ def call(
output_hidden_states=None,
return_dict=None,
training=False,
**kwargs,
):
r"""
labels (:obj:`tf.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Expand Down Expand Up @@ -1268,13 +1267,6 @@ def call(
else:
input_ids = inputs

if "past_key_value_states" in kwargs:
warnings.warn(
"The `past_key_value_states` argument is deprecated and will be removed in a future version, use `past_key_values` instead.",
FutureWarning,
)
past_key_values = kwargs.pop("past_key_value_states")

jplu marked this conversation as resolved.
Show resolved Hide resolved
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states else self.config.output_hidden_states
Expand Down
34 changes: 26 additions & 8 deletions tests/test_modeling_tf_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
TF_MODEL_FOR_MASKED_LM_MAPPING,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING,
TF_MODEL_FOR_PRETRAINING_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
Expand Down Expand Up @@ -98,6 +99,14 @@ def _prepare_for_class(self, inputs_dict, model_class, return_labels=False) -> d
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING.values():
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
elif model_class in TF_MODEL_FOR_PRETRAINING_MAPPING.values():
# Only BERT needs the next sentence label for pre-training
if model_class.base_model_prefix in ["bert"]:
inputs_dict["next_sentence_label"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)

inputs_dict["labels"] = tf.zeros(
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's BERT specific, I'd prefer to override this in TFBertModelTest instead of putting model-specific stuff in a model-agnostic file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would you do that? As this is only for BERT + PRetraining.

Copy link
Contributor Author

@jplu jplu Nov 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LysandreJik What do you think of this approach to solve this? (see last commit)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this isn't what I had in mind. I think overriding this in TFBertModelTest would be better. Similarly to how the test_saved_model_with_hidden_states_output is overriden in TFLxmertModelTest: redefining the test in the tester when it diverges to the common test.

We should always extrapolate that a model-specific change will lead to other model-specific changes down the road, and this would become unmaintainable. Overriding in the model-specific tester makes much more sense from a maintainability point of view.

Copy link
Contributor Author

@jplu jplu Nov 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok got it! Sorry I didn't really know how to overwrite a test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just made an update. Does-it looks better?

elif model_class in [
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
Expand Down Expand Up @@ -834,7 +843,9 @@ def test_loss_computation(self):
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
added_label = prepared_for_class[
sorted(list(prepared_for_class.keys() - inputs_dict.keys()), reverse=True)[0]
]
loss_size = tf.size(added_label)

if model.__class__ in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
Expand All @@ -859,23 +870,30 @@ def test_loss_computation(self):

# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
signature = inspect.signature(model.call).parameters
signature_names = list(signature.keys())

# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
tuple_index_mapping = {0: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
label_key_index = signature_names.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())

# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
# Initialize a list with their default values, update the values and convert to a tuple
list_input = []

for name in signature_names:
list_input.append(signature[name].default)

for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
list_input[index] = prepared_for_class[value]

tuple_input = tuple(list_input)

# Send to model
loss = model(tuple_input)[0]
loss = model(tuple_input[:-1])[0]

self.assertEqual(loss.shape, [loss_size])

def _generate_random_bad_tokens(self, num_bad_tokens, model):
Expand Down