Skip to content

Commit

Permalink
make fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalkraj committed Nov 22, 2021
1 parent af93378 commit 64d78f8
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 24 deletions.
159 changes: 141 additions & 18 deletions src/transformers/models/tapas/modeling_tf_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
requires_backends,
)
from ...modeling_tf_outputs import (
TFBaseModelOutput,
TFBaseModelOutputWithPastAndCrossAttentions,
TFBaseModelOutputWithPooling,
TFMaskedLMOutput,
TFSequenceClassifierOutput,
Expand Down Expand Up @@ -277,6 +277,8 @@ def __init__(self, config: TapasConfig, **kwargs):
)
self.dropout = tf.keras.layers.Dropout(rate=config.attention_probs_dropout_prob)

self.is_decoder = config.is_decoder

def transpose_for_scores(self, tensor: tf.Tensor, batch_size: int) -> tf.Tensor:
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
tensor = tf.reshape(tensor=tensor, shape=(batch_size, -1, self.num_attention_heads, self.attention_head_size))
Expand All @@ -289,16 +291,49 @@ def call(
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
batch_size = shape_list(hidden_states)[0]
mixed_query_layer = self.query(inputs=hidden_states)
mixed_key_layer = self.key(inputs=hidden_states)
mixed_value_layer = self.value(inputs=hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(inputs=encoder_hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=encoder_hidden_states), batch_size)
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)
key_layer = tf.concatenate([past_key_value[0], key_layer], dim=2)
value_layer = tf.concatenate([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(inputs=hidden_states), batch_size)
value_layer = self.transpose_for_scores(self.value(inputs=hidden_states), batch_size)

query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)

if self.is_decoder:
# if cross_attention save Tuple(tf.Tensor, tf.Tensor) of all cross attention key/value_states.
# Further calls to cross_attention layer can then reuse all cross-attention
# key/value_states (first "if" case)
# if uni-directional self-attention (decoder) save Tuple(tf.Tensor, tf.Tensor) of
# all previous decoder key/value_states. Further calls to uni-directional self-attention
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_layer, value_layer)

# Take the dot product between "query" and "key" to get the raw attention scores.
# (batch size, num_heads, seq_len_q, seq_len_k)
Expand Down Expand Up @@ -328,6 +363,8 @@ def call(
attention_output = tf.reshape(tensor=attention_output, shape=(batch_size, -1, self.all_head_size))
outputs = (attention_output, attention_probs) if output_attentions else (attention_output,)

if self.is_decoder:
outputs = outputs + (past_key_value,)
return outputs


Expand Down Expand Up @@ -366,20 +403,27 @@ def call(
input_tensor: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: tf.Tensor,
encoder_attention_mask: tf.Tensor,
past_key_value: Tuple[tf.Tensor],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
self_outputs = self.self_attention(
hidden_states=input_tensor,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = self.dense_output(
hidden_states=self_outputs[0], input_tensor=input_tensor, training=training
)
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
# add attentions (possibly with past_key_value) if we output them
outputs = (attention_output,) + self_outputs[1:]

return outputs

Expand Down Expand Up @@ -430,6 +474,12 @@ def __init__(self, config: TapasConfig, **kwargs):
super().__init__(**kwargs)

self.attention = TFTapasAttention(config, name="attention")
self.is_decoder = config.is_decoder
self.add_cross_attention = config.add_cross_attention
if self.add_cross_attention:
if not self.is_decoder:
raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
self.crossattention = TFTapasAttention(config, name="crossattention")
self.intermediate = TFTapasIntermediate(config, name="intermediate")
self.bert_output = TFTapasOutput(config, name="output")

Expand All @@ -438,22 +488,69 @@ def call(
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_value: Optional[Tuple[tf.Tensor]],
output_attentions: bool,
training: bool = False,
) -> Tuple[tf.Tensor]:
attention_outputs = self.attention(
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
self_attention_outputs = self.attention(
input_tensor=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_value=self_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = attention_outputs[0]
attention_output = self_attention_outputs[0]

# if decoder, the last output is tuple of self-attn cache
if self.is_decoder:
outputs = self_attention_outputs[1:-1]
present_key_value = self_attention_outputs[-1]
else:
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights

cross_attn_present_key_value = None
if self.is_decoder and encoder_hidden_states is not None:
if not hasattr(self, "crossattention"):
raise ValueError(
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers "
"by setting `config.add_cross_attention=True`"
)

# cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
cross_attention_outputs = self.crossattention(
input_tensor=attention_output,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=cross_attn_past_key_value,
output_attentions=output_attentions,
training=training,
)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights

# add cross-attn cache to positions 3,4 of present_key_value tuple
cross_attn_present_key_value = cross_attention_outputs[-1]
present_key_value = present_key_value + cross_attn_present_key_value

intermediate_output = self.intermediate(hidden_states=attention_output)
layer_output = self.bert_output(
hidden_states=intermediate_output, input_tensor=attention_output, training=training
)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
outputs = (layer_output,) + outputs # add attentions if we output them

# if decoder, return the attn key/values as the last output
if self.is_decoder:
outputs = outputs + (present_key_value,)

return outputs

Expand All @@ -462,47 +559,69 @@ def call(
class TFTapasEncoder(tf.keras.layers.Layer):
def __init__(self, config: TapasConfig, **kwargs):
super().__init__(**kwargs)

self.config = config
self.layer = [TFTapasLayer(config, name=f"layer_._{i}") for i in range(config.num_hidden_layers)]

def call(
self,
hidden_states: tf.Tensor,
attention_mask: tf.Tensor,
head_mask: tf.Tensor,
encoder_hidden_states: Optional[tf.Tensor],
encoder_attention_mask: Optional[tf.Tensor],
past_key_values: Optional[Tuple[Tuple[tf.Tensor]]],
use_cache: Optional[bool],
output_attentions: bool,
output_hidden_states: bool,
return_dict: bool,
training: bool = False,
) -> Union[TFBaseModelOutput, Tuple[tf.Tensor]]:
) -> Union[TFBaseModelOutputWithPastAndCrossAttentions, Tuple[tf.Tensor]]:
all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None

next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

past_key_value = past_key_values[i] if past_key_values is not None else None

layer_outputs = layer_module(
hidden_states=hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
training=training,
)
hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache += (layer_outputs[-1],)

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
if self.config.add_cross_attention and encoder_hidden_states is not None:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)

# Add last layer
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
return tuple(
v for v in [hidden_states, all_hidden_states, all_attentions, all_cross_attentions] if v is not None
)

return TFBaseModelOutput(
last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
return TFBaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_attentions,
cross_attentions=all_cross_attentions,
)


Expand Down Expand Up @@ -722,6 +841,10 @@ def call(
hidden_states=embedding_output,
attention_mask=extended_attention_mask,
head_mask=inputs["head_mask"],
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=inputs["output_attentions"],
output_hidden_states=inputs["output_hidden_states"],
return_dict=inputs["return_dict"],
Expand Down Expand Up @@ -858,7 +981,7 @@ def __init__(self, config: TapasConfig, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFBaseModelOutputWithPooling,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -961,7 +1084,7 @@ def get_lm_head(self) -> tf.keras.layers.Layer:

@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFMaskedLMOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1190,7 +1313,7 @@ def __init__(self, config: TapasConfig, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFTableQuestionAnsweringOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down Expand Up @@ -1539,7 +1662,7 @@ def __init__(self, config: TapasConfig, *inputs, **kwargs):

@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length"))
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
processor_class=_TOKENIZER_FOR_DOC,
checkpoint=_CHECKPOINT_FOR_DOC,
output_type=TFSequenceClassifierOutput,
config_class=_CONFIG_FOR_DOC,
Expand Down
18 changes: 18 additions & 0 deletions src/transformers/utils/dummy_tf_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFAutoModelForTokenClassification:
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -2511,6 +2514,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFTapasForQuestionAnswering:
def __init__(self, *args, **kwargs):
Expand All @@ -2520,6 +2526,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFTapasForSequenceClassification:
def __init__(self, *args, **kwargs):
Expand All @@ -2529,6 +2538,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFTapasModel:
def __init__(self, *args, **kwargs):
Expand All @@ -2538,6 +2550,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


class TFTapasPreTrainedModel:
def __init__(self, *args, **kwargs):
Expand All @@ -2547,6 +2562,9 @@ def __init__(self, *args, **kwargs):
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])

def call(self, *args, **kwargs):
requires_backends(self, ["tf"])


TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None

Expand Down
7 changes: 1 addition & 6 deletions tests/test_modeling_tf_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import tempfile
import unittest

from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, is_tf_available
from transformers import CONFIG_MAPPING, AutoConfig, BertConfig, GPT2Config, T5Config, TapasConfig, is_tf_available
from transformers.testing_utils import (
DUMMY_UNKNOWN_IDENTIFIER,
SMALL_MODEL_IDENTIFIER,
Expand All @@ -31,11 +31,6 @@

if is_tf_available():
from transformers import (
AutoConfig,
BertConfig,
GPT2Config,
T5Config,
TapasConfig,
TFAutoModel,
TFAutoModelForCausalLM,
TFAutoModelForMaskedLM,
Expand Down

0 comments on commit 64d78f8

Please sign in to comment.