diff --git a/src/transformers/models/openai/modeling_tf_openai.py b/src/transformers/models/openai/modeling_tf_openai.py index cb680603a1df..06a1edea52b9 100644 --- a/src/transformers/models/openai/modeling_tf_openai.py +++ b/src/transformers/models/openai/modeling_tf_openai.py @@ -37,8 +37,8 @@ TFSequenceSummary, TFSharedEmbeddings, get_initializer, - input_processing, keras_serializable, + unpack_inputs, ) from ...tf_utils import shape_list from ...utils import logging @@ -234,6 +234,7 @@ def _prune_heads(self, heads_to_prune): """ raise NotImplementedError + @unpack_inputs def call( self, input_ids=None, @@ -248,42 +249,27 @@ def call( training=False, **kwargs, ): - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - if inputs["input_ids"] is not None and inputs["inputs_embeds"] is not None: + if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif inputs["input_ids"] is not None: - input_shape = shape_list(inputs["input_ids"]) - inputs["input_ids"] = tf.reshape(inputs["input_ids"], [-1, input_shape[-1]]) - elif inputs["inputs_embeds"] is not None: - input_shape = shape_list(inputs["inputs_embeds"])[:-1] + elif input_ids is not None: + input_shape = shape_list(input_ids) + input_ids = tf.reshape(input_ids, [-1, input_shape[-1]]) + elif inputs_embeds is not None: + input_shape = shape_list(inputs_embeds)[:-1] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - if inputs["position_ids"] is None: - inputs["position_ids"] = tf.expand_dims(tf.range(input_shape[-1]), axis=0) + if position_ids is None: + position_ids = tf.expand_dims(tf.range(input_shape[-1]), axis=0) - if inputs["attention_mask"] is not None: + if attention_mask is not None: # We create a 3D attention mask from a 2D tensor mask. # Sizes are [batch_size, 1, 1, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # this attention mask is more simple than the triangular masking of causal attention # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - inputs["attention_mask"] = tf.reshape(inputs["attention_mask"], (input_shape[0], 1, 1, input_shape[1])) + attention_mask = tf.reshape(attention_mask, (input_shape[0], 1, 1, input_shape[1])) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for @@ -292,69 +278,65 @@ def call( # effectively the same as removing these entirely. one_cst = tf.constant(1.0) - inputs["attention_mask"] = tf.cast(inputs["attention_mask"], dtype=one_cst.dtype) - inputs["attention_mask"] = tf.multiply( - tf.subtract(one_cst, inputs["attention_mask"]), tf.constant(-10000.0) - ) + attention_mask = tf.cast(attention_mask, dtype=one_cst.dtype) + attention_mask = tf.multiply(tf.subtract(one_cst, attention_mask), tf.constant(-10000.0)) else: - inputs["attention_mask"] = None + attention_mask = None # Prepare head mask if needed # 1.0 in head_mask indicate we keep the head # attention_probs has shape bsz x n_heads x N x N # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - if inputs["head_mask"] is not None: + if head_mask is not None: raise NotImplementedError else: - inputs["head_mask"] = [None] * self.num_hidden_layers + head_mask = [None] * self.num_hidden_layers # head_mask = tf.constant([0] * self.num_hidden_layers) - inputs["position_ids"] = tf.reshape(inputs["position_ids"], [-1, shape_list(inputs["position_ids"])[-1]]) + position_ids = tf.reshape(position_ids, [-1, shape_list(position_ids)[-1]]) - if inputs["inputs_embeds"] is None: - inputs["inputs_embeds"] = self.tokens_embed(inputs["input_ids"], mode="embedding") - position_embeds = tf.gather(self.positions_embed, inputs["position_ids"]) - if inputs["token_type_ids"] is not None: - inputs["token_type_ids"] = tf.reshape( - inputs["token_type_ids"], [-1, shape_list(inputs["token_type_ids"])[-1]] - ) - token_type_embeds = self.tokens_embed(inputs["token_type_ids"], mode="embedding") + if inputs_embeds is None: + inputs_embeds = self.tokens_embed(input_ids, mode="embedding") + position_embeds = tf.gather(self.positions_embed, position_ids) + if token_type_ids is not None: + token_type_ids = tf.reshape(token_type_ids, [-1, shape_list(token_type_ids)[-1]]) + token_type_embeds = self.tokens_embed(token_type_ids, mode="embedding") else: token_type_embeds = 0 - hidden_states = inputs["inputs_embeds"] + position_embeds + token_type_embeds - hidden_states = self.drop(hidden_states, training=inputs["training"]) + hidden_states = inputs_embeds + position_embeds + token_type_embeds + hidden_states = self.drop(hidden_states, training=training) output_shape = input_shape + [shape_list(hidden_states)[-1]] - all_attentions = () if inputs["output_attentions"] else None - all_hidden_states = () if inputs["output_hidden_states"] else None + all_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None for i, block in enumerate(self.h): - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),) outputs = block( hidden_states, - inputs["attention_mask"], - inputs["head_mask"][i], - inputs["output_attentions"], - training=inputs["training"], + attention_mask, + head_mask[i], + output_attentions, + training=training, ) hidden_states = outputs[0] - if inputs["output_attentions"]: + if output_attentions: all_attentions = all_attentions + (outputs[1],) hidden_states = tf.reshape(hidden_states, output_shape) # Add last hidden state - if inputs["output_hidden_states"]: + if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) - if inputs["output_attentions"]: + if output_attentions: # let the number of heads free (-1) so we can extract attention even after head pruning attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:] all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions) - if not inputs["return_dict"]: + if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None) return TFBaseModelOutput( @@ -518,6 +500,7 @@ def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + @unpack_inputs @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -539,9 +522,8 @@ def call( training=False, **kwargs, ): - inputs = input_processing( - func=self.call, - config=self.config, + + outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -552,19 +534,6 @@ def call( output_hidden_states=output_hidden_states, return_dict=return_dict, training=training, - kwargs_call=kwargs, - ) - outputs = self.transformer( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) return outputs @@ -594,6 +563,7 @@ def get_output_embeddings(self): def set_output_embeddings(self, value): self.set_input_embeddings(value) + @unpack_inputs @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -621,9 +591,8 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + + transformer_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -633,34 +602,20 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - transformer_outputs = self.transformer( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) hidden_states = transformer_outputs[0] logits = self.transformer.tokens_embed(hidden_states, mode="linear") loss = None - if inputs["labels"] is not None: + if labels is not None: # shift labels to the left and cut last logit token shifted_logits = logits[:, :-1] - labels = inputs["labels"][:, 1:] + labels = labels[:, 1:] loss = self.hf_compute_loss(labels, shifted_logits) - if not inputs["return_dict"]: + if not return_dict: output = (logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -696,6 +651,7 @@ def __init__(self, config, *inputs, **kwargs): config, initializer_range=config.initializer_range, name="multiple_choice_head" ) + @unpack_inputs @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=TFOpenAIGPTDoubleHeadsModelOutput, config_class=_CONFIG_FOR_DOC) def call( @@ -746,58 +702,35 @@ def call( >>> lm_prediction_scores, mc_prediction_scores = outputs[:2] ```""" - inputs = input_processing( - func=self.call, - config=self.config, - input_ids=input_ids, - attention_mask=attention_mask, - token_type_ids=token_type_ids, - position_ids=position_ids, - head_mask=head_mask, - inputs_embeds=inputs_embeds, - mc_token_ids=mc_token_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - training=training, - kwargs_call=kwargs, - ) - - if inputs["input_ids"] is not None: - input_shapes = shape_list(inputs["input_ids"]) + if input_ids is not None: + input_shapes = shape_list(input_ids) else: - input_shapes = shape_list(inputs["inputs_embeds"])[:-1] + input_shapes = shape_list(inputs_embeds)[:-1] seq_length = input_shapes[-1] - flat_input_ids = tf.reshape(inputs["input_ids"], (-1, seq_length)) if inputs["input_ids"] is not None else None - flat_attention_mask = ( - tf.reshape(inputs["attention_mask"], (-1, seq_length)) if inputs["attention_mask"] is not None else None - ) - flat_token_type_ids = ( - tf.reshape(inputs["token_type_ids"], (-1, seq_length)) if inputs["token_type_ids"] is not None else None - ) - flat_position_ids = ( - tf.reshape(inputs["position_ids"], (-1, seq_length)) if inputs["position_ids"] is not None else None - ) + flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None + flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None + flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None + flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None transformer_outputs = self.transformer( flat_input_ids, flat_attention_mask, flat_token_type_ids, flat_position_ids, - inputs["head_mask"], - inputs["inputs_embeds"], - inputs["output_attentions"], - inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], + head_mask, + inputs_embeds, + output_attentions, + output_hidden_states, + return_dict=return_dict, + training=training, ) hidden_states = transformer_outputs[0] hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:]) lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear") - mc_logits = self.multiple_choice_head(hidden_states, inputs["mc_token_ids"], training=inputs["training"]) + mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training) mc_logits = tf.squeeze(mc_logits, axis=-1) - if not inputs["return_dict"]: + if not return_dict: return (lm_logits, mc_logits) + transformer_outputs[1:] return TFOpenAIGPTDoubleHeadsModelOutput( @@ -857,6 +790,7 @@ def __init__(self, config, *inputs, **kwargs): ) self.transformer = TFOpenAIGPTMainLayer(config, name="transformer") + @unpack_inputs @add_start_docstrings_to_model_forward(OPENAI_GPT_INPUTS_DOCSTRING) @add_code_sample_docstrings( processor_class=_TOKENIZER_FOR_DOC, @@ -884,9 +818,7 @@ def call( Labels for computing the cross entropy classification loss. Indices should be in `[0, ..., config.vocab_size - 1]`. """ - inputs = input_processing( - func=self.call, - config=self.config, + transformer_outputs = self.transformer( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -896,22 +828,7 @@ def call( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - labels=labels, training=training, - kwargs_call=kwargs, - ) - - transformer_outputs = self.transformer( - input_ids=inputs["input_ids"], - attention_mask=inputs["attention_mask"], - token_type_ids=inputs["token_type_ids"], - position_ids=inputs["position_ids"], - head_mask=inputs["head_mask"], - inputs_embeds=inputs["inputs_embeds"], - output_attentions=inputs["output_attentions"], - output_hidden_states=inputs["output_hidden_states"], - return_dict=inputs["return_dict"], - training=inputs["training"], ) hidden_states = transformer_outputs[0] @@ -920,12 +837,12 @@ def call( if self.config.pad_token_id is None: sequence_lengths = -1 else: - if inputs["input_ids"] is not None: + if input_ids is not None: sequence_lengths = ( tf.reduce_sum( tf.cast( - tf.math.not_equal(inputs["input_ids"], self.config.pad_token_id), - dtype=inputs["input_ids"].dtype, + tf.math.not_equal(input_ids, self.config.pad_token_id), + dtype=input_ids.dtype, ), -1, keepdims=False, @@ -941,11 +858,11 @@ def call( ) loss = None - if inputs["labels"] is not None: + if labels is not None: if input_ids is not None: - batch_size, sequence_length = shape_list(inputs["input_ids"])[:2] + batch_size, sequence_length = shape_list(input_ids)[:2] else: - batch_size, sequence_length = shape_list(inputs["inputs_embeds"])[:2] + batch_size, sequence_length = shape_list(inputs_embeds)[:2] assert ( self.config.pad_token_id is not None or batch_size == 1 ), "Cannot handle batch sizes > 1 if no padding token is defined." @@ -953,13 +870,11 @@ def call( if not tf.is_tensor(sequence_lengths): in_logits = logits[0:batch_size, sequence_lengths] - loss = self.hf_compute_loss( - tf.reshape(inputs["labels"], [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels]) - ) + loss = self.hf_compute_loss(tf.reshape(labels, [-1, 1]), tf.reshape(in_logits, [-1, self.num_labels])) pooled_logits = in_logits if in_logits is not None else logits - if not inputs["return_dict"]: + if not return_dict: output = (pooled_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output